From 95cf20b00bbca15cdc2f32a3c1e46f62310a6687 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 1 Mar 2026 15:48:22 -0500 Subject: [PATCH] feat(mqtt): implement CONNECT/CONNACK/DISCONNECT packet handlers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement Task 3 of MQTT orchestration: - Create MqttPacketHandlers.cs with ParseConnect(), ProcessConnect(), EnqueueConnAck(), HandleDisconnect() - Wire CONNECT and DISCONNECT dispatch in MqttParser.cs - Parse CONNECT: protocol name/level, flags, keep-alive, client ID, will, auth - Send CONNACK (4-byte fixed packet with return code) - DISCONNECT clears will message and closes connection cleanly - Auto-generate client ID for empty ID + clean session - Validate reserved bit, will flags, username/password consistency - Add Reader field to MqttHandler for per-connection parsing - 11 unit tests for CONNECT parsing and processing - 1 end-to-end integration test: TCP → CONNECT → CONNACK over the wire --- .../Mqtt/MqttPacketHandlers.cs | 240 +++++++++++ .../ZB.MOM.NatsNet.Server/Mqtt/MqttParser.cs | 26 +- .../ServerBootTests.cs | 90 ++++ .../Mqtt/MqttConnectTests.cs | 386 ++++++++++++++++++ .../Mqtt/MqttParserTests.cs | 39 +- reports/current.md | 2 +- 6 files changed, 768 insertions(+), 15 deletions(-) create mode 100644 dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttPacketHandlers.cs create mode 100644 dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttConnectTests.cs diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttPacketHandlers.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttPacketHandlers.cs new file mode 100644 index 0000000..7e581db --- /dev/null +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttPacketHandlers.cs @@ -0,0 +1,240 @@ +// Copyright 2020-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Adapted from server/mqtt.go in the NATS server Go source. + +using System.Text; + +namespace ZB.MOM.NatsNet.Server.Mqtt; + +/// +/// MQTT packet parsing and processing handlers. +/// Mirrors the mqttParseConnect / mqttProcessConnect / mqttEnqueueConnAck +/// functions in server/mqtt.go. +/// +internal static class MqttPacketHandlers +{ + /// + /// Parses an MQTT CONNECT packet from the reader. + /// Returns (returnCode, connectProto, error). + /// returnCode == 0 means success; non-zero is a CONNACK return code. + /// Mirrors Go mqttParseConnect(). + /// + public static (byte rc, MqttConnectProto? cp, Exception? err) ParseConnect(MqttReader r) + { + // --- Protocol Name --- + string protoName; + try { protoName = r.ReadString("protocol name"); } + catch (Exception ex) { return (0, null, ex); } + + if (protoName != "MQTT") + return (MqttConnAckRc.UnacceptableProtocol, null, + new InvalidOperationException($"unsupported MQTT protocol: \"{protoName}\"")); + + // --- Protocol Level --- + byte level; + try { level = r.ReadByte("protocol level"); } + catch (Exception ex) { return (0, null, ex); } + + if (level != 0x04) + return (MqttConnAckRc.UnacceptableProtocol, null, + new InvalidOperationException($"unsupported MQTT protocol level: {level}")); + + // --- Connect Flags --- + byte flags; + try { flags = r.ReadByte("connect flags"); } + catch (Exception ex) { return (0, null, ex); } + + if ((flags & MqttConnectFlag.Reserved) != 0) + return (0, null, new InvalidOperationException("CONNECT flags reserved bit must be 0")); + + bool cleanSession = (flags & MqttConnectFlag.CleanSession) != 0; + bool willFlag = (flags & MqttConnectFlag.WillFlag) != 0; + byte willQos = (byte)((flags & MqttConnectFlag.WillQoS) >> 3); + bool willRetain = (flags & MqttConnectFlag.WillRetain) != 0; + bool hasPassword = (flags & MqttConnectFlag.PasswordFlag) != 0; + bool hasUsername = (flags & MqttConnectFlag.UsernameFlag) != 0; + + // Validate Will flags. + if (!willFlag) + { + if (willQos != 0) + return (0, null, new InvalidOperationException("Will QoS must be 0 when Will Flag is 0")); + if (willRetain) + return (0, null, new InvalidOperationException("Will Retain must be 0 when Will Flag is 0")); + } + else + { + if (willQos > 2) + return (0, null, new InvalidOperationException($"invalid Will QoS: {willQos}")); + } + + // Username/password consistency. + if (hasPassword && !hasUsername) + return (0, null, new InvalidOperationException("password flag without username flag")); + + // --- Keep Alive --- + ushort keepAlive; + try { keepAlive = r.ReadUInt16("keep alive"); } + catch (Exception ex) { return (0, null, ex); } + + // --- Client ID --- + string clientId; + try { clientId = r.ReadString("client id"); } + catch (Exception ex) { return (0, null, ex); } + + if (string.IsNullOrEmpty(clientId)) + { + if (!cleanSession) + return (MqttConnAckRc.IdentifierRejected, null, + new InvalidOperationException("empty client ID requires clean session flag")); + // Generate a unique client ID. + clientId = Guid.NewGuid().ToString("N"); + } + + // --- Will Topic & Message --- + MqttWill? will = null; + if (willFlag) + { + string willTopic; + try { willTopic = r.ReadString("will topic"); } + catch (Exception ex) { return (0, null, ex); } + + if (string.IsNullOrEmpty(willTopic)) + return (0, null, new InvalidOperationException("empty will topic")); + + byte[] willMsg; + try { willMsg = r.ReadBytes("will message", copy: true); } + catch (Exception ex) { return (0, null, ex); } + + // Convert MQTT topic to NATS subject. + var topicBytes = Encoding.UTF8.GetBytes(willTopic); + byte[] subjectBytes; + try { subjectBytes = MqttSubjectConverter.MqttTopicToNatsPubSubject(topicBytes); } + catch (Exception ex) { return (0, null, ex); } + + will = new MqttWill + { + Topic = willTopic, + Subject = Encoding.UTF8.GetString(subjectBytes), + Msg = willMsg.Length > 0 ? willMsg : null, + Qos = willQos, + Retain = willRetain, + }; + } + + // --- Username --- + string username = string.Empty; + if (hasUsername) + { + try { username = r.ReadString("username"); } + catch (Exception ex) { return (0, null, ex); } + + if (string.IsNullOrEmpty(username)) + return (0, null, new InvalidOperationException("empty username")); + } + + // --- Password --- + byte[]? password = null; + if (hasPassword) + { + try { password = r.ReadBytes("password", copy: true); } + catch (Exception ex) { return (0, null, ex); } + } + + var cp = new MqttConnectProto + { + ClientId = clientId, + Will = will, + Username = username, + Password = password, + CleanSession = cleanSession, + KeepAlive = keepAlive, + }; + + return (MqttConnAckRc.Accepted, cp, null); + } + + /// + /// Processes a parsed CONNECT packet: sets client state, sends CONNACK. + /// Minimal implementation — full session management deferred to Task 6. + /// Mirrors Go mqttProcessConnect(). + /// + public static Exception? ProcessConnect(ClientConnection c, MqttConnectProto cp) + { + var mqtt = c.Mqtt!; + + // Store client identity. + mqtt.ClientId = cp.ClientId; + mqtt.CleanSession = cp.CleanSession; + mqtt.KeepAlive = cp.KeepAlive; + mqtt.Will = cp.Will; + + // Store auth credentials on client options. + if (!string.IsNullOrEmpty(cp.Username)) + c.Opts.Username = cp.Username; + if (cp.Password != null) + c.Opts.Password = Encoding.UTF8.GetString(cp.Password); + + // Mark as connected. + c.Flags |= ClientFlags.ConnectReceived; + + // Set keep-alive read deadline. + if (cp.KeepAlive > 0) + { + // MQTT spec: server MUST disconnect if no packet within 1.5x keep-alive. + var deadline = TimeSpan.FromSeconds(cp.KeepAlive * 1.5); + mqtt.KeepAlive = cp.KeepAlive; + // TODO: set read deadline on connection stream (Task 7) + } + + // Send CONNACK (accepted, no session present for now). + EnqueueConnAck(c, MqttConnAckRc.Accepted, sessionPresent: false); + + return null; + } + + /// + /// Enqueues a CONNACK packet to the client. + /// Mirrors Go mqttEnqueueConnAck(). + /// + public static void EnqueueConnAck(ClientConnection c, byte rc, bool sessionPresent) + { + byte sp = 0; + if (rc == MqttConnAckRc.Accepted && sessionPresent) + sp = 1; + + ReadOnlySpan connack = [MqttPacket.ConnectAck, 0x02, sp, rc]; + lock (c) + { + c.EnqueueProto(connack); + } + } + + /// + /// Handles DISCONNECT: clears the will message and closes the connection. + /// Mirrors Go DISCONNECT case in mqttParse(). + /// + public static void HandleDisconnect(ClientConnection c) + { + // Per MQTT spec 3.1.2-8: discard the will message on clean disconnect. + lock (c) + { + if (c.Mqtt != null) + c.Mqtt.Will = null; + } + + // Close the connection cleanly. + c.CloseConnection(ClosedState.ClientClosed); + } +} diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttParser.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttParser.cs index 5949a8e..892cf8e 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttParser.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Mqtt/MqttParser.cs @@ -126,18 +126,32 @@ internal static class MqttParser case MqttPacket.Connect: if (connected) { - // Second CONNECT on same connection is a protocol violation. err = new InvalidOperationException("second CONNECT packet not allowed"); break; } - // TODO: Task 3 — MqttParseConnect + MqttProcessConnect - err = new NotImplementedException("CONNECT not yet implemented"); + var (rc, cp, parseErr) = MqttPacketHandlers.ParseConnect(r); + if (parseErr != null) + { + // Send CONNACK with error code if we have one, then close. + if (rc != MqttConnAckRc.Accepted) + MqttPacketHandlers.EnqueueConnAck(c, rc, false); + err = parseErr; + break; + } + if (rc != MqttConnAckRc.Accepted) + { + MqttPacketHandlers.EnqueueConnAck(c, rc, false); + err = new InvalidOperationException($"CONNECT rejected with code 0x{rc:X2}"); + break; + } + err = MqttPacketHandlers.ProcessConnect(c, cp!); + if (err == null) + connected = true; break; case MqttPacket.Disconnect: - // TODO: Task 3 — handle DISCONNECT - err = new NotImplementedException("DISCONNECT not yet implemented"); - break; + MqttPacketHandlers.HandleDisconnect(c); + return null; // Connection closed, exit parse loop. default: err = new InvalidOperationException($"unknown MQTT packet type: 0x{pt:X2}"); diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.IntegrationTests/ServerBootTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.IntegrationTests/ServerBootTests.cs index 1cc98db..800852f 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.IntegrationTests/ServerBootTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.IntegrationTests/ServerBootTests.cs @@ -160,6 +160,96 @@ public sealed class ServerBootTests : IDisposable server.Running().ShouldBeFalse(); } + /// + /// End-to-end: TCP connect → send MQTT CONNECT → receive CONNACK. + /// Validates the full MQTT handshake over the wire. + /// + [Fact] + public async Task MqttBoot_ConnectHandshake_ShouldReceiveConnAck() + { + var opts = new ServerOptions + { + Host = "127.0.0.1", + Port = 0, + Mqtt = { Port = -1, Host = "127.0.0.1" }, + }; + + var (server, err) = NatsServer.NewServer(opts); + err.ShouldBeNull(); + server.ShouldNotBeNull(); + + try + { + server!.Start(); + var mqttAddr = server.MqttAddr(); + mqttAddr.ShouldNotBeNull(); + + using var tcp = new System.Net.Sockets.TcpClient(); + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + await tcp.ConnectAsync(mqttAddr!.Address, mqttAddr.Port, cts.Token); + + var stream = tcp.GetStream(); + + // Build and send MQTT CONNECT packet. + var connectPacket = BuildMqttConnectPacket("integration-test"); + await stream.WriteAsync(connectPacket, cts.Token); + await stream.FlushAsync(cts.Token); + + // Read CONNACK response (4 bytes). + var response = new byte[4]; + var totalRead = 0; + while (totalRead < 4) + { + var n = await stream.ReadAsync(response.AsMemory(totalRead, 4 - totalRead), cts.Token); + if (n == 0) break; + totalRead += n; + } + + totalRead.ShouldBe(4, "Should receive 4-byte CONNACK"); + response[0].ShouldBe((byte)0x20, "Packet type should be CONNACK"); + response[1].ShouldBe((byte)0x02, "Remaining length should be 2"); + response[3].ShouldBe((byte)0x00, "Return code should be Accepted (0)"); + } + finally + { + server!.Shutdown(); + } + } + + /// Builds a minimal MQTT CONNECT packet. + private static byte[] BuildMqttConnectPacket(string clientId) + { + var payload = new List(); + // Protocol name "MQTT" + payload.AddRange(new byte[] { 0x00, 0x04 }); + payload.AddRange(System.Text.Encoding.UTF8.GetBytes("MQTT")); + // Protocol level 4 + payload.Add(0x04); + // Flags: clean session + payload.Add(0x02); + // Keep alive: 60s + payload.AddRange(new byte[] { 0x00, 0x3C }); + // Client ID + var cidBytes = System.Text.Encoding.UTF8.GetBytes(clientId); + payload.Add((byte)(cidBytes.Length >> 8)); + payload.Add((byte)(cidBytes.Length & 0xFF)); + payload.AddRange(cidBytes); + + // Fixed header + var result = new List(); + result.Add(0x10); // CONNECT type + var remLen = payload.Count; + do + { + var b = (byte)(remLen & 0x7F); + remLen >>= 7; + if (remLen > 0) b |= 0x80; + result.Add(b); + } while (remLen > 0); + result.AddRange(payload); + return result.ToArray(); + } + /// /// Validates that Shutdown() after Start() completes cleanly. /// Uses DontListen to skip TCP binding — tests lifecycle only. diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttConnectTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttConnectTests.cs new file mode 100644 index 0000000..d6ed543 --- /dev/null +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttConnectTests.cs @@ -0,0 +1,386 @@ +// Copyright 2020-2026 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 + +using System.Text; +using Shouldly; +using ZB.MOM.NatsNet.Server; +using ZB.MOM.NatsNet.Server.Internal; +using ZB.MOM.NatsNet.Server.Mqtt; + +namespace ZB.MOM.NatsNet.Server.Tests.Mqtt; + +/// +/// Unit tests for MQTT CONNECT/CONNACK/DISCONNECT packet handling. +/// +public sealed class MqttConnectTests +{ + /// + /// Builds a minimal valid MQTT CONNECT packet. + /// + private static byte[] BuildConnectPacket( + string clientId = "test-client", + bool cleanSession = true, + ushort keepAlive = 60, + string? willTopic = null, + byte[]? willMessage = null, + byte willQos = 0, + bool willRetain = false, + string? username = null, + string? password = null) + { + var payload = new List(); + + // Variable header. + // Protocol name "MQTT". + var protoName = Encoding.UTF8.GetBytes("MQTT"); + payload.Add((byte)(protoName.Length >> 8)); + payload.Add((byte)(protoName.Length & 0xFF)); + payload.AddRange(protoName); + + // Protocol level. + payload.Add(0x04); + + // Connect flags. + byte flags = 0; + if (cleanSession) flags |= MqttConnectFlag.CleanSession; + if (willTopic != null) + { + flags |= MqttConnectFlag.WillFlag; + flags |= (byte)((willQos & 0x03) << 3); + if (willRetain) flags |= MqttConnectFlag.WillRetain; + } + if (username != null) flags |= MqttConnectFlag.UsernameFlag; + if (password != null) flags |= MqttConnectFlag.PasswordFlag; + payload.Add(flags); + + // Keep alive. + payload.Add((byte)(keepAlive >> 8)); + payload.Add((byte)(keepAlive & 0xFF)); + + // Client ID. + var cidBytes = Encoding.UTF8.GetBytes(clientId); + payload.Add((byte)(cidBytes.Length >> 8)); + payload.Add((byte)(cidBytes.Length & 0xFF)); + payload.AddRange(cidBytes); + + // Will topic + message. + if (willTopic != null) + { + var topicBytes = Encoding.UTF8.GetBytes(willTopic); + payload.Add((byte)(topicBytes.Length >> 8)); + payload.Add((byte)(topicBytes.Length & 0xFF)); + payload.AddRange(topicBytes); + + var msg = willMessage ?? []; + payload.Add((byte)(msg.Length >> 8)); + payload.Add((byte)(msg.Length & 0xFF)); + payload.AddRange(msg); + } + + // Username. + if (username != null) + { + var userBytes = Encoding.UTF8.GetBytes(username); + payload.Add((byte)(userBytes.Length >> 8)); + payload.Add((byte)(userBytes.Length & 0xFF)); + payload.AddRange(userBytes); + } + + // Password. + if (password != null) + { + var passBytes = Encoding.UTF8.GetBytes(password); + payload.Add((byte)(passBytes.Length >> 8)); + payload.Add((byte)(passBytes.Length & 0xFF)); + payload.AddRange(passBytes); + } + + // Build full packet: type byte + remaining length + payload. + var result = new List(); + result.Add(MqttPacket.Connect); + + // Encode remaining length. + var remLen = payload.Count; + do + { + var encoded = (byte)(remLen & 0x7F); + remLen >>= 7; + if (remLen > 0) encoded |= 0x80; + result.Add(encoded); + } while (remLen > 0); + + result.AddRange(payload); + return result.ToArray(); + } + + private static ClientConnection CreateMqttClient() + { + var ms = new MemoryStream(); + var c = new ClientConnection(ClientKind.Client, nc: ms); + c.InitMqtt(new MqttHandler()); + return c; + } + + // ========================================================================= + // ParseConnect tests + // ========================================================================= + + [Fact] + public void ParseConnect_ValidMinimal_ShouldSucceed() + { + var buf = BuildConnectPacket(); + var r = new MqttReader(); + // Skip the fixed header (type + remaining length) — parser handles that. + // For direct ParseConnect testing, we feed only the variable header + payload. + r.Reset(buf[2..]); // Skip type byte and 1-byte remaining length. + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + err.ShouldBeNull(); + rc.ShouldBe(MqttConnAckRc.Accepted); + cp.ShouldNotBeNull(); + cp!.ClientId.ShouldBe("test-client"); + cp.CleanSession.ShouldBeTrue(); + cp.KeepAlive.ShouldBe((ushort)60); + cp.Will.ShouldBeNull(); + cp.Username.ShouldBeEmpty(); + cp.Password.ShouldBeNull(); + } + + [Fact] + public void ParseConnect_WithWill_ShouldParseCorrectly() + { + var buf = BuildConnectPacket( + willTopic: "test/will", + willMessage: Encoding.UTF8.GetBytes("goodbye"), + willQos: 1, + willRetain: true); + + // Find remaining length to skip header correctly. + int headerLen = 1; // type byte + int remLen = 0; + int mult = 1; + for (int i = 1; i < buf.Length; i++) + { + remLen += (buf[i] & 0x7F) * mult; + headerLen++; + if ((buf[i] & 0x80) == 0) break; + mult *= 128; + } + + var r = new MqttReader(); + r.Reset(buf[headerLen..]); + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + err.ShouldBeNull(); + rc.ShouldBe(MqttConnAckRc.Accepted); + cp!.Will.ShouldNotBeNull(); + cp.Will!.Topic.ShouldBe("test/will"); + cp.Will.Subject.ShouldNotBeEmpty(); // NATS-converted subject + cp.Will.Msg.ShouldNotBeNull(); + Encoding.UTF8.GetString(cp.Will.Msg!).ShouldBe("goodbye"); + cp.Will.Qos.ShouldBe((byte)1); + cp.Will.Retain.ShouldBeTrue(); + } + + [Fact] + public void ParseConnect_WithAuth_ShouldParseCorrectly() + { + var buf = BuildConnectPacket(username: "user1", password: "pass1"); + var r = new MqttReader(); + r.Reset(buf[2..]); + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + err.ShouldBeNull(); + rc.ShouldBe(MqttConnAckRc.Accepted); + cp!.Username.ShouldBe("user1"); + cp.Password.ShouldNotBeNull(); + Encoding.UTF8.GetString(cp.Password!).ShouldBe("pass1"); + } + + [Fact] + public void ParseConnect_WrongProtocolName_ShouldRejectWithCode() + { + // Build a packet with wrong protocol name. + var buf = new List(); + var wrong = Encoding.UTF8.GetBytes("MQIsdp"); + buf.Add((byte)(wrong.Length >> 8)); + buf.Add((byte)(wrong.Length & 0xFF)); + buf.AddRange(wrong); + buf.Add(0x03); // level + buf.Add(0x02); // clean session + buf.AddRange(new byte[] { 0x00, 0x3C }); // keepalive=60 + buf.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId="AB" + + var r = new MqttReader(); + r.Reset(buf.ToArray()); + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + rc.ShouldBe(MqttConnAckRc.UnacceptableProtocol); + err.ShouldNotBeNull(); + } + + [Fact] + public void ParseConnect_EmptyClientIdWithoutCleanSession_ShouldReject() + { + var buf = BuildConnectPacket(clientId: "", cleanSession: false); + var r = new MqttReader(); + r.Reset(buf[2..]); + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + rc.ShouldBe(MqttConnAckRc.IdentifierRejected); + err.ShouldNotBeNull(); + } + + [Fact] + public void ParseConnect_EmptyClientIdWithCleanSession_ShouldAutoGenerate() + { + var buf = BuildConnectPacket(clientId: "", cleanSession: true); + var r = new MqttReader(); + r.Reset(buf[2..]); + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + err.ShouldBeNull(); + rc.ShouldBe(MqttConnAckRc.Accepted); + cp!.ClientId.ShouldNotBeEmpty(); // Auto-generated + cp.ClientId.Length.ShouldBe(32); // GUID "N" format + } + + [Fact] + public void ParseConnect_ReservedBitSet_ShouldError() + { + // Build manually with reserved bit set. + var payload = new List(); + var proto = Encoding.UTF8.GetBytes("MQTT"); + payload.Add(0); payload.Add(4); + payload.AddRange(proto); + payload.Add(0x04); // level + payload.Add(0x03); // clean session + reserved bit! + payload.AddRange(new byte[] { 0x00, 0x00 }); // keepalive + payload.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId + + var r = new MqttReader(); + r.Reset(payload.ToArray()); + + var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r); + err.ShouldNotBeNull(); + err.Message.ShouldContain("reserved bit"); + } + + // ========================================================================= + // ProcessConnect + CONNACK tests + // ========================================================================= + + [Fact] + public void ProcessConnect_ShouldSetFlagsAndSendConnAck() + { + var c = CreateMqttClient(); + var cp = new MqttConnectProto + { + ClientId = "test-123", + CleanSession = true, + KeepAlive = 30, + }; + + var err = MqttPacketHandlers.ProcessConnect(c, cp); + err.ShouldBeNull(); + + // Verify state. + c.Mqtt!.ClientId.ShouldBe("test-123"); + c.Mqtt.CleanSession.ShouldBeTrue(); + (c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0); + + // Verify CONNACK was written. + var ms = (MemoryStream)typeof(ClientConnection) + .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! + .GetValue(c)!; + var data = ms.ToArray(); + data.Length.ShouldBe(4); + data[0].ShouldBe(MqttPacket.ConnectAck); + data[1].ShouldBe((byte)0x02); + data[2].ShouldBe((byte)0x00); // No session present + data[3].ShouldBe(MqttConnAckRc.Accepted); + } + + // ========================================================================= + // Full CONNECT via parser integration + // ========================================================================= + + [Fact] + public void Parser_ConnectPacket_ShouldParseAndSendConnAck() + { + var c = CreateMqttClient(); + var buf = BuildConnectPacket(clientId: "mqtt-parser-test"); + + var err = MqttParser.Parse(c, buf, buf.Length); + err.ShouldBeNull(); + + // Verify connected. + (c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0); + c.Mqtt!.ClientId.ShouldBe("mqtt-parser-test"); + + // Verify CONNACK written. + var ms = (MemoryStream)typeof(ClientConnection) + .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! + .GetValue(c)!; + var data = ms.ToArray(); + data.Length.ShouldBe(4); + data[0].ShouldBe(MqttPacket.ConnectAck); + data[3].ShouldBe(MqttConnAckRc.Accepted); + } + + [Fact] + public void Parser_ConnectThenPing_ShouldSucceed() + { + var c = CreateMqttClient(); + var connectBuf = BuildConnectPacket(); + var pingBuf = new byte[] { MqttPacket.Ping, 0x00 }; + + // Concatenate CONNECT + PING into one buffer. + var buf = new byte[connectBuf.Length + pingBuf.Length]; + Buffer.BlockCopy(connectBuf, 0, buf, 0, connectBuf.Length); + Buffer.BlockCopy(pingBuf, 0, buf, connectBuf.Length, pingBuf.Length); + + var err = MqttParser.Parse(c, buf, buf.Length); + err.ShouldBeNull(); + + // Verify CONNACK + PINGRESP written. + var ms = (MemoryStream)typeof(ClientConnection) + .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! + .GetValue(c)!; + var data = ms.ToArray(); + data.Length.ShouldBe(6); // 4 (CONNACK) + 2 (PINGRESP) + data[0].ShouldBe(MqttPacket.ConnectAck); + data[4].ShouldBe(MqttPacket.PingResp); + } + + // ========================================================================= + // DISCONNECT tests + // ========================================================================= + + [Fact] + public void Parser_Disconnect_ShouldClearWillAndClose() + { + var c = CreateMqttClient(); + + // First, process a CONNECT with a will. + var connectBuf = BuildConnectPacket( + willTopic: "test/will", + willMessage: Encoding.UTF8.GetBytes("bye")); + + var err = MqttParser.Parse(c, connectBuf, connectBuf.Length); + err.ShouldBeNull(); + c.Mqtt!.Will.ShouldNotBeNull(); + + // Now send DISCONNECT. + var disconnectBuf = new byte[] { MqttPacket.Disconnect, 0x00 }; + err = MqttParser.Parse(c, disconnectBuf, disconnectBuf.Length); + err.ShouldBeNull(); + + // Will should be cleared. + c.Mqtt.Will.ShouldBeNull(); + } +} diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttParserTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttParserTests.cs index 702e92c..600b7d1 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttParserTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttParserTests.cs @@ -56,16 +56,39 @@ public sealed class MqttParserTests } [Fact] - public void Parse_ConnectFirst_ShouldNotRejectAsNonConnect() + public void Parse_ConnectFirst_ShouldAcceptConnect() { var c = CreateMqttClient(); - // CONNECT packet (minimal): type=0x10, remaining len=0 - // This will hit the "not yet implemented" but NOT the "first packet" error. - var buf = new byte[] { MqttPacket.Connect, 0x00 }; - var err = MqttParser.Parse(c, buf, buf.Length); - err.ShouldNotBeNull(); // Will be NotImplementedException - err.ShouldBeOfType(); - err.Message.ShouldContain("CONNECT not yet implemented"); + // Use a MemoryStream so CONNACK can be written. + typeof(ClientConnection) + .GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! + .SetValue(c, new MemoryStream()); + + // Build a valid CONNECT packet. + var payload = new List(); + payload.AddRange(new byte[] { 0x00, 0x04 }); // protocol name length + payload.AddRange(System.Text.Encoding.UTF8.GetBytes("MQTT")); + payload.Add(0x04); // level + payload.Add(0x02); // flags: clean session + payload.AddRange(new byte[] { 0x00, 0x3C }); // keep alive = 60 + payload.AddRange(new byte[] { 0x00, 0x04 }); // client id length + payload.AddRange(System.Text.Encoding.UTF8.GetBytes("test")); + + var buf = new List { MqttPacket.Connect }; + // Remaining length + var remLen = payload.Count; + do + { + var b = (byte)(remLen & 0x7F); + remLen >>= 7; + if (remLen > 0) b |= 0x80; + buf.Add(b); + } while (remLen > 0); + buf.AddRange(payload); + + var err = MqttParser.Parse(c, buf.ToArray(), buf.Count); + err.ShouldBeNull("CONNECT should be accepted, not rejected as non-CONNECT"); + (c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0); } [Fact] diff --git a/reports/current.md b/reports/current.md index a8848f6..57aeddc 100644 --- a/reports/current.md +++ b/reports/current.md @@ -1,6 +1,6 @@ # NATS .NET Porting Status Report -Generated: 2026-03-01 20:41:46 UTC +Generated: 2026-03-01 20:48:23 UTC ## Modules (12 total)