// Copyright 2020-2026 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 using System.Text; using Shouldly; using ZB.MOM.NatsNet.Server; using ZB.MOM.NatsNet.Server.Internal; using ZB.MOM.NatsNet.Server.Mqtt; namespace ZB.MOM.NatsNet.Server.Tests.Mqtt; /// /// Unit tests for MQTT CONNECT/CONNACK/DISCONNECT packet handling. /// public sealed class MqttConnectTests { /// /// Builds a minimal valid MQTT CONNECT packet. /// private static byte[] BuildConnectPacket( string clientId = "test-client", bool cleanSession = true, ushort keepAlive = 60, string? willTopic = null, byte[]? willMessage = null, byte willQos = 0, bool willRetain = false, string? username = null, string? password = null) { var payload = new List(); // Variable header. // Protocol name "MQTT". var protoName = Encoding.UTF8.GetBytes("MQTT"); payload.Add((byte)(protoName.Length >> 8)); payload.Add((byte)(protoName.Length & 0xFF)); payload.AddRange(protoName); // Protocol level. payload.Add(0x04); // Connect flags. byte flags = 0; if (cleanSession) flags |= MqttConnectFlag.CleanSession; if (willTopic != null) { flags |= MqttConnectFlag.WillFlag; flags |= (byte)((willQos & 0x03) << 3); if (willRetain) flags |= MqttConnectFlag.WillRetain; } if (username != null) flags |= MqttConnectFlag.UsernameFlag; if (password != null) flags |= MqttConnectFlag.PasswordFlag; payload.Add(flags); // Keep alive. payload.Add((byte)(keepAlive >> 8)); payload.Add((byte)(keepAlive & 0xFF)); // Client ID. var cidBytes = Encoding.UTF8.GetBytes(clientId); payload.Add((byte)(cidBytes.Length >> 8)); payload.Add((byte)(cidBytes.Length & 0xFF)); payload.AddRange(cidBytes); // Will topic + message. if (willTopic != null) { var topicBytes = Encoding.UTF8.GetBytes(willTopic); payload.Add((byte)(topicBytes.Length >> 8)); payload.Add((byte)(topicBytes.Length & 0xFF)); payload.AddRange(topicBytes); var msg = willMessage ?? []; payload.Add((byte)(msg.Length >> 8)); payload.Add((byte)(msg.Length & 0xFF)); payload.AddRange(msg); } // Username. if (username != null) { var userBytes = Encoding.UTF8.GetBytes(username); payload.Add((byte)(userBytes.Length >> 8)); payload.Add((byte)(userBytes.Length & 0xFF)); payload.AddRange(userBytes); } // Password. if (password != null) { var passBytes = Encoding.UTF8.GetBytes(password); payload.Add((byte)(passBytes.Length >> 8)); payload.Add((byte)(passBytes.Length & 0xFF)); payload.AddRange(passBytes); } // Build full packet: type byte + remaining length + payload. var result = new List(); result.Add(MqttPacket.Connect); // Encode remaining length. var remLen = payload.Count; do { var encoded = (byte)(remLen & 0x7F); remLen >>= 7; if (remLen > 0) encoded |= 0x80; result.Add(encoded); } while (remLen > 0); result.AddRange(payload); return result.ToArray(); } private static ClientConnection CreateMqttClient() { var ms = new MemoryStream(); var c = new ClientConnection(ClientKind.Client, nc: ms); c.InitMqtt(new MqttHandler()); return c; } // ========================================================================= // ParseConnect tests // ========================================================================= [Fact] public void ParseConnect_ValidMinimal_ShouldSucceed() { var buf = BuildConnectPacket(); var r = new MqttReader(); // Skip the fixed header (type + remaining length) — parser handles that. // For direct ParseConnect testing, we feed only the variable header + payload. r.Reset(buf[2..]); // Skip type byte and 1-byte remaining length. var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); err.ShouldBeNull(); rc.ShouldBe(MqttConnAckRc.Accepted); cp.ShouldNotBeNull(); cp!.ClientId.ShouldBe("test-client"); cp.CleanSession.ShouldBeTrue(); cp.KeepAlive.ShouldBe((ushort)60); cp.Will.ShouldBeNull(); cp.Username.ShouldBeEmpty(); cp.Password.ShouldBeNull(); } [Fact] public void ParseConnect_WithWill_ShouldParseCorrectly() { var buf = BuildConnectPacket( willTopic: "test/will", willMessage: Encoding.UTF8.GetBytes("goodbye"), willQos: 1, willRetain: true); // Find remaining length to skip header correctly. int headerLen = 1; // type byte int remLen = 0; int mult = 1; for (int i = 1; i < buf.Length; i++) { remLen += (buf[i] & 0x7F) * mult; headerLen++; if ((buf[i] & 0x80) == 0) break; mult *= 128; } var r = new MqttReader(); r.Reset(buf[headerLen..]); var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); err.ShouldBeNull(); rc.ShouldBe(MqttConnAckRc.Accepted); cp!.Will.ShouldNotBeNull(); cp.Will!.Topic.ShouldBe("test/will"); cp.Will.Subject.ShouldNotBeEmpty(); // NATS-converted subject cp.Will.Msg.ShouldNotBeNull(); Encoding.UTF8.GetString(cp.Will.Msg!).ShouldBe("goodbye"); cp.Will.Qos.ShouldBe((byte)1); cp.Will.Retain.ShouldBeTrue(); } [Fact] public void ParseConnect_WithAuth_ShouldParseCorrectly() { var buf = BuildConnectPacket(username: "user1", password: "pass1"); var r = new MqttReader(); r.Reset(buf[2..]); var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); err.ShouldBeNull(); rc.ShouldBe(MqttConnAckRc.Accepted); cp!.Username.ShouldBe("user1"); cp.Password.ShouldNotBeNull(); Encoding.UTF8.GetString(cp.Password!).ShouldBe("pass1"); } [Fact] public void ParseConnect_WrongProtocolName_ShouldRejectWithCode() { // Build a packet with wrong protocol name. var buf = new List(); var wrong = Encoding.UTF8.GetBytes("MQIsdp"); buf.Add((byte)(wrong.Length >> 8)); buf.Add((byte)(wrong.Length & 0xFF)); buf.AddRange(wrong); buf.Add(0x03); // level buf.Add(0x02); // clean session buf.AddRange(new byte[] { 0x00, 0x3C }); // keepalive=60 buf.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId="AB" var r = new MqttReader(); r.Reset(buf.ToArray()); var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); rc.ShouldBe(MqttConnAckRc.UnacceptableProtocol); err.ShouldNotBeNull(); } [Fact] public void ParseConnect_EmptyClientIdWithoutCleanSession_ShouldReject() { var buf = BuildConnectPacket(clientId: "", cleanSession: false); var r = new MqttReader(); r.Reset(buf[2..]); var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); rc.ShouldBe(MqttConnAckRc.IdentifierRejected); err.ShouldNotBeNull(); } [Fact] public void ParseConnect_EmptyClientIdWithCleanSession_ShouldAutoGenerate() { var buf = BuildConnectPacket(clientId: "", cleanSession: true); var r = new MqttReader(); r.Reset(buf[2..]); var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); err.ShouldBeNull(); rc.ShouldBe(MqttConnAckRc.Accepted); cp!.ClientId.ShouldNotBeEmpty(); // Auto-generated cp.ClientId.Length.ShouldBe(32); // GUID "N" format } [Fact] public void ParseConnect_ReservedBitSet_ShouldError() { // Build manually with reserved bit set. var payload = new List(); var proto = Encoding.UTF8.GetBytes("MQTT"); payload.Add(0); payload.Add(4); payload.AddRange(proto); payload.Add(0x04); // level payload.Add(0x03); // clean session + reserved bit! payload.AddRange(new byte[] { 0x00, 0x00 }); // keepalive payload.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId var r = new MqttReader(); r.Reset(payload.ToArray()); var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); err.ShouldNotBeNull(); err.Message.ShouldContain("reserved bit"); } // ========================================================================= // ProcessConnect + CONNACK tests // ========================================================================= [Fact] public void ProcessConnect_ShouldSetFlagsAndSendConnAck() { var c = CreateMqttClient(); var cp = new MqttConnectProto { ClientId = "test-123", CleanSession = true, KeepAlive = 30, }; var err = MqttPacketHandlers.ProcessConnect(c, cp); err.ShouldBeNull(); // Verify state. c.Mqtt!.ClientId.ShouldBe("test-123"); c.Mqtt.CleanSession.ShouldBeTrue(); (c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0); // Verify CONNACK was written. var ms = (MemoryStream)typeof(ClientConnection) .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! .GetValue(c)!; var data = ms.ToArray(); data.Length.ShouldBe(4); data[0].ShouldBe(MqttPacket.ConnectAck); data[1].ShouldBe((byte)0x02); data[2].ShouldBe((byte)0x00); // No session present data[3].ShouldBe(MqttConnAckRc.Accepted); } // ========================================================================= // Full CONNECT via parser integration // ========================================================================= [Fact] public void Parser_ConnectPacket_ShouldParseAndSendConnAck() { var c = CreateMqttClient(); var buf = BuildConnectPacket(clientId: "mqtt-parser-test"); var err = MqttParser.Parse(c, buf, buf.Length); err.ShouldBeNull(); // Verify connected. (c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0); c.Mqtt!.ClientId.ShouldBe("mqtt-parser-test"); // Verify CONNACK written. var ms = (MemoryStream)typeof(ClientConnection) .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! .GetValue(c)!; var data = ms.ToArray(); data.Length.ShouldBe(4); data[0].ShouldBe(MqttPacket.ConnectAck); data[3].ShouldBe(MqttConnAckRc.Accepted); } [Fact] public void Parser_ConnectThenPing_ShouldSucceed() { var c = CreateMqttClient(); var connectBuf = BuildConnectPacket(); var pingBuf = new byte[] { MqttPacket.Ping, 0x00 }; // Concatenate CONNECT + PING into one buffer. var buf = new byte[connectBuf.Length + pingBuf.Length]; Buffer.BlockCopy(connectBuf, 0, buf, 0, connectBuf.Length); Buffer.BlockCopy(pingBuf, 0, buf, connectBuf.Length, pingBuf.Length); var err = MqttParser.Parse(c, buf, buf.Length); err.ShouldBeNull(); // Verify CONNACK + PINGRESP written. var ms = (MemoryStream)typeof(ClientConnection) .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! .GetValue(c)!; var data = ms.ToArray(); data.Length.ShouldBe(6); // 4 (CONNACK) + 2 (PINGRESP) data[0].ShouldBe(MqttPacket.ConnectAck); data[4].ShouldBe(MqttPacket.PingResp); } // ========================================================================= // DISCONNECT tests // ========================================================================= [Fact] public void Parser_Disconnect_ShouldClearWillAndClose() { var c = CreateMqttClient(); // First, process a CONNECT with a will. var connectBuf = BuildConnectPacket( willTopic: "test/will", willMessage: Encoding.UTF8.GetBytes("bye")); var err = MqttParser.Parse(c, connectBuf, connectBuf.Length); err.ShouldBeNull(); c.Mqtt!.Will.ShouldNotBeNull(); // Now send DISCONNECT. var disconnectBuf = new byte[] { MqttPacket.Disconnect, 0x00 }; err = MqttParser.Parse(c, disconnectBuf, disconnectBuf.Length); err.ShouldBeNull(); // Will should be cleared. c.Mqtt.Will.ShouldBeNull(); } }