Files
natsnet/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Mqtt/MqttConnectTests.cs
Joseph Doherty 95cf20b00b feat(mqtt): implement CONNECT/CONNACK/DISCONNECT packet handlers
Implement Task 3 of MQTT orchestration:
- Create MqttPacketHandlers.cs with ParseConnect(), ProcessConnect(), EnqueueConnAck(), HandleDisconnect()
- Wire CONNECT and DISCONNECT dispatch in MqttParser.cs
- Parse CONNECT: protocol name/level, flags, keep-alive, client ID, will, auth
- Send CONNACK (4-byte fixed packet with return code)
- DISCONNECT clears will message and closes connection cleanly
- Auto-generate client ID for empty ID + clean session
- Validate reserved bit, will flags, username/password consistency
- Add Reader field to MqttHandler for per-connection parsing
- 11 unit tests for CONNECT parsing and processing
- 1 end-to-end integration test: TCP → CONNECT → CONNACK over the wire
2026-03-01 15:48:22 -05:00

387 lines
13 KiB
C#

// Copyright 2020-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
using System.Text;
using Shouldly;
using ZB.MOM.NatsNet.Server;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.Mqtt;
namespace ZB.MOM.NatsNet.Server.Tests.Mqtt;
/// <summary>
/// Unit tests for MQTT CONNECT/CONNACK/DISCONNECT packet handling.
/// </summary>
public sealed class MqttConnectTests
{
/// <summary>
/// Builds a minimal valid MQTT CONNECT packet.
/// </summary>
private static byte[] BuildConnectPacket(
string clientId = "test-client",
bool cleanSession = true,
ushort keepAlive = 60,
string? willTopic = null,
byte[]? willMessage = null,
byte willQos = 0,
bool willRetain = false,
string? username = null,
string? password = null)
{
var payload = new List<byte>();
// Variable header.
// Protocol name "MQTT".
var protoName = Encoding.UTF8.GetBytes("MQTT");
payload.Add((byte)(protoName.Length >> 8));
payload.Add((byte)(protoName.Length & 0xFF));
payload.AddRange(protoName);
// Protocol level.
payload.Add(0x04);
// Connect flags.
byte flags = 0;
if (cleanSession) flags |= MqttConnectFlag.CleanSession;
if (willTopic != null)
{
flags |= MqttConnectFlag.WillFlag;
flags |= (byte)((willQos & 0x03) << 3);
if (willRetain) flags |= MqttConnectFlag.WillRetain;
}
if (username != null) flags |= MqttConnectFlag.UsernameFlag;
if (password != null) flags |= MqttConnectFlag.PasswordFlag;
payload.Add(flags);
// Keep alive.
payload.Add((byte)(keepAlive >> 8));
payload.Add((byte)(keepAlive & 0xFF));
// Client ID.
var cidBytes = Encoding.UTF8.GetBytes(clientId);
payload.Add((byte)(cidBytes.Length >> 8));
payload.Add((byte)(cidBytes.Length & 0xFF));
payload.AddRange(cidBytes);
// Will topic + message.
if (willTopic != null)
{
var topicBytes = Encoding.UTF8.GetBytes(willTopic);
payload.Add((byte)(topicBytes.Length >> 8));
payload.Add((byte)(topicBytes.Length & 0xFF));
payload.AddRange(topicBytes);
var msg = willMessage ?? [];
payload.Add((byte)(msg.Length >> 8));
payload.Add((byte)(msg.Length & 0xFF));
payload.AddRange(msg);
}
// Username.
if (username != null)
{
var userBytes = Encoding.UTF8.GetBytes(username);
payload.Add((byte)(userBytes.Length >> 8));
payload.Add((byte)(userBytes.Length & 0xFF));
payload.AddRange(userBytes);
}
// Password.
if (password != null)
{
var passBytes = Encoding.UTF8.GetBytes(password);
payload.Add((byte)(passBytes.Length >> 8));
payload.Add((byte)(passBytes.Length & 0xFF));
payload.AddRange(passBytes);
}
// Build full packet: type byte + remaining length + payload.
var result = new List<byte>();
result.Add(MqttPacket.Connect);
// Encode remaining length.
var remLen = payload.Count;
do
{
var encoded = (byte)(remLen & 0x7F);
remLen >>= 7;
if (remLen > 0) encoded |= 0x80;
result.Add(encoded);
} while (remLen > 0);
result.AddRange(payload);
return result.ToArray();
}
private static ClientConnection CreateMqttClient()
{
var ms = new MemoryStream();
var c = new ClientConnection(ClientKind.Client, nc: ms);
c.InitMqtt(new MqttHandler());
return c;
}
// =========================================================================
// ParseConnect tests
// =========================================================================
[Fact]
public void ParseConnect_ValidMinimal_ShouldSucceed()
{
var buf = BuildConnectPacket();
var r = new MqttReader();
// Skip the fixed header (type + remaining length) — parser handles that.
// For direct ParseConnect testing, we feed only the variable header + payload.
r.Reset(buf[2..]); // Skip type byte and 1-byte remaining length.
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
err.ShouldBeNull();
rc.ShouldBe(MqttConnAckRc.Accepted);
cp.ShouldNotBeNull();
cp!.ClientId.ShouldBe("test-client");
cp.CleanSession.ShouldBeTrue();
cp.KeepAlive.ShouldBe((ushort)60);
cp.Will.ShouldBeNull();
cp.Username.ShouldBeEmpty();
cp.Password.ShouldBeNull();
}
[Fact]
public void ParseConnect_WithWill_ShouldParseCorrectly()
{
var buf = BuildConnectPacket(
willTopic: "test/will",
willMessage: Encoding.UTF8.GetBytes("goodbye"),
willQos: 1,
willRetain: true);
// Find remaining length to skip header correctly.
int headerLen = 1; // type byte
int remLen = 0;
int mult = 1;
for (int i = 1; i < buf.Length; i++)
{
remLen += (buf[i] & 0x7F) * mult;
headerLen++;
if ((buf[i] & 0x80) == 0) break;
mult *= 128;
}
var r = new MqttReader();
r.Reset(buf[headerLen..]);
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
err.ShouldBeNull();
rc.ShouldBe(MqttConnAckRc.Accepted);
cp!.Will.ShouldNotBeNull();
cp.Will!.Topic.ShouldBe("test/will");
cp.Will.Subject.ShouldNotBeEmpty(); // NATS-converted subject
cp.Will.Msg.ShouldNotBeNull();
Encoding.UTF8.GetString(cp.Will.Msg!).ShouldBe("goodbye");
cp.Will.Qos.ShouldBe((byte)1);
cp.Will.Retain.ShouldBeTrue();
}
[Fact]
public void ParseConnect_WithAuth_ShouldParseCorrectly()
{
var buf = BuildConnectPacket(username: "user1", password: "pass1");
var r = new MqttReader();
r.Reset(buf[2..]);
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
err.ShouldBeNull();
rc.ShouldBe(MqttConnAckRc.Accepted);
cp!.Username.ShouldBe("user1");
cp.Password.ShouldNotBeNull();
Encoding.UTF8.GetString(cp.Password!).ShouldBe("pass1");
}
[Fact]
public void ParseConnect_WrongProtocolName_ShouldRejectWithCode()
{
// Build a packet with wrong protocol name.
var buf = new List<byte>();
var wrong = Encoding.UTF8.GetBytes("MQIsdp");
buf.Add((byte)(wrong.Length >> 8));
buf.Add((byte)(wrong.Length & 0xFF));
buf.AddRange(wrong);
buf.Add(0x03); // level
buf.Add(0x02); // clean session
buf.AddRange(new byte[] { 0x00, 0x3C }); // keepalive=60
buf.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId="AB"
var r = new MqttReader();
r.Reset(buf.ToArray());
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
rc.ShouldBe(MqttConnAckRc.UnacceptableProtocol);
err.ShouldNotBeNull();
}
[Fact]
public void ParseConnect_EmptyClientIdWithoutCleanSession_ShouldReject()
{
var buf = BuildConnectPacket(clientId: "", cleanSession: false);
var r = new MqttReader();
r.Reset(buf[2..]);
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
rc.ShouldBe(MqttConnAckRc.IdentifierRejected);
err.ShouldNotBeNull();
}
[Fact]
public void ParseConnect_EmptyClientIdWithCleanSession_ShouldAutoGenerate()
{
var buf = BuildConnectPacket(clientId: "", cleanSession: true);
var r = new MqttReader();
r.Reset(buf[2..]);
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
err.ShouldBeNull();
rc.ShouldBe(MqttConnAckRc.Accepted);
cp!.ClientId.ShouldNotBeEmpty(); // Auto-generated
cp.ClientId.Length.ShouldBe(32); // GUID "N" format
}
[Fact]
public void ParseConnect_ReservedBitSet_ShouldError()
{
// Build manually with reserved bit set.
var payload = new List<byte>();
var proto = Encoding.UTF8.GetBytes("MQTT");
payload.Add(0); payload.Add(4);
payload.AddRange(proto);
payload.Add(0x04); // level
payload.Add(0x03); // clean session + reserved bit!
payload.AddRange(new byte[] { 0x00, 0x00 }); // keepalive
payload.AddRange(new byte[] { 0x00, 0x02, 0x41, 0x42 }); // clientId
var r = new MqttReader();
r.Reset(payload.ToArray());
var (rc, cp, err) = MqttPacketHandlers.ParseConnect(r);
err.ShouldNotBeNull();
err.Message.ShouldContain("reserved bit");
}
// =========================================================================
// ProcessConnect + CONNACK tests
// =========================================================================
[Fact]
public void ProcessConnect_ShouldSetFlagsAndSendConnAck()
{
var c = CreateMqttClient();
var cp = new MqttConnectProto
{
ClientId = "test-123",
CleanSession = true,
KeepAlive = 30,
};
var err = MqttPacketHandlers.ProcessConnect(c, cp);
err.ShouldBeNull();
// Verify state.
c.Mqtt!.ClientId.ShouldBe("test-123");
c.Mqtt.CleanSession.ShouldBeTrue();
(c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0);
// Verify CONNACK was written.
var ms = (MemoryStream)typeof(ClientConnection)
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
.GetValue(c)!;
var data = ms.ToArray();
data.Length.ShouldBe(4);
data[0].ShouldBe(MqttPacket.ConnectAck);
data[1].ShouldBe((byte)0x02);
data[2].ShouldBe((byte)0x00); // No session present
data[3].ShouldBe(MqttConnAckRc.Accepted);
}
// =========================================================================
// Full CONNECT via parser integration
// =========================================================================
[Fact]
public void Parser_ConnectPacket_ShouldParseAndSendConnAck()
{
var c = CreateMqttClient();
var buf = BuildConnectPacket(clientId: "mqtt-parser-test");
var err = MqttParser.Parse(c, buf, buf.Length);
err.ShouldBeNull();
// Verify connected.
(c.Flags & ClientFlags.ConnectReceived).ShouldNotBe((ClientFlags)0);
c.Mqtt!.ClientId.ShouldBe("mqtt-parser-test");
// Verify CONNACK written.
var ms = (MemoryStream)typeof(ClientConnection)
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
.GetValue(c)!;
var data = ms.ToArray();
data.Length.ShouldBe(4);
data[0].ShouldBe(MqttPacket.ConnectAck);
data[3].ShouldBe(MqttConnAckRc.Accepted);
}
[Fact]
public void Parser_ConnectThenPing_ShouldSucceed()
{
var c = CreateMqttClient();
var connectBuf = BuildConnectPacket();
var pingBuf = new byte[] { MqttPacket.Ping, 0x00 };
// Concatenate CONNECT + PING into one buffer.
var buf = new byte[connectBuf.Length + pingBuf.Length];
Buffer.BlockCopy(connectBuf, 0, buf, 0, connectBuf.Length);
Buffer.BlockCopy(pingBuf, 0, buf, connectBuf.Length, pingBuf.Length);
var err = MqttParser.Parse(c, buf, buf.Length);
err.ShouldBeNull();
// Verify CONNACK + PINGRESP written.
var ms = (MemoryStream)typeof(ClientConnection)
.GetField("_nc", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
.GetValue(c)!;
var data = ms.ToArray();
data.Length.ShouldBe(6); // 4 (CONNACK) + 2 (PINGRESP)
data[0].ShouldBe(MqttPacket.ConnectAck);
data[4].ShouldBe(MqttPacket.PingResp);
}
// =========================================================================
// DISCONNECT tests
// =========================================================================
[Fact]
public void Parser_Disconnect_ShouldClearWillAndClose()
{
var c = CreateMqttClient();
// First, process a CONNECT with a will.
var connectBuf = BuildConnectPacket(
willTopic: "test/will",
willMessage: Encoding.UTF8.GetBytes("bye"));
var err = MqttParser.Parse(c, connectBuf, connectBuf.Length);
err.ShouldBeNull();
c.Mqtt!.Will.ShouldNotBeNull();
// Now send DISCONNECT.
var disconnectBuf = new byte[] { MqttPacket.Disconnect, 0x00 };
err = MqttParser.Parse(c, disconnectBuf, disconnectBuf.Length);
err.ShouldBeNull();
// Will should be cleared.
c.Mqtt.Will.ShouldBeNull();
}
}