feat(mqtt): implement PUBLISH QoS 0, SUBSCRIBE, and UNSUBSCRIBE handlers

Add ParsePub, ParseSubsOrUnsubs, ProcessPub (QoS 0), ProcessSubs,
ProcessUnsubs, EnqueueSubAck, and EnqueueUnsubAck to MqttPacketHandlers.
Wire PUB/SUB/UNSUB dispatch cases in MqttParser. Add ReadSlice to
MqttReader for raw payload extraction. 18 new unit tests covering
parsing, flags, error cases, QoS downgrade, and full flow. 1 new
integration test verifying SUBSCRIBE handshake over TCP.
This commit is contained in:
Joseph Doherty
2026-03-01 16:04:37 -05:00
parent 95cf20b00b
commit 715367b9ea
7 changed files with 947 additions and 21 deletions

View File

@@ -24,7 +24,7 @@ public sealed class MqttParserTests
/// </summary>
private static ClientConnection CreateMqttClient()
{
var c = new ClientConnection(ClientKind.Client);
var c = new ClientConnection(ClientKind.Client, nc: new MemoryStream());
c.InitMqtt(new MqttHandler());
return c;
}
@@ -164,37 +164,38 @@ public sealed class MqttParserTests
[Fact]
public void Parse_SingleByteRemainingLength_ShouldWork()
{
// SUB packet with remaining length = 5 (single byte < 128)
// After CONNECT is received, a SUB packet should parse the remaining length correctly.
// SUBSCRIBE with remaining length = 6 (single byte < 128).
// Proves single-byte remaining length decoding works.
var c = CreateMqttClient();
c.Flags |= ClientFlags.ConnectReceived;
// SUBSCRIBE: type=0x82, remaining len=5, then 5 bytes of payload
var buf = new byte[] { 0x82, 0x05, 0x00, 0x01, 0x00, 0x01, 0x74 };
// SUBSCRIBE: type=0x82, remlen=6, PI=1, filter="t" (len=1), QoS=0
var buf = new byte[] { 0x82, 0x06, 0x00, 0x01, 0x00, 0x01, (byte)'t', 0x00 };
var err = MqttParser.Parse(c, buf, buf.Length);
// Will hit NotImplementedException for SUBSCRIBE — that's fine, it proves parsing worked.
err.ShouldNotBeNull();
err.ShouldBeOfType<NotImplementedException>();
err.ShouldBeNull();
}
[Fact]
public void Parse_TwoByteRemainingLength_ShouldWork()
{
// PUBLISH QoS 0 with remaining length = 200 → encoded as [0xC8, 0x01].
// Proves two-byte remaining length decoding works.
var c = CreateMqttClient();
c.Flags |= ClientFlags.ConnectReceived;
// Remaining length = 200 → encoded as [0xC8, 0x01]
// (200 & 0x7F) | 0x80 = 0xC8, 200 >> 7 = 1 → 0x01
// type(1) + remlen(2) + payload(200) = 203 bytes total.
var buf = new byte[203];
buf[0] = MqttPacket.Pub;
buf[0] = MqttPacket.Pub; // 0x30, QoS 0
buf[1] = 0xC8;
buf[2] = 0x01;
// Remaining 200 bytes are zero (payload).
// Topic "t": length prefix (2 bytes) + 1 byte.
buf[3] = 0x00;
buf[4] = 0x01;
buf[5] = (byte)'t';
// Bytes 6..202 are zero (197-byte payload).
var err = MqttParser.Parse(c, buf, buf.Length);
err.ShouldNotBeNull();
err.ShouldBeOfType<NotImplementedException>(); // PUBLISH not yet implemented
err.ShouldBeNull();
}
// =========================================================================

View File

@@ -0,0 +1,530 @@
// Copyright 2020-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
using System.Text;
using Shouldly;
using ZB.MOM.NatsNet.Server;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.Mqtt;
namespace ZB.MOM.NatsNet.Server.Tests.Mqtt;
/// <summary>
/// Unit tests for MQTT PUBLISH, SUBSCRIBE, and UNSUBSCRIBE packet handling.
/// </summary>
public sealed class MqttPubSubTests
{
private static ClientConnection CreateConnectedMqttClient()
{
var ms = new MemoryStream();
var c = new ClientConnection(ClientKind.Client, nc: ms);
c.InitMqtt(new MqttHandler());
c.Flags |= ClientFlags.ConnectReceived;
return c;
}
private static MemoryStream GetStream(ClientConnection c)
{
return (MemoryStream)typeof(ClientConnection)
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
.GetValue(c)!;
}
// =========================================================================
// PUBLISH parsing
// =========================================================================
[Fact]
public void ParsePub_QoS0_ShouldParseCorrectly()
{
var r = new MqttReader();
// Topic "test/topic" + payload "hello"
var topic = Encoding.UTF8.GetBytes("test/topic");
var payload = Encoding.UTF8.GetBytes("hello");
var data = new List<byte>();
data.Add((byte)(topic.Length >> 8));
data.Add((byte)(topic.Length & 0xFF));
data.AddRange(topic);
data.AddRange(payload);
r.Reset(data.ToArray());
byte flags = 0x00; // QoS 0, no retain, no dup
var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: false);
err.ShouldBeNull();
pp.ShouldNotBeNull();
pp!.Topic.ShouldBe("test/topic");
pp.Subject.ShouldNotBeEmpty();
pp.Qos.ShouldBe((byte)0);
pp.Pi.ShouldBe((ushort)0);
pp.Retain.ShouldBeFalse();
pp.Dup.ShouldBeFalse();
Encoding.UTF8.GetString(pp.Msg!).ShouldBe("hello");
}
[Fact]
public void ParsePub_QoS1_ShouldParsePacketId()
{
var r = new MqttReader();
var topic = Encoding.UTF8.GetBytes("a/b");
var data = new List<byte>();
data.Add((byte)(topic.Length >> 8));
data.Add((byte)(topic.Length & 0xFF));
data.AddRange(topic);
data.Add(0x00); data.Add(0x07); // PI = 7
data.AddRange(Encoding.UTF8.GetBytes("msg"));
r.Reset(data.ToArray());
byte flags = MqttPubFlag.QoS1; // QoS 1 = 0x02
var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: false);
err.ShouldBeNull();
pp!.Qos.ShouldBe((byte)1);
pp.Pi.ShouldBe((ushort)7);
}
[Fact]
public void ParsePub_QoS2Rejected_ShouldReturnError()
{
var r = new MqttReader();
var topic = Encoding.UTF8.GetBytes("t");
var data = new List<byte>();
data.Add(0x00); data.Add(0x01);
data.AddRange(topic);
data.Add(0x00); data.Add(0x01); // PI = 1
r.Reset(data.ToArray());
byte flags = MqttPubFlag.QoS2; // QoS 2 = 0x04
var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: true);
err.ShouldNotBeNull();
err.Message.ShouldContain("QoS-2 PUBLISH rejected");
}
[Fact]
public void ParsePub_EmptyTopic_ShouldReturnError()
{
var r = new MqttReader();
var data = new byte[] { 0x00, 0x00 }; // zero-length topic
r.Reset(data);
var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Length, 0x00, rejectQoS2: false);
err.ShouldNotBeNull();
err.Message.ShouldContain("empty topic");
}
[Fact]
public void ParsePub_RetainAndDup_ShouldSetFlags()
{
var r = new MqttReader();
var topic = Encoding.UTF8.GetBytes("t");
var data = new List<byte>();
data.Add(0x00); data.Add(0x01);
data.AddRange(topic);
r.Reset(data.ToArray());
byte flags = MqttPubFlag.Retain | MqttPubFlag.Dup; // 0x09
var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: false);
err.ShouldBeNull();
pp!.Retain.ShouldBeTrue();
pp.Dup.ShouldBeTrue();
}
[Fact]
public void ParsePub_EmptyPayload_ShouldSucceed()
{
var r = new MqttReader();
var topic = Encoding.UTF8.GetBytes("t");
var data = new List<byte>();
data.Add(0x00); data.Add(0x01);
data.AddRange(topic);
// No payload bytes.
r.Reset(data.ToArray());
var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, 0x00, rejectQoS2: false);
err.ShouldBeNull();
pp!.Msg.ShouldBeNull();
}
// =========================================================================
// PUBLISH processing via parser
// =========================================================================
[Fact]
public void Parser_PublishQoS0_ShouldSucceed()
{
var c = CreateConnectedMqttClient();
// Build PUBLISH: type=0x30 (QoS 0), topic="test", payload="hi"
var topic = Encoding.UTF8.GetBytes("test");
var payload = Encoding.UTF8.GetBytes("hi");
var data = new List<byte>();
data.Add(0x00); data.Add((byte)topic.Length);
data.AddRange(topic);
data.AddRange(payload);
var buf = new List<byte>();
buf.Add(MqttPacket.Pub); // 0x30
buf.Add((byte)data.Count);
buf.AddRange(data);
var err = MqttParser.Parse(c, buf.ToArray(), buf.Count);
err.ShouldBeNull();
}
[Fact]
public void Parser_PublishQoS1_ShouldReturnNotImplemented()
{
var c = CreateConnectedMqttClient();
// PUBLISH QoS 1: type=0x32, topic="t", PI=1, payload="x"
var data = new List<byte>();
data.Add(0x00); data.Add(0x01); data.Add((byte)'t'); // topic
data.Add(0x00); data.Add(0x01); // PI = 1
data.Add((byte)'x'); // payload
var buf = new List<byte>();
buf.Add((byte)(MqttPacket.Pub | MqttPubFlag.QoS1)); // 0x32
buf.Add((byte)data.Count);
buf.AddRange(data);
var err = MqttParser.Parse(c, buf.ToArray(), buf.Count);
err.ShouldNotBeNull();
err.ShouldBeOfType<NotImplementedException>();
}
// =========================================================================
// SUBSCRIBE parsing
// =========================================================================
[Fact]
public void ParseSubs_SingleFilter_ShouldParseCorrectly()
{
var r = new MqttReader();
var filter = Encoding.UTF8.GetBytes("test/topic");
var data = new List<byte>();
data.Add(0x00); data.Add(0x0A); // PI = 10
data.Add((byte)(filter.Length >> 8));
data.Add((byte)(filter.Length & 0xFF));
data.AddRange(filter);
data.Add(0x01); // QoS 1
r.Reset(data.ToArray());
var (pi, filters, err) = MqttPacketHandlers.ParseSubsOrUnsubs(
r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Count, isSub: true);
err.ShouldBeNull();
pi.ShouldBe((ushort)10);
filters.ShouldNotBeNull();
filters!.Count.ShouldBe(1);
filters[0].Filter.ShouldNotBeEmpty();
filters[0].Qos.ShouldBe((byte)1);
}
[Fact]
public void ParseSubs_MultipleFilters_ShouldParseAll()
{
var r = new MqttReader();
var f1 = Encoding.UTF8.GetBytes("a/b");
var f2 = Encoding.UTF8.GetBytes("c/d");
var data = new List<byte>();
data.Add(0x00); data.Add(0x01); // PI = 1
data.Add((byte)(f1.Length >> 8)); data.Add((byte)(f1.Length & 0xFF));
data.AddRange(f1);
data.Add(0x00); // QoS 0
data.Add((byte)(f2.Length >> 8)); data.Add((byte)(f2.Length & 0xFF));
data.AddRange(f2);
data.Add(0x02); // QoS 2
r.Reset(data.ToArray());
var (pi, filters, err) = MqttPacketHandlers.ParseSubsOrUnsubs(
r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Count, isSub: true);
err.ShouldBeNull();
filters!.Count.ShouldBe(2);
filters[0].Qos.ShouldBe((byte)0);
filters[1].Qos.ShouldBe((byte)2);
}
[Fact]
public void ParseSubs_WrongFlags_ShouldReturnError()
{
var r = new MqttReader();
var data = new byte[] { 0x00, 0x01, 0x00, 0x01, (byte)'t', 0x00 };
r.Reset(data);
// Wrong flags: 0x00 instead of 0x02
var (_, _, err) = MqttPacketHandlers.ParseSubsOrUnsubs(
r, MqttPacket.Sub, data.Length, isSub: true); // flags = 0x00
err.ShouldNotBeNull();
err.Message.ShouldContain("reserved flags");
}
[Fact]
public void ParseSubs_ZeroPacketId_ShouldReturnError()
{
var r = new MqttReader();
var data = new byte[] { 0x00, 0x00, 0x00, 0x01, (byte)'t', 0x00 };
r.Reset(data);
var (_, _, err) = MqttPacketHandlers.ParseSubsOrUnsubs(
r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Length, isSub: true);
err.ShouldNotBeNull();
err.Message.ShouldContain("packet identifier must not be 0");
}
[Fact]
public void ParseSubs_InvalidQoS_ShouldReturnError()
{
var r = new MqttReader();
var data = new byte[] { 0x00, 0x01, 0x00, 0x01, (byte)'t', 0x03 }; // QoS=3, invalid
r.Reset(data);
var (_, _, err) = MqttPacketHandlers.ParseSubsOrUnsubs(
r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Length, isSub: true);
err.ShouldNotBeNull();
err.Message.ShouldContain("invalid QoS");
}
// =========================================================================
// SUBSCRIBE processing via parser
// =========================================================================
[Fact]
public void Parser_Subscribe_ShouldSendSubAck()
{
var c = CreateConnectedMqttClient();
// Build SUBSCRIBE: PI=5, filter="test/topic", QoS=1
var filter = Encoding.UTF8.GetBytes("test/topic");
var payload = new List<byte>();
payload.Add(0x00); payload.Add(0x05); // PI = 5
payload.Add((byte)(filter.Length >> 8));
payload.Add((byte)(filter.Length & 0xFF));
payload.AddRange(filter);
payload.Add(0x01); // QoS 1
var buf = new List<byte>();
buf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags)); // 0x82
buf.Add((byte)payload.Count);
buf.AddRange(payload);
var err = MqttParser.Parse(c, buf.ToArray(), buf.Count);
err.ShouldBeNull();
// Verify SUBACK was written.
var ms = GetStream(c);
var data = ms.ToArray();
data.Length.ShouldBeGreaterThan(0);
data[0].ShouldBe(MqttPacket.SubAck); // 0x90
// SUBACK payload: PI (2 bytes) + QoS per filter (1 byte) = 3
// Format: [0x90] [0x03] [0x00] [0x05] [0x01]
data[1].ShouldBe((byte)0x03); // remaining length
data[2].ShouldBe((byte)0x00); // PI high
data[3].ShouldBe((byte)0x05); // PI low
data[4].ShouldBe((byte)0x01); // granted QoS 1
}
[Fact]
public void Parser_Subscribe_QoS2Downgrade_ShouldGrantQoS1()
{
var c = CreateConnectedMqttClient();
c.Mqtt!.DowngradeQoS2Sub = true;
var filter = Encoding.UTF8.GetBytes("a");
var payload = new List<byte>();
payload.Add(0x00); payload.Add(0x01); // PI = 1
payload.Add(0x00); payload.Add((byte)filter.Length);
payload.AddRange(filter);
payload.Add(0x02); // QoS 2 requested
var buf = new List<byte>();
buf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags));
buf.Add((byte)payload.Count);
buf.AddRange(payload);
var err = MqttParser.Parse(c, buf.ToArray(), buf.Count);
err.ShouldBeNull();
var ms = GetStream(c);
var data = ms.ToArray();
// Last byte of SUBACK is the granted QoS, should be 1 (downgraded from 2).
data[^1].ShouldBe((byte)0x01);
}
// =========================================================================
// UNSUBSCRIBE parsing
// =========================================================================
[Fact]
public void ParseUnsubs_SingleFilter_ShouldParseCorrectly()
{
var r = new MqttReader();
var filter = Encoding.UTF8.GetBytes("test/topic");
var data = new List<byte>();
data.Add(0x00); data.Add(0x03); // PI = 3
data.Add((byte)(filter.Length >> 8));
data.Add((byte)(filter.Length & 0xFF));
data.AddRange(filter);
// No QoS byte for UNSUBSCRIBE
r.Reset(data.ToArray());
var (pi, filters, err) = MqttPacketHandlers.ParseSubsOrUnsubs(
r, (byte)(MqttPacket.Unsub | MqttConst.UnsubscribeFlags), data.Count, isSub: false);
err.ShouldBeNull();
pi.ShouldBe((ushort)3);
filters!.Count.ShouldBe(1);
filters[0].Qos.ShouldBe((byte)0); // Always 0 for unsub
}
// =========================================================================
// UNSUBSCRIBE processing via parser
// =========================================================================
[Fact]
public void Parser_Unsubscribe_ShouldSendUnsubAck()
{
var c = CreateConnectedMqttClient();
// First subscribe to create the subscription.
var filter = Encoding.UTF8.GetBytes("test/unsub");
var subPayload = new List<byte>();
subPayload.Add(0x00); subPayload.Add(0x01); // PI = 1
subPayload.Add((byte)(filter.Length >> 8));
subPayload.Add((byte)(filter.Length & 0xFF));
subPayload.AddRange(filter);
subPayload.Add(0x00); // QoS 0
var subBuf = new List<byte>();
subBuf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags));
subBuf.Add((byte)subPayload.Count);
subBuf.AddRange(subPayload);
MqttParser.Parse(c, subBuf.ToArray(), subBuf.Count);
// Reset stream to capture UNSUBACK only.
var ms = GetStream(c);
ms.SetLength(0);
// Now unsubscribe.
var unsubPayload = new List<byte>();
unsubPayload.Add(0x00); unsubPayload.Add(0x02); // PI = 2
unsubPayload.Add((byte)(filter.Length >> 8));
unsubPayload.Add((byte)(filter.Length & 0xFF));
unsubPayload.AddRange(filter);
var unsubBuf = new List<byte>();
unsubBuf.Add((byte)(MqttPacket.Unsub | MqttConst.UnsubscribeFlags)); // 0xA2
unsubBuf.Add((byte)unsubPayload.Count);
unsubBuf.AddRange(unsubPayload);
var err = MqttParser.Parse(c, unsubBuf.ToArray(), unsubBuf.Count);
err.ShouldBeNull();
// Verify UNSUBACK: [0xB0] [0x02] [PI high] [PI low]
var data = ms.ToArray();
data.Length.ShouldBe(4);
data[0].ShouldBe(MqttPacket.UnsubAck); // 0xB0
data[1].ShouldBe((byte)0x02);
data[2].ShouldBe((byte)0x00); // PI high
data[3].ShouldBe((byte)0x02); // PI low
}
// =========================================================================
// Full CONNECT + SUBSCRIBE + PUBLISH + UNSUBSCRIBE flow
// =========================================================================
[Fact]
public void Parser_FullFlow_ConnectSubPubUnsub()
{
var ms = new MemoryStream();
var c = new ClientConnection(ClientKind.Client, nc: ms);
c.InitMqtt(new MqttHandler());
// 1. CONNECT
var connectBuf = BuildConnectPacket("flow-test");
var err = MqttParser.Parse(c, connectBuf, connectBuf.Length);
err.ShouldBeNull();
(c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0);
// 2. SUBSCRIBE to "test/flow" QoS 0
var filter = Encoding.UTF8.GetBytes("test/flow");
var subPayload = new List<byte>();
subPayload.Add(0x00); subPayload.Add(0x01); // PI = 1
subPayload.Add((byte)(filter.Length >> 8));
subPayload.Add((byte)(filter.Length & 0xFF));
subPayload.AddRange(filter);
subPayload.Add(0x00); // QoS 0
var subBuf = new List<byte>();
subBuf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags));
subBuf.Add((byte)subPayload.Count);
subBuf.AddRange(subPayload);
err = MqttParser.Parse(c, subBuf.ToArray(), subBuf.Count);
err.ShouldBeNull();
// 3. PUBLISH to "test/flow" QoS 0
var topic = Encoding.UTF8.GetBytes("test/flow");
var pubData = new List<byte>();
pubData.Add((byte)(topic.Length >> 8));
pubData.Add((byte)(topic.Length & 0xFF));
pubData.AddRange(topic);
pubData.AddRange(Encoding.UTF8.GetBytes("hello"));
var pubBuf = new List<byte>();
pubBuf.Add(MqttPacket.Pub);
pubBuf.Add((byte)pubData.Count);
pubBuf.AddRange(pubData);
err = MqttParser.Parse(c, pubBuf.ToArray(), pubBuf.Count);
err.ShouldBeNull();
// 4. UNSUBSCRIBE from "test/flow"
var unsubPayload = new List<byte>();
unsubPayload.Add(0x00); unsubPayload.Add(0x02); // PI = 2
unsubPayload.Add((byte)(filter.Length >> 8));
unsubPayload.Add((byte)(filter.Length & 0xFF));
unsubPayload.AddRange(filter);
var unsubBuf = new List<byte>();
unsubBuf.Add((byte)(MqttPacket.Unsub | MqttConst.UnsubscribeFlags));
unsubBuf.Add((byte)unsubPayload.Count);
unsubBuf.AddRange(unsubPayload);
err = MqttParser.Parse(c, unsubBuf.ToArray(), unsubBuf.Count);
err.ShouldBeNull();
// Verify: CONNACK(4) + SUBACK(5) + UNSUBACK(4) = 13 bytes written
var data = ms.ToArray();
data.Length.ShouldBe(13);
data[0].ShouldBe(MqttPacket.ConnectAck); // CONNACK
data[4].ShouldBe(MqttPacket.SubAck); // SUBACK
data[9].ShouldBe(MqttPacket.UnsubAck); // UNSUBACK
}
/// <summary>Builds a minimal MQTT CONNECT packet.</summary>
private static byte[] BuildConnectPacket(string clientId)
{
var payload = new List<byte>();
payload.AddRange(new byte[] { 0x00, 0x04 });
payload.AddRange(Encoding.UTF8.GetBytes("MQTT"));
payload.Add(0x04);
payload.Add(0x02); // clean session
payload.AddRange(new byte[] { 0x00, 0x3C });
var cidBytes = Encoding.UTF8.GetBytes(clientId);
payload.Add((byte)(cidBytes.Length >> 8));
payload.Add((byte)(cidBytes.Length & 0xFF));
payload.AddRange(cidBytes);
var result = new List<byte> { MqttPacket.Connect };
var remLen = payload.Count;
do
{
var b = (byte)(remLen & 0x7F);
remLen >>= 7;
if (remLen > 0) b |= 0x80;
result.Add(b);
} while (remLen > 0);
result.AddRange(payload);
return result.ToArray();
}
}