// Copyright 2020-2026 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 using System.Text; using Shouldly; using ZB.MOM.NatsNet.Server; using ZB.MOM.NatsNet.Server.Internal; using ZB.MOM.NatsNet.Server.Mqtt; namespace ZB.MOM.NatsNet.Server.Tests.Mqtt; /// /// Unit tests for MQTT PUBLISH, SUBSCRIBE, and UNSUBSCRIBE packet handling. /// public sealed class MqttPubSubTests { private static ClientConnection CreateConnectedMqttClient() { var ms = new MemoryStream(); var c = new ClientConnection(ClientKind.Client, nc: ms); c.InitMqtt(new MqttHandler()); c.Flags |= ClientFlags.ConnectReceived; return c; } private static MemoryStream GetStream(ClientConnection c) { return (MemoryStream)typeof(ClientConnection) .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! .GetValue(c)!; } // ========================================================================= // PUBLISH parsing // ========================================================================= [Fact] public void ParsePub_QoS0_ShouldParseCorrectly() { var r = new MqttReader(); // Topic "test/topic" + payload "hello" var topic = Encoding.UTF8.GetBytes("test/topic"); var payload = Encoding.UTF8.GetBytes("hello"); var data = new List(); data.Add((byte)(topic.Length >> 8)); data.Add((byte)(topic.Length & 0xFF)); data.AddRange(topic); data.AddRange(payload); r.Reset(data.ToArray()); byte flags = 0x00; // QoS 0, no retain, no dup var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: false); err.ShouldBeNull(); pp.ShouldNotBeNull(); pp!.Topic.ShouldBe("test/topic"); pp.Subject.ShouldNotBeEmpty(); pp.Qos.ShouldBe((byte)0); pp.Pi.ShouldBe((ushort)0); pp.Retain.ShouldBeFalse(); pp.Dup.ShouldBeFalse(); Encoding.UTF8.GetString(pp.Msg!).ShouldBe("hello"); } [Fact] public void ParsePub_QoS1_ShouldParsePacketId() { var r = new MqttReader(); var topic = Encoding.UTF8.GetBytes("a/b"); var data = new List(); data.Add((byte)(topic.Length >> 8)); data.Add((byte)(topic.Length & 0xFF)); data.AddRange(topic); data.Add(0x00); data.Add(0x07); // PI = 7 data.AddRange(Encoding.UTF8.GetBytes("msg")); r.Reset(data.ToArray()); byte flags = MqttPubFlag.QoS1; // QoS 1 = 0x02 var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: false); err.ShouldBeNull(); pp!.Qos.ShouldBe((byte)1); pp.Pi.ShouldBe((ushort)7); } [Fact] public void ParsePub_QoS2Rejected_ShouldReturnError() { var r = new MqttReader(); var topic = Encoding.UTF8.GetBytes("t"); var data = new List(); data.Add(0x00); data.Add(0x01); data.AddRange(topic); data.Add(0x00); data.Add(0x01); // PI = 1 r.Reset(data.ToArray()); byte flags = MqttPubFlag.QoS2; // QoS 2 = 0x04 var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: true); err.ShouldNotBeNull(); err.Message.ShouldContain("QoS-2 PUBLISH rejected"); } [Fact] public void ParsePub_EmptyTopic_ShouldReturnError() { var r = new MqttReader(); var data = new byte[] { 0x00, 0x00 }; // zero-length topic r.Reset(data); var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Length, 0x00, rejectQoS2: false); err.ShouldNotBeNull(); err.Message.ShouldContain("empty topic"); } [Fact] public void ParsePub_RetainAndDup_ShouldSetFlags() { var r = new MqttReader(); var topic = Encoding.UTF8.GetBytes("t"); var data = new List(); data.Add(0x00); data.Add(0x01); data.AddRange(topic); r.Reset(data.ToArray()); byte flags = MqttPubFlag.Retain | MqttPubFlag.Dup; // 0x09 var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, flags, rejectQoS2: false); err.ShouldBeNull(); pp!.Retain.ShouldBeTrue(); pp.Dup.ShouldBeTrue(); } [Fact] public void ParsePub_EmptyPayload_ShouldSucceed() { var r = new MqttReader(); var topic = Encoding.UTF8.GetBytes("t"); var data = new List(); data.Add(0x00); data.Add(0x01); data.AddRange(topic); // No payload bytes. r.Reset(data.ToArray()); var (pp, err) = MqttPacketHandlers.ParsePub(r, data.Count, 0x00, rejectQoS2: false); err.ShouldBeNull(); pp!.Msg.ShouldBeNull(); } // ========================================================================= // PUBLISH processing via parser // ========================================================================= [Fact] public void Parser_PublishQoS0_ShouldSucceed() { var c = CreateConnectedMqttClient(); // Build PUBLISH: type=0x30 (QoS 0), topic="test", payload="hi" var topic = Encoding.UTF8.GetBytes("test"); var payload = Encoding.UTF8.GetBytes("hi"); var data = new List(); data.Add(0x00); data.Add((byte)topic.Length); data.AddRange(topic); data.AddRange(payload); var buf = new List(); buf.Add(MqttPacket.Pub); // 0x30 buf.Add((byte)data.Count); buf.AddRange(data); var err = MqttParser.Parse(c, buf.ToArray(), buf.Count); err.ShouldBeNull(); } [Fact] public void Parser_PublishQoS1_ShouldSendPubAck() { var c = CreateConnectedMqttClient(); // PUBLISH QoS 1: type=0x32, topic="t", PI=5, payload="x" var data = new List(); data.Add(0x00); data.Add(0x01); data.Add((byte)'t'); // topic data.Add(0x00); data.Add(0x05); // PI = 5 data.Add((byte)'x'); // payload var buf = new List(); buf.Add((byte)(MqttPacket.Pub | MqttPubFlag.QoS1)); // 0x32 buf.Add((byte)data.Count); buf.AddRange(data); var err = MqttParser.Parse(c, buf.ToArray(), buf.Count); err.ShouldBeNull(); // Verify PUBACK: [0x40] [0x02] [PI high] [PI low] var ms = GetStream(c); var written = ms.ToArray(); written.Length.ShouldBe(4); written[0].ShouldBe(MqttPacket.PubAck); // 0x40 written[1].ShouldBe((byte)0x02); written[2].ShouldBe((byte)0x00); // PI high written[3].ShouldBe((byte)0x05); // PI low } [Fact] public void Parser_PublishQoS2_ShouldSendPubRec() { var c = CreateConnectedMqttClient(); // PUBLISH QoS 2: type=0x34, topic="t", PI=10, payload="y" var data = new List(); data.Add(0x00); data.Add(0x01); data.Add((byte)'t'); // topic data.Add(0x00); data.Add(0x0A); // PI = 10 data.Add((byte)'y'); // payload var buf = new List(); buf.Add((byte)(MqttPacket.Pub | MqttPubFlag.QoS2)); // 0x34 buf.Add((byte)data.Count); buf.AddRange(data); var err = MqttParser.Parse(c, buf.ToArray(), buf.Count); err.ShouldBeNull(); // Verify PUBREC: [0x50] [0x02] [PI high] [PI low] var ms = GetStream(c); var written = ms.ToArray(); written.Length.ShouldBe(4); written[0].ShouldBe(MqttPacket.PubRec); // 0x50 written[1].ShouldBe((byte)0x02); written[2].ShouldBe((byte)0x00); // PI high written[3].ShouldBe((byte)0x0A); // PI low // Message should be stored in QoS2Pending. c.Mqtt!.QoS2Pending.ShouldContainKey((ushort)10); } [Fact] public void Parser_QoS2_FullHandshake_PubRecPubRelPubComp() { var c = CreateConnectedMqttClient(); // Step 1: PUBLISH QoS 2 → PUBREC var pubData = new List(); pubData.Add(0x00); pubData.Add(0x01); pubData.Add((byte)'t'); pubData.Add(0x00); pubData.Add(0x07); // PI = 7 pubData.AddRange(Encoding.UTF8.GetBytes("qos2msg")); var pubBuf = new List(); pubBuf.Add((byte)(MqttPacket.Pub | MqttPubFlag.QoS2)); pubBuf.Add((byte)pubData.Count); pubBuf.AddRange(pubData); var err = MqttParser.Parse(c, pubBuf.ToArray(), pubBuf.Count); err.ShouldBeNull(); c.Mqtt!.QoS2Pending.ShouldContainKey((ushort)7); // Reset stream to capture only PUBCOMP. var ms = GetStream(c); ms.SetLength(0); // Step 2: PUBREL from client → PUBCOMP // PUBREL: [0x62] [0x02] [PI high] [PI low] var pubrelBuf = new byte[] { 0x62, 0x02, 0x00, 0x07 }; err = MqttParser.Parse(c, pubrelBuf, pubrelBuf.Length); err.ShouldBeNull(); // Verify PUBCOMP was sent. var written = ms.ToArray(); written.Length.ShouldBe(4); written[0].ShouldBe(MqttPacket.PubComp); // 0x70 written[1].ShouldBe((byte)0x02); written[2].ShouldBe((byte)0x00); written[3].ShouldBe((byte)0x07); // Message should be removed from QoS2Pending. c.Mqtt.QoS2Pending.ShouldNotContainKey((ushort)7); } [Fact] public void Parser_PubAck_ShouldRemoveFromPending() { var c = CreateConnectedMqttClient(); // Pre-populate pending. c.Mqtt!.Pending[(ushort)3] = null; c.Mqtt.Pending.ShouldContainKey((ushort)3); // PUBACK: [0x40] [0x02] [0x00] [0x03] var buf = new byte[] { 0x40, 0x02, 0x00, 0x03 }; var err = MqttParser.Parse(c, buf, buf.Length); err.ShouldBeNull(); c.Mqtt.Pending.ShouldNotContainKey((ushort)3); } [Fact] public void Parser_PubRec_ShouldSendPubRel() { var c = CreateConnectedMqttClient(); // Pre-populate pending. c.Mqtt!.Pending[(ushort)9] = null; // PUBREC: [0x50] [0x02] [0x00] [0x09] var buf = new byte[] { 0x50, 0x02, 0x00, 0x09 }; var err = MqttParser.Parse(c, buf, buf.Length); err.ShouldBeNull(); // Should have sent PUBREL: [0x62] [0x02] [0x00] [0x09] var ms = GetStream(c); var written = ms.ToArray(); written.Length.ShouldBe(4); written[0].ShouldBe((byte)0x62); // PUBREL with bit 1 set written[3].ShouldBe((byte)0x09); // Pending should be cleared. c.Mqtt.Pending.ShouldNotContainKey((ushort)9); } [Fact] public void Parser_PubComp_ShouldRemoveFromPending() { var c = CreateConnectedMqttClient(); c.Mqtt!.Pending[(ushort)15] = null; // PUBCOMP: [0x70] [0x02] [0x00] [0x0F] var buf = new byte[] { 0x70, 0x02, 0x00, 0x0F }; var err = MqttParser.Parse(c, buf, buf.Length); err.ShouldBeNull(); c.Mqtt.Pending.ShouldNotContainKey((ushort)15); } // ========================================================================= // SUBSCRIBE parsing // ========================================================================= [Fact] public void ParseSubs_SingleFilter_ShouldParseCorrectly() { var r = new MqttReader(); var filter = Encoding.UTF8.GetBytes("test/topic"); var data = new List(); data.Add(0x00); data.Add(0x0A); // PI = 10 data.Add((byte)(filter.Length >> 8)); data.Add((byte)(filter.Length & 0xFF)); data.AddRange(filter); data.Add(0x01); // QoS 1 r.Reset(data.ToArray()); var (pi, filters, err) = MqttPacketHandlers.ParseSubsOrUnsubs( r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Count, isSub: true); err.ShouldBeNull(); pi.ShouldBe((ushort)10); filters.ShouldNotBeNull(); filters!.Count.ShouldBe(1); filters[0].Filter.ShouldNotBeEmpty(); filters[0].Qos.ShouldBe((byte)1); } [Fact] public void ParseSubs_MultipleFilters_ShouldParseAll() { var r = new MqttReader(); var f1 = Encoding.UTF8.GetBytes("a/b"); var f2 = Encoding.UTF8.GetBytes("c/d"); var data = new List(); data.Add(0x00); data.Add(0x01); // PI = 1 data.Add((byte)(f1.Length >> 8)); data.Add((byte)(f1.Length & 0xFF)); data.AddRange(f1); data.Add(0x00); // QoS 0 data.Add((byte)(f2.Length >> 8)); data.Add((byte)(f2.Length & 0xFF)); data.AddRange(f2); data.Add(0x02); // QoS 2 r.Reset(data.ToArray()); var (pi, filters, err) = MqttPacketHandlers.ParseSubsOrUnsubs( r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Count, isSub: true); err.ShouldBeNull(); filters!.Count.ShouldBe(2); filters[0].Qos.ShouldBe((byte)0); filters[1].Qos.ShouldBe((byte)2); } [Fact] public void ParseSubs_WrongFlags_ShouldReturnError() { var r = new MqttReader(); var data = new byte[] { 0x00, 0x01, 0x00, 0x01, (byte)'t', 0x00 }; r.Reset(data); // Wrong flags: 0x00 instead of 0x02 var (_, _, err) = MqttPacketHandlers.ParseSubsOrUnsubs( r, MqttPacket.Sub, data.Length, isSub: true); // flags = 0x00 err.ShouldNotBeNull(); err.Message.ShouldContain("reserved flags"); } [Fact] public void ParseSubs_ZeroPacketId_ShouldReturnError() { var r = new MqttReader(); var data = new byte[] { 0x00, 0x00, 0x00, 0x01, (byte)'t', 0x00 }; r.Reset(data); var (_, _, err) = MqttPacketHandlers.ParseSubsOrUnsubs( r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Length, isSub: true); err.ShouldNotBeNull(); err.Message.ShouldContain("packet identifier must not be 0"); } [Fact] public void ParseSubs_InvalidQoS_ShouldReturnError() { var r = new MqttReader(); var data = new byte[] { 0x00, 0x01, 0x00, 0x01, (byte)'t', 0x03 }; // QoS=3, invalid r.Reset(data); var (_, _, err) = MqttPacketHandlers.ParseSubsOrUnsubs( r, (byte)(MqttPacket.Sub | MqttConst.SubscribeFlags), data.Length, isSub: true); err.ShouldNotBeNull(); err.Message.ShouldContain("invalid QoS"); } // ========================================================================= // SUBSCRIBE processing via parser // ========================================================================= [Fact] public void Parser_Subscribe_ShouldSendSubAck() { var c = CreateConnectedMqttClient(); // Build SUBSCRIBE: PI=5, filter="test/topic", QoS=1 var filter = Encoding.UTF8.GetBytes("test/topic"); var payload = new List(); payload.Add(0x00); payload.Add(0x05); // PI = 5 payload.Add((byte)(filter.Length >> 8)); payload.Add((byte)(filter.Length & 0xFF)); payload.AddRange(filter); payload.Add(0x01); // QoS 1 var buf = new List(); buf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags)); // 0x82 buf.Add((byte)payload.Count); buf.AddRange(payload); var err = MqttParser.Parse(c, buf.ToArray(), buf.Count); err.ShouldBeNull(); // Verify SUBACK was written. var ms = GetStream(c); var data = ms.ToArray(); data.Length.ShouldBeGreaterThan(0); data[0].ShouldBe(MqttPacket.SubAck); // 0x90 // SUBACK payload: PI (2 bytes) + QoS per filter (1 byte) = 3 // Format: [0x90] [0x03] [0x00] [0x05] [0x01] data[1].ShouldBe((byte)0x03); // remaining length data[2].ShouldBe((byte)0x00); // PI high data[3].ShouldBe((byte)0x05); // PI low data[4].ShouldBe((byte)0x01); // granted QoS 1 } [Fact] public void Parser_Subscribe_QoS2Downgrade_ShouldGrantQoS1() { var c = CreateConnectedMqttClient(); c.Mqtt!.DowngradeQoS2Sub = true; var filter = Encoding.UTF8.GetBytes("a"); var payload = new List(); payload.Add(0x00); payload.Add(0x01); // PI = 1 payload.Add(0x00); payload.Add((byte)filter.Length); payload.AddRange(filter); payload.Add(0x02); // QoS 2 requested var buf = new List(); buf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags)); buf.Add((byte)payload.Count); buf.AddRange(payload); var err = MqttParser.Parse(c, buf.ToArray(), buf.Count); err.ShouldBeNull(); var ms = GetStream(c); var data = ms.ToArray(); // Last byte of SUBACK is the granted QoS, should be 1 (downgraded from 2). data[^1].ShouldBe((byte)0x01); } // ========================================================================= // UNSUBSCRIBE parsing // ========================================================================= [Fact] public void ParseUnsubs_SingleFilter_ShouldParseCorrectly() { var r = new MqttReader(); var filter = Encoding.UTF8.GetBytes("test/topic"); var data = new List(); data.Add(0x00); data.Add(0x03); // PI = 3 data.Add((byte)(filter.Length >> 8)); data.Add((byte)(filter.Length & 0xFF)); data.AddRange(filter); // No QoS byte for UNSUBSCRIBE r.Reset(data.ToArray()); var (pi, filters, err) = MqttPacketHandlers.ParseSubsOrUnsubs( r, (byte)(MqttPacket.Unsub | MqttConst.UnsubscribeFlags), data.Count, isSub: false); err.ShouldBeNull(); pi.ShouldBe((ushort)3); filters!.Count.ShouldBe(1); filters[0].Qos.ShouldBe((byte)0); // Always 0 for unsub } // ========================================================================= // UNSUBSCRIBE processing via parser // ========================================================================= [Fact] public void Parser_Unsubscribe_ShouldSendUnsubAck() { var c = CreateConnectedMqttClient(); // First subscribe to create the subscription. var filter = Encoding.UTF8.GetBytes("test/unsub"); var subPayload = new List(); subPayload.Add(0x00); subPayload.Add(0x01); // PI = 1 subPayload.Add((byte)(filter.Length >> 8)); subPayload.Add((byte)(filter.Length & 0xFF)); subPayload.AddRange(filter); subPayload.Add(0x00); // QoS 0 var subBuf = new List(); subBuf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags)); subBuf.Add((byte)subPayload.Count); subBuf.AddRange(subPayload); MqttParser.Parse(c, subBuf.ToArray(), subBuf.Count); // Reset stream to capture UNSUBACK only. var ms = GetStream(c); ms.SetLength(0); // Now unsubscribe. var unsubPayload = new List(); unsubPayload.Add(0x00); unsubPayload.Add(0x02); // PI = 2 unsubPayload.Add((byte)(filter.Length >> 8)); unsubPayload.Add((byte)(filter.Length & 0xFF)); unsubPayload.AddRange(filter); var unsubBuf = new List(); unsubBuf.Add((byte)(MqttPacket.Unsub | MqttConst.UnsubscribeFlags)); // 0xA2 unsubBuf.Add((byte)unsubPayload.Count); unsubBuf.AddRange(unsubPayload); var err = MqttParser.Parse(c, unsubBuf.ToArray(), unsubBuf.Count); err.ShouldBeNull(); // Verify UNSUBACK: [0xB0] [0x02] [PI high] [PI low] var data = ms.ToArray(); data.Length.ShouldBe(4); data[0].ShouldBe(MqttPacket.UnsubAck); // 0xB0 data[1].ShouldBe((byte)0x02); data[2].ShouldBe((byte)0x00); // PI high data[3].ShouldBe((byte)0x02); // PI low } // ========================================================================= // Full CONNECT + SUBSCRIBE + PUBLISH + UNSUBSCRIBE flow // ========================================================================= [Fact] public void Parser_FullFlow_ConnectSubPubUnsub() { var ms = new MemoryStream(); var c = new ClientConnection(ClientKind.Client, nc: ms); c.InitMqtt(new MqttHandler()); // 1. CONNECT var connectBuf = BuildConnectPacket("flow-test"); var err = MqttParser.Parse(c, connectBuf, connectBuf.Length); err.ShouldBeNull(); (c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0); // 2. SUBSCRIBE to "test/flow" QoS 0 var filter = Encoding.UTF8.GetBytes("test/flow"); var subPayload = new List(); subPayload.Add(0x00); subPayload.Add(0x01); // PI = 1 subPayload.Add((byte)(filter.Length >> 8)); subPayload.Add((byte)(filter.Length & 0xFF)); subPayload.AddRange(filter); subPayload.Add(0x00); // QoS 0 var subBuf = new List(); subBuf.Add((byte)(MqttPacket.Sub | MqttConst.SubscribeFlags)); subBuf.Add((byte)subPayload.Count); subBuf.AddRange(subPayload); err = MqttParser.Parse(c, subBuf.ToArray(), subBuf.Count); err.ShouldBeNull(); // 3. PUBLISH to "test/flow" QoS 0 var topic = Encoding.UTF8.GetBytes("test/flow"); var pubData = new List(); pubData.Add((byte)(topic.Length >> 8)); pubData.Add((byte)(topic.Length & 0xFF)); pubData.AddRange(topic); pubData.AddRange(Encoding.UTF8.GetBytes("hello")); var pubBuf = new List(); pubBuf.Add(MqttPacket.Pub); pubBuf.Add((byte)pubData.Count); pubBuf.AddRange(pubData); err = MqttParser.Parse(c, pubBuf.ToArray(), pubBuf.Count); err.ShouldBeNull(); // 4. UNSUBSCRIBE from "test/flow" var unsubPayload = new List(); unsubPayload.Add(0x00); unsubPayload.Add(0x02); // PI = 2 unsubPayload.Add((byte)(filter.Length >> 8)); unsubPayload.Add((byte)(filter.Length & 0xFF)); unsubPayload.AddRange(filter); var unsubBuf = new List(); unsubBuf.Add((byte)(MqttPacket.Unsub | MqttConst.UnsubscribeFlags)); unsubBuf.Add((byte)unsubPayload.Count); unsubBuf.AddRange(unsubPayload); err = MqttParser.Parse(c, unsubBuf.ToArray(), unsubBuf.Count); err.ShouldBeNull(); // Verify: CONNACK(4) + SUBACK(5) + UNSUBACK(4) = 13 bytes written var data = ms.ToArray(); data.Length.ShouldBe(13); data[0].ShouldBe(MqttPacket.ConnectAck); // CONNACK data[4].ShouldBe(MqttPacket.SubAck); // SUBACK data[9].ShouldBe(MqttPacket.UnsubAck); // UNSUBACK } /// Builds a minimal MQTT CONNECT packet. private static byte[] BuildConnectPacket(string clientId) { var payload = new List(); payload.AddRange(new byte[] { 0x00, 0x04 }); payload.AddRange(Encoding.UTF8.GetBytes("MQTT")); payload.Add(0x04); payload.Add(0x02); // clean session payload.AddRange(new byte[] { 0x00, 0x3C }); var cidBytes = Encoding.UTF8.GetBytes(clientId); payload.Add((byte)(cidBytes.Length >> 8)); payload.Add((byte)(cidBytes.Length & 0xFF)); payload.AddRange(cidBytes); var result = new List { MqttPacket.Connect }; var remLen = payload.Count; do { var b = (byte)(remLen & 0x7F); remLen >>= 7; if (remLen > 0) b |= 0x80; result.Add(b); } while (remLen > 0); result.AddRange(payload); return result.ToArray(); } }