feat(mqtt): implement CONNECT/CONNACK/DISCONNECT packet handlers
Implement Task 3 of MQTT orchestration: - Create MqttPacketHandlers.cs with ParseConnect(), ProcessConnect(), EnqueueConnAck(), HandleDisconnect() - Wire CONNECT and DISCONNECT dispatch in MqttParser.cs - Parse CONNECT: protocol name/level, flags, keep-alive, client ID, will, auth - Send CONNACK (4-byte fixed packet with return code) - DISCONNECT clears will message and closes connection cleanly - Auto-generate client ID for empty ID + clean session - Validate reserved bit, will flags, username/password consistency - Add Reader field to MqttHandler for per-connection parsing - 11 unit tests for CONNECT parsing and processing - 1 end-to-end integration test: TCP → CONNECT → CONNACK over the wire
This commit is contained in:
@@ -0,0 +1,386 @@
|
||||
// Copyright 2020-2026 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
using System.Text;
|
||||
using Shouldly;
|
||||
using ZB.MOM.NatsNet.Server;
|
||||
using ZB.MOM.NatsNet.Server.Internal;
|
||||
using ZB.MOM.NatsNet.Server.Mqtt;
|
||||
|
||||
namespace ZB.MOM.NatsNet.Server.Tests.Mqtt;
|
||||
|
||||
/// <summary>
|
||||
/// Unit tests for MQTT CONNECT/CONNACK/DISCONNECT packet handling.
|
||||
/// </summary>
|
||||
public sealed class MqttConnectTests
|
||||
{
|
||||
/// <summary>
|
||||
/// Builds a minimal valid MQTT CONNECT packet.
|
||||
/// </summary>
|
||||
private static byte[] BuildConnectPacket(
|
||||
string clientId = "test-client",
|
||||
bool cleanSession = true,
|
||||
ushort keepAlive = 60,
|
||||
string? willTopic = null,
|
||||
byte[]? willMessage = null,
|
||||
byte willQos = 0,
|
||||
bool willRetain = false,
|
||||
string? username = null,
|
||||
string? password = null)
|
||||
{
|
||||
var payload = new List<byte>();
|
||||
|
||||
// Variable header.
|
||||
// Protocol name "MQTT".
|
||||
var protoName = Encoding.UTF8.GetBytes("MQTT");
|
||||
payload.Add((byte)(protoName.Length >> 8));
|
||||
payload.Add((byte)(protoName.Length & 0xFF));
|
||||
payload.AddRange(protoName);
|
||||
|
||||
// Protocol level.
|
||||
payload.Add(0x04);
|
||||
|
||||
// Connect flags.
|
||||
byte flags = 0;
|
||||
if (cleanSession) flags |= MqttConnectFlag.CleanSession;
|
||||
if (willTopic != null)
|
||||
{
|
||||
flags |= MqttConnectFlag.WillFlag;
|
||||
flags |= (byte)((willQos & 0x03) << 3);
|
||||
if (willRetain) flags |= MqttConnectFlag.WillRetain;
|
||||
}
|
||||
if (username != null) flags |= MqttConnectFlag.UsernameFlag;
|
||||
if (password != null) flags |= MqttConnectFlag.PasswordFlag;
|
||||
payload.Add(flags);
|
||||
|
||||
// Keep alive.
|
||||
payload.Add((byte)(keepAlive >> 8));
|
||||
payload.Add((byte)(keepAlive & 0xFF));
|
||||
|
||||
// Client ID.
|
||||
var cidBytes = Encoding.UTF8.GetBytes(clientId);
|
||||
payload.Add((byte)(cidBytes.Length >> 8));
|
||||
payload.Add((byte)(cidBytes.Length & 0xFF));
|
||||
payload.AddRange(cidBytes);
|
||||
|
||||
// Will topic + message.
|
||||
if (willTopic != null)
|
||||
{
|
||||
var topicBytes = Encoding.UTF8.GetBytes(willTopic);
|
||||
payload.Add((byte)(topicBytes.Length >> 8));
|
||||
payload.Add((byte)(topicBytes.Length & 0xFF));
|
||||
payload.AddRange(topicBytes);
|
||||
|
||||
var msg = willMessage ?? [];
|
||||
payload.Add((byte)(msg.Length >> 8));
|
||||
payload.Add((byte)(msg.Length & 0xFF));
|
||||
payload.AddRange(msg);
|
||||
}
|
||||
|
||||
// Username.
|
||||
if (username != null)
|
||||
{
|
||||
var userBytes = Encoding.UTF8.GetBytes(username);
|
||||
payload.Add((byte)(userBytes.Length >> 8));
|
||||
payload.Add((byte)(userBytes.Length & 0xFF));
|
||||
payload.AddRange(userBytes);
|
||||
}
|
||||
|
||||
// Password.
|
||||
if (password != null)
|
||||
{
|
||||
var passBytes = Encoding.UTF8.GetBytes(password);
|
||||
payload.Add((byte)(passBytes.Length >> 8));
|
||||
payload.Add((byte)(passBytes.Length & 0xFF));
|
||||
payload.AddRange(passBytes);
|
||||
}
|
||||
|
||||
// Build full packet: type byte + remaining length + payload.
|
||||
var result = new List<byte>();
|
||||
result.Add(MqttPacket.Connect);
|
||||
|
||||
// Encode remaining length.
|
||||
var remLen = payload.Count;
|
||||
do
|
||||
{
|
||||
var encoded = (byte)(remLen & 0x7F);
|
||||
remLen >>= 7;
|
||||
if (remLen > 0) encoded |= 0x80;
|
||||
result.Add(encoded);
|
||||
} while (remLen > 0);
|
||||
|
||||
result.AddRange(payload);
|
||||
return result.ToArray();
|
||||
}
|
||||
|
||||
private static ClientConnection CreateMqttClient()
|
||||
{
|
||||
var ms = new MemoryStream();
|
||||
var c = new ClientConnection(ClientKind.Client, nc: ms);
|
||||
c.InitMqtt(new MqttHandler());
|
||||
return c;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// ParseConnect tests
|
||||
// =========================================================================
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_ValidMinimal_ShouldSucceed()
|
||||
{
|
||||
var buf = BuildConnectPacket();
|
||||
var r = new MqttReader();
|
||||
// Skip the fixed header (type + remaining length) — parser handles that.
|
||||
// For direct ParseConnect testing, we feed only the variable header + payload.
|
||||
r.Reset(buf[2..]); // Skip type byte and 1-byte remaining length.
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
err.ShouldBeNull();
|
||||
rc.ShouldBe(MqttConnAckRc.Accepted);
|
||||
cp.ShouldNotBeNull();
|
||||
cp!.ClientId.ShouldBe("test-client");
|
||||
cp.CleanSession.ShouldBeTrue();
|
||||
cp.KeepAlive.ShouldBe((ushort)60);
|
||||
cp.Will.ShouldBeNull();
|
||||
cp.Username.ShouldBeEmpty();
|
||||
cp.Password.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_WithWill_ShouldParseCorrectly()
|
||||
{
|
||||
var buf = BuildConnectPacket(
|
||||
willTopic: "test/will",
|
||||
willMessage: Encoding.UTF8.GetBytes("goodbye"),
|
||||
willQos: 1,
|
||||
willRetain: true);
|
||||
|
||||
// Find remaining length to skip header correctly.
|
||||
int headerLen = 1; // type byte
|
||||
int remLen = 0;
|
||||
int mult = 1;
|
||||
for (int i = 1; i < buf.Length; i++)
|
||||
{
|
||||
remLen += (buf[i] & 0x7F) * mult;
|
||||
headerLen++;
|
||||
if ((buf[i] & 0x80) == 0) break;
|
||||
mult *= 128;
|
||||
}
|
||||
|
||||
var r = new MqttReader();
|
||||
r.Reset(buf[headerLen..]);
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
err.ShouldBeNull();
|
||||
rc.ShouldBe(MqttConnAckRc.Accepted);
|
||||
cp!.Will.ShouldNotBeNull();
|
||||
cp.Will!.Topic.ShouldBe("test/will");
|
||||
cp.Will.Subject.ShouldNotBeEmpty(); // NATS-converted subject
|
||||
cp.Will.Msg.ShouldNotBeNull();
|
||||
Encoding.UTF8.GetString(cp.Will.Msg!).ShouldBe("goodbye");
|
||||
cp.Will.Qos.ShouldBe((byte)1);
|
||||
cp.Will.Retain.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_WithAuth_ShouldParseCorrectly()
|
||||
{
|
||||
var buf = BuildConnectPacket(username: "user1", password: "pass1");
|
||||
var r = new MqttReader();
|
||||
r.Reset(buf[2..]);
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
err.ShouldBeNull();
|
||||
rc.ShouldBe(MqttConnAckRc.Accepted);
|
||||
cp!.Username.ShouldBe("user1");
|
||||
cp.Password.ShouldNotBeNull();
|
||||
Encoding.UTF8.GetString(cp.Password!).ShouldBe("pass1");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_WrongProtocolName_ShouldRejectWithCode()
|
||||
{
|
||||
// Build a packet with wrong protocol name.
|
||||
var buf = new List<byte>();
|
||||
var wrong = Encoding.UTF8.GetBytes("MQIsdp");
|
||||
buf.Add((byte)(wrong.Length >> 8));
|
||||
buf.Add((byte)(wrong.Length & 0xFF));
|
||||
buf.AddRange(wrong);
|
||||
buf.Add(0x03); // level
|
||||
buf.Add(0x02); // clean session
|
||||
buf.AddRange(new byte[] { 0x00, 0x3C }); // keepalive=60
|
||||
buf.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId="AB"
|
||||
|
||||
var r = new MqttReader();
|
||||
r.Reset(buf.ToArray());
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
rc.ShouldBe(MqttConnAckRc.UnacceptableProtocol);
|
||||
err.ShouldNotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_EmptyClientIdWithoutCleanSession_ShouldReject()
|
||||
{
|
||||
var buf = BuildConnectPacket(clientId: "", cleanSession: false);
|
||||
var r = new MqttReader();
|
||||
r.Reset(buf[2..]);
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
rc.ShouldBe(MqttConnAckRc.IdentifierRejected);
|
||||
err.ShouldNotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_EmptyClientIdWithCleanSession_ShouldAutoGenerate()
|
||||
{
|
||||
var buf = BuildConnectPacket(clientId: "", cleanSession: true);
|
||||
var r = new MqttReader();
|
||||
r.Reset(buf[2..]);
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
err.ShouldBeNull();
|
||||
rc.ShouldBe(MqttConnAckRc.Accepted);
|
||||
cp!.ClientId.ShouldNotBeEmpty(); // Auto-generated
|
||||
cp.ClientId.Length.ShouldBe(32); // GUID "N" format
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseConnect_ReservedBitSet_ShouldError()
|
||||
{
|
||||
// Build manually with reserved bit set.
|
||||
var payload = new List<byte>();
|
||||
var proto = Encoding.UTF8.GetBytes("MQTT");
|
||||
payload.Add(0); payload.Add(4);
|
||||
payload.AddRange(proto);
|
||||
payload.Add(0x04); // level
|
||||
payload.Add(0x03); // clean session + reserved bit!
|
||||
payload.AddRange(new byte[] { 0x00, 0x00 }); // keepalive
|
||||
payload.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId
|
||||
|
||||
var r = new MqttReader();
|
||||
r.Reset(payload.ToArray());
|
||||
|
||||
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
|
||||
err.ShouldNotBeNull();
|
||||
err.Message.ShouldContain("reserved bit");
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// ProcessConnect + CONNACK tests
|
||||
// =========================================================================
|
||||
|
||||
[Fact]
|
||||
public void ProcessConnect_ShouldSetFlagsAndSendConnAck()
|
||||
{
|
||||
var c = CreateMqttClient();
|
||||
var cp = new MqttConnectProto
|
||||
{
|
||||
ClientId = "test-123",
|
||||
CleanSession = true,
|
||||
KeepAlive = 30,
|
||||
};
|
||||
|
||||
var err = MqttPacketHandlers.ProcessConnect(c, cp);
|
||||
err.ShouldBeNull();
|
||||
|
||||
// Verify state.
|
||||
c.Mqtt!.ClientId.ShouldBe("test-123");
|
||||
c.Mqtt.CleanSession.ShouldBeTrue();
|
||||
(c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0);
|
||||
|
||||
// Verify CONNACK was written.
|
||||
var ms = (MemoryStream)typeof(ClientConnection)
|
||||
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
|
||||
.GetValue(c)!;
|
||||
var data = ms.ToArray();
|
||||
data.Length.ShouldBe(4);
|
||||
data[0].ShouldBe(MqttPacket.ConnectAck);
|
||||
data[1].ShouldBe((byte)0x02);
|
||||
data[2].ShouldBe((byte)0x00); // No session present
|
||||
data[3].ShouldBe(MqttConnAckRc.Accepted);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Full CONNECT via parser integration
|
||||
// =========================================================================
|
||||
|
||||
[Fact]
|
||||
public void Parser_ConnectPacket_ShouldParseAndSendConnAck()
|
||||
{
|
||||
var c = CreateMqttClient();
|
||||
var buf = BuildConnectPacket(clientId: "mqtt-parser-test");
|
||||
|
||||
var err = MqttParser.Parse(c, buf, buf.Length);
|
||||
err.ShouldBeNull();
|
||||
|
||||
// Verify connected.
|
||||
(c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0);
|
||||
c.Mqtt!.ClientId.ShouldBe("mqtt-parser-test");
|
||||
|
||||
// Verify CONNACK written.
|
||||
var ms = (MemoryStream)typeof(ClientConnection)
|
||||
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
|
||||
.GetValue(c)!;
|
||||
var data = ms.ToArray();
|
||||
data.Length.ShouldBe(4);
|
||||
data[0].ShouldBe(MqttPacket.ConnectAck);
|
||||
data[3].ShouldBe(MqttConnAckRc.Accepted);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parser_ConnectThenPing_ShouldSucceed()
|
||||
{
|
||||
var c = CreateMqttClient();
|
||||
var connectBuf = BuildConnectPacket();
|
||||
var pingBuf = new byte[] { MqttPacket.Ping, 0x00 };
|
||||
|
||||
// Concatenate CONNECT + PING into one buffer.
|
||||
var buf = new byte[connectBuf.Length + pingBuf.Length];
|
||||
Buffer.BlockCopy(connectBuf, 0, buf, 0, connectBuf.Length);
|
||||
Buffer.BlockCopy(pingBuf, 0, buf, connectBuf.Length, pingBuf.Length);
|
||||
|
||||
var err = MqttParser.Parse(c, buf, buf.Length);
|
||||
err.ShouldBeNull();
|
||||
|
||||
// Verify CONNACK + PINGRESP written.
|
||||
var ms = (MemoryStream)typeof(ClientConnection)
|
||||
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
|
||||
.GetValue(c)!;
|
||||
var data = ms.ToArray();
|
||||
data.Length.ShouldBe(6); // 4 (CONNACK) + 2 (PINGRESP)
|
||||
data[0].ShouldBe(MqttPacket.ConnectAck);
|
||||
data[4].ShouldBe(MqttPacket.PingResp);
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// DISCONNECT tests
|
||||
// =========================================================================
|
||||
|
||||
[Fact]
|
||||
public void Parser_Disconnect_ShouldClearWillAndClose()
|
||||
{
|
||||
var c = CreateMqttClient();
|
||||
|
||||
// First, process a CONNECT with a will.
|
||||
var connectBuf = BuildConnectPacket(
|
||||
willTopic: "test/will",
|
||||
willMessage: Encoding.UTF8.GetBytes("bye"));
|
||||
|
||||
var err = MqttParser.Parse(c, connectBuf, connectBuf.Length);
|
||||
err.ShouldBeNull();
|
||||
c.Mqtt!.Will.ShouldNotBeNull();
|
||||
|
||||
// Now send DISCONNECT.
|
||||
var disconnectBuf = new byte[] { MqttPacket.Disconnect, 0x00 };
|
||||
err = MqttParser.Parse(c, disconnectBuf, disconnectBuf.Length);
|
||||
err.ShouldBeNull();
|
||||
|
||||
// Will should be cleared.
|
||||
c.Mqtt.Will.ShouldBeNull();
|
||||
}
|
||||
}
|
||||
@@ -56,16 +56,39 @@ public sealed class MqttParserTests
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_ConnectFirst_ShouldNotRejectAsNonConnect()
|
||||
public void Parse_ConnectFirst_ShouldAcceptConnect()
|
||||
{
|
||||
var c = CreateMqttClient();
|
||||
// CONNECT packet (minimal): type=0x10, remaining len=0
|
||||
// This will hit the "not yet implemented" but NOT the "first packet" error.
|
||||
var buf = new byte[] { MqttPacket.Connect, 0x00 };
|
||||
var err = MqttParser.Parse(c, buf, buf.Length);
|
||||
err.ShouldNotBeNull(); // Will be NotImplementedException
|
||||
err.ShouldBeOfType<NotImplementedException>();
|
||||
err.Message.ShouldContain("CONNECT not yet implemented");
|
||||
// Use a MemoryStream so CONNACK can be written.
|
||||
typeof(ClientConnection)
|
||||
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
|
||||
.SetValue(c, new MemoryStream());
|
||||
|
||||
// Build a valid CONNECT packet.
|
||||
var payload = new List<byte>();
|
||||
payload.AddRange(new byte[] { 0x00, 0x04 }); // protocol name length
|
||||
payload.AddRange(System.Text.Encoding.UTF8.GetBytes("MQTT"));
|
||||
payload.Add(0x04); // level
|
||||
payload.Add(0x02); // flags: clean session
|
||||
payload.AddRange(new byte[] { 0x00, 0x3C }); // keep alive = 60
|
||||
payload.AddRange(new byte[] { 0x00, 0x04 }); // client id length
|
||||
payload.AddRange(System.Text.Encoding.UTF8.GetBytes("test"));
|
||||
|
||||
var buf = new List<byte> { MqttPacket.Connect };
|
||||
// Remaining length
|
||||
var remLen = payload.Count;
|
||||
do
|
||||
{
|
||||
var b = (byte)(remLen & 0x7F);
|
||||
remLen >>= 7;
|
||||
if (remLen > 0) b |= 0x80;
|
||||
buf.Add(b);
|
||||
} while (remLen > 0);
|
||||
buf.AddRange(payload);
|
||||
|
||||
var err = MqttParser.Parse(c, buf.ToArray(), buf.Count);
|
||||
err.ShouldBeNull("CONNECT should be accepted, not rejected as non-CONNECT");
|
||||
(c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
||||
Reference in New Issue
Block a user