Files
natsdotnet/tests/NATS.Server.Mqtt.Tests/Mqtt/MqttBinaryProtocolTests.cs
Joseph Doherty 845441b32c feat: implement full MQTT Go parity across 5 phases — binary protocol, auth/TLS, cross-protocol bridging, monitoring, and JetStream persistence
Phase 1: Binary MQTT 3.1.1 wire protocol with PipeReader-based parsing,
full packet type dispatch, and MQTT 3.1.1 compliance checks.

Phase 2: Auth pipeline routing MQTT CONNECT through AuthService,
TLS transport with SslStream wrapping, pinned cert validation.

Phase 3: IMessageRouter refactor (NatsClient → INatsClient),
MqttNatsClientAdapter for cross-protocol bridging, MqttTopicMapper
with full Go-parity topic/subject translation.

Phase 4: /connz mqtt_client field population, /varz actual MQTT port.

Phase 5: JetStream persistence — MqttStreamInitializer creates 5
internal streams, MqttConsumerManager for QoS 1/2 consumers,
subject-keyed session/retained lookups replacing linear scans.

All 503 MQTT tests and 1589 Core tests pass.
2026-03-13 10:09:40 -04:00

866 lines
32 KiB
C#

using System.Buffers;
using System.Net;
using System.Net.Sockets;
using System.Text;
using NATS.Server.Mqtt;
namespace NATS.Server.Mqtt.Tests;
/// <summary>
/// Tests for the binary MQTT 3.1.1 wire protocol implementation.
/// Covers: TryRead, ParseUnsubscribe, new WriteXxx methods, PipeReader-based
/// connection handling, and MQTT 3.1.1 compliance rules.
/// </summary>
public class MqttBinaryProtocolTests
{
// -----------------------------------------------------------------------
// MqttPacketReader.TryRead tests
// -----------------------------------------------------------------------
[Fact]
public void TryRead_complete_connect_packet_succeeds()
{
// Build a CONNECT packet
var connectPayload = BuildConnectPayload("test-client");
var raw = MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload);
var seq = new ReadOnlySequence<byte>(raw);
MqttPacketReader.TryRead(seq, out var packet, out var consumed).ShouldBeTrue();
packet.ShouldNotBeNull();
packet.Type.ShouldBe(MqttControlPacketType.Connect);
seq.GetOffset(consumed).ShouldBe(raw.Length);
}
[Fact]
public void TryRead_returns_false_on_partial_fixed_header()
{
var seq = new ReadOnlySequence<byte>([0x10]); // just first byte, no remaining length
MqttPacketReader.TryRead(seq, out var packet, out _).ShouldBeFalse();
packet.ShouldBeNull();
}
[Fact]
public void TryRead_returns_false_on_partial_payload()
{
// CONNECT with remaining length indicating 10 bytes but only 3 present
var raw = new byte[] { 0x10, 10, 0x00, 0x04, 0x4D }; // truncated
var seq = new ReadOnlySequence<byte>(raw);
MqttPacketReader.TryRead(seq, out var packet, out _).ShouldBeFalse();
packet.ShouldBeNull();
}
[Fact]
public void TryRead_handles_multi_byte_remaining_length()
{
// Create a packet with remaining length = 200 (requires 2 bytes to encode)
var payload = new byte[200];
var raw = MqttPacketWriter.Write(MqttControlPacketType.Publish, payload, flags: 0x00);
var seq = new ReadOnlySequence<byte>(raw);
MqttPacketReader.TryRead(seq, out var packet, out var consumed).ShouldBeTrue();
packet.ShouldNotBeNull();
packet.Type.ShouldBe(MqttControlPacketType.Publish);
packet.RemainingLength.ShouldBe(200);
}
[Fact]
public void TryRead_handles_segmented_sequence()
{
// Simulate a split packet across two segments
var connectPayload = BuildConnectPayload("seg-client");
var raw = MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload);
var mid = raw.Length / 2;
var first = new ReadOnlyMemory<byte>(raw, 0, mid);
var second = new ReadOnlyMemory<byte>(raw, mid, raw.Length - mid);
var firstSegment = new MemorySegment<byte>(first);
var lastSegment = firstSegment.Append(second);
var seq = new ReadOnlySequence<byte>(firstSegment, 0, lastSegment, second.Length);
MqttPacketReader.TryRead(seq, out var packet, out _).ShouldBeTrue();
packet.ShouldNotBeNull();
packet.Type.ShouldBe(MqttControlPacketType.Connect);
}
[Fact]
public void TryRead_reads_multiple_packets_from_buffer()
{
var ping = MqttPacketWriter.Write(MqttControlPacketType.PingReq, []);
var combined = new byte[ping.Length * 3];
ping.CopyTo(combined, 0);
ping.CopyTo(combined, ping.Length);
ping.CopyTo(combined, ping.Length * 2);
var seq = new ReadOnlySequence<byte>(combined);
var count = 0;
while (MqttPacketReader.TryRead(seq, out var packet, out var consumed))
{
packet!.Type.ShouldBe(MqttControlPacketType.PingReq);
seq = seq.Slice(consumed);
count++;
}
count.ShouldBe(3);
}
[Fact]
public void TryRead_zero_remaining_length_packet()
{
// PINGREQ has 0 remaining length
var raw = MqttPacketWriter.Write(MqttControlPacketType.PingReq, []);
raw.Length.ShouldBe(2); // 1 byte header + 1 byte remaining length (0)
var seq = new ReadOnlySequence<byte>(raw);
MqttPacketReader.TryRead(seq, out var packet, out _).ShouldBeTrue();
packet!.Type.ShouldBe(MqttControlPacketType.PingReq);
packet.RemainingLength.ShouldBe(0);
}
// -----------------------------------------------------------------------
// MqttBinaryDecoder.ParseUnsubscribe tests
// -----------------------------------------------------------------------
[Fact]
public void ParseUnsubscribe_single_filter()
{
var payload = new List<byte>();
// Packet ID
payload.Add(0x00);
payload.Add(0x0A); // 10
// Topic filter
var filter = Encoding.UTF8.GetBytes("sensor/temp");
payload.Add((byte)(filter.Length >> 8));
payload.Add((byte)(filter.Length & 0xFF));
payload.AddRange(filter);
var result = MqttBinaryDecoder.ParseUnsubscribe([.. payload]);
result.PacketId.ShouldBe((ushort)10);
result.Filters.Count.ShouldBe(1);
result.Filters[0].ShouldBe("sensor/temp");
}
[Fact]
public void ParseUnsubscribe_multiple_filters()
{
var payload = new List<byte>();
payload.Add(0x00);
payload.Add(0x01); // Packet ID = 1
foreach (var topic in new[] { "a/b", "c/d", "e/f" })
{
var bytes = Encoding.UTF8.GetBytes(topic);
payload.Add((byte)(bytes.Length >> 8));
payload.Add((byte)(bytes.Length & 0xFF));
payload.AddRange(bytes);
}
var result = MqttBinaryDecoder.ParseUnsubscribe([.. payload]);
result.PacketId.ShouldBe((ushort)1);
result.Filters.Count.ShouldBe(3);
result.Filters[0].ShouldBe("a/b");
result.Filters[1].ShouldBe("c/d");
result.Filters[2].ShouldBe("e/f");
}
[Fact]
public void ParseUnsubscribe_rejects_invalid_flags()
{
var payload = new byte[] { 0x00, 0x01, 0x00, 0x01, (byte)'a' };
Should.Throw<FormatException>(() => MqttBinaryDecoder.ParseUnsubscribe(payload, flags: 0x00));
}
[Fact]
public void ParseUnsubscribe_rejects_empty_filter_list()
{
// Just packet ID, no filters
var payload = new byte[] { 0x00, 0x01 };
Should.Throw<FormatException>(() => MqttBinaryDecoder.ParseUnsubscribe(payload));
}
// -----------------------------------------------------------------------
// MqttPacketWriter response helper tests
// -----------------------------------------------------------------------
[Fact]
public void WriteConnAck_encodes_correctly()
{
var data = MqttPacketWriter.WriteConnAck(0x01, 0x00);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.ConnAck);
packet.RemainingLength.ShouldBe(2);
packet.Payload.Span[0].ShouldBe((byte)0x01); // session present
packet.Payload.Span[1].ShouldBe((byte)0x00); // accepted
}
[Fact]
public void WritePubAck_round_trips_packet_id()
{
var data = MqttPacketWriter.WritePubAck(42);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.PubAck);
var id = (ushort)((packet.Payload.Span[0] << 8) | packet.Payload.Span[1]);
id.ShouldBe((ushort)42);
}
[Fact]
public void WriteSubAck_encodes_granted_qos()
{
byte[] grantedQoS = [0, 1, 2, 0x80]; // 0x80 = failure
var data = MqttPacketWriter.WriteSubAck(99, grantedQoS);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.SubAck);
// Packet ID
var id = (ushort)((packet.Payload.Span[0] << 8) | packet.Payload.Span[1]);
id.ShouldBe((ushort)99);
// QoS values
packet.Payload.Span[2].ShouldBe((byte)0);
packet.Payload.Span[3].ShouldBe((byte)1);
packet.Payload.Span[4].ShouldBe((byte)2);
packet.Payload.Span[5].ShouldBe((byte)0x80);
}
[Fact]
public void WriteUnsubAck_round_trips_packet_id()
{
var data = MqttPacketWriter.WriteUnsubAck(7);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.UnsubAck);
var id = (ushort)((packet.Payload.Span[0] << 8) | packet.Payload.Span[1]);
id.ShouldBe((ushort)7);
}
[Fact]
public void WritePingResp_is_correct()
{
var data = MqttPacketWriter.WritePingResp();
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.PingResp);
packet.RemainingLength.ShouldBe(0);
}
[Fact]
public void WritePubRec_round_trips_packet_id()
{
var data = MqttPacketWriter.WritePubRec(100);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.PubRec);
var id = (ushort)((packet.Payload.Span[0] << 8) | packet.Payload.Span[1]);
id.ShouldBe((ushort)100);
}
[Fact]
public void WritePubRel_has_correct_flags()
{
var data = MqttPacketWriter.WritePubRel(50);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.PubRel);
packet.Flags.ShouldBe((byte)0x02); // PUBREL must have flags 0x02
}
[Fact]
public void WritePubComp_round_trips_packet_id()
{
var data = MqttPacketWriter.WritePubComp(200);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.PubComp);
var id = (ushort)((packet.Payload.Span[0] << 8) | packet.Payload.Span[1]);
id.ShouldBe((ushort)200);
}
[Fact]
public void WritePublish_qos0_no_packet_id()
{
var data = MqttPacketWriter.WritePublish("test/topic", "hello"u8, qos: 0);
var packet = MqttPacketReader.Read(data);
packet.Type.ShouldBe(MqttControlPacketType.Publish);
var pub = MqttBinaryDecoder.ParsePublish(packet.Payload.Span, packet.Flags);
pub.Topic.ShouldBe("test/topic");
pub.QoS.ShouldBe((byte)0);
pub.PacketId.ShouldBe((ushort)0);
Encoding.UTF8.GetString(pub.Payload.Span).ShouldBe("hello");
}
[Fact]
public void WritePublish_qos1_with_flags()
{
var data = MqttPacketWriter.WritePublish("a/b", "data"u8, qos: 1, retain: true, dup: true, packetId: 5);
var packet = MqttPacketReader.Read(data);
var pub = MqttBinaryDecoder.ParsePublish(packet.Payload.Span, packet.Flags);
pub.QoS.ShouldBe((byte)1);
pub.Retain.ShouldBeTrue();
pub.Dup.ShouldBeTrue();
pub.PacketId.ShouldBe((ushort)5);
}
// -----------------------------------------------------------------------
// Enum completeness
// -----------------------------------------------------------------------
[Theory]
[InlineData(MqttControlPacketType.PubRec, 5)]
[InlineData(MqttControlPacketType.PubRel, 6)]
[InlineData(MqttControlPacketType.PubComp, 7)]
[InlineData(MqttControlPacketType.Unsubscribe, 10)]
[InlineData(MqttControlPacketType.UnsubAck, 11)]
public void Enum_has_all_mqtt_packet_types(MqttControlPacketType type, byte expectedValue)
{
((byte)type).ShouldBe(expectedValue);
}
// -----------------------------------------------------------------------
// Binary connection integration tests (MQTT 3.1.1 compliance)
// -----------------------------------------------------------------------
[Fact]
public async Task Binary_connect_and_ping_pong()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
// Send CONNECT
await SendMqttPacketAsync(stream, BuildConnectPacket("ping-client"));
// Read CONNACK
var connAck = await ReadMqttPacketAsync(stream);
connAck.Type.ShouldBe(MqttControlPacketType.ConnAck);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckAccepted);
// Send PINGREQ
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.PingReq, []));
// Read PINGRESP
var pingResp = await ReadMqttPacketAsync(stream);
pingResp.Type.ShouldBe(MqttControlPacketType.PingResp);
}
[Fact]
public async Task Binary_first_packet_must_be_connect()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
// Send PINGREQ as first packet (not CONNECT) — should be disconnected
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.PingReq, []));
// Connection should be closed
var response = await ReadWithTimeoutAsync(stream, 500);
response.ShouldBeNull();
}
[Fact]
public async Task Binary_reject_bad_protocol_level()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
// CONNECT with protocol level 5 (not 4)
var connectPayload = BuildConnectPayload("bad-level", protocolLevel: 5);
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload));
var connAck = await ReadMqttPacketAsync(stream);
connAck.Type.ShouldBe(MqttControlPacketType.ConnAck);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckUnacceptableProtocolVersion);
}
[Fact]
public async Task Binary_empty_clientid_clean_session_generates_id()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
// CONNECT with empty client ID + clean session
var connectPayload = BuildConnectPayload("", cleanSession: true);
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload));
var connAck = await ReadMqttPacketAsync(stream);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckAccepted);
}
[Fact]
public async Task Binary_empty_clientid_persistent_session_rejected()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
// CONNECT with empty client ID + persistent session
var connectPayload = BuildConnectPayload("", cleanSession: false);
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload));
var connAck = await ReadMqttPacketAsync(stream);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckIdentifierRejected);
}
[Fact]
public async Task Binary_auth_failure_returns_not_authorized()
{
await using var listener = new MqttListener("127.0.0.1", 0,
requiredUsername: "admin", requiredPassword: "pass");
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
// CONNECT with wrong credentials
var connectPayload = BuildConnectPayload("auth-fail", username: "wrong", password: "creds");
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload));
var connAck = await ReadMqttPacketAsync(stream);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckNotAuthorized);
}
[Fact]
public async Task Binary_auth_success_with_credentials()
{
await using var listener = new MqttListener("127.0.0.1", 0,
requiredUsername: "admin", requiredPassword: "secret");
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
var connectPayload = BuildConnectPayload("auth-ok", username: "admin", password: "secret");
await SendMqttPacketAsync(stream, MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload));
var connAck = await ReadMqttPacketAsync(stream);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckAccepted);
}
[Fact]
public async Task Binary_subscribe_and_publish_qos0()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
// Subscriber
using var subTcp = new TcpClient();
await subTcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var subStream = subTcp.GetStream();
await ConnectAsync(subStream, "sub-client");
// Subscribe to "test/topic"
await SendMqttPacketAsync(subStream, BuildSubscribePacket(1, "test/topic", 0));
var subAck = await ReadMqttPacketAsync(subStream);
subAck.Type.ShouldBe(MqttControlPacketType.SubAck);
// Publisher
using var pubTcp = new TcpClient();
await pubTcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var pubStream = pubTcp.GetStream();
await ConnectAsync(pubStream, "pub-client");
// Publish to "test/topic"
await SendMqttPacketAsync(pubStream,
MqttPacketWriter.WritePublish("test/topic", "hello binary"u8));
// Subscriber should receive PUBLISH
var received = await ReadMqttPacketAsync(subStream);
received.Type.ShouldBe(MqttControlPacketType.Publish);
var pub = MqttBinaryDecoder.ParsePublish(received.Payload.Span, received.Flags);
pub.Topic.ShouldBe("test/topic");
Encoding.UTF8.GetString(pub.Payload.Span).ShouldBe("hello binary");
}
[Fact]
public async Task Binary_publish_qos1_gets_puback()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
await ConnectAsync(stream, "qos1-pub");
// Publish QoS 1
await SendMqttPacketAsync(stream,
MqttPacketWriter.WritePublish("qos1/topic", "msg"u8, qos: 1, packetId: 42));
var pubAck = await ReadMqttPacketAsync(stream);
pubAck.Type.ShouldBe(MqttControlPacketType.PubAck);
var id = (ushort)((pubAck.Payload.Span[0] << 8) | pubAck.Payload.Span[1]);
id.ShouldBe((ushort)42);
}
[Fact]
public async Task Binary_publish_qos2_full_flow()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
await ConnectAsync(stream, "qos2-pub");
// Step 1: PUBLISH QoS 2
await SendMqttPacketAsync(stream,
MqttPacketWriter.WritePublish("qos2/topic", "msg"u8, qos: 2, packetId: 10));
// Step 2: Receive PUBREC
var pubRec = await ReadMqttPacketAsync(stream);
pubRec.Type.ShouldBe(MqttControlPacketType.PubRec);
// Step 3: Send PUBREL
await SendMqttPacketAsync(stream, MqttPacketWriter.WritePubRel(10));
// Step 4: Receive PUBCOMP
var pubComp = await ReadMqttPacketAsync(stream);
pubComp.Type.ShouldBe(MqttControlPacketType.PubComp);
var id = (ushort)((pubComp.Payload.Span[0] << 8) | pubComp.Payload.Span[1]);
id.ShouldBe((ushort)10);
}
[Fact]
public async Task Binary_unsubscribe_returns_unsuback()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
await ConnectAsync(stream, "unsub-client");
// Subscribe
await SendMqttPacketAsync(stream, BuildSubscribePacket(1, "test/unsub", 0));
_ = await ReadMqttPacketAsync(stream); // SUBACK
// Unsubscribe
await SendMqttPacketAsync(stream, BuildUnsubscribePacket(2, "test/unsub"));
var unsubAck = await ReadMqttPacketAsync(stream);
unsubAck.Type.ShouldBe(MqttControlPacketType.UnsubAck);
var id = (ushort)((unsubAck.Payload.Span[0] << 8) | unsubAck.Payload.Span[1]);
id.ShouldBe((ushort)2);
}
[Fact]
public async Task Binary_unsubscribe_stops_message_delivery()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
// Subscriber
using var subTcp = new TcpClient();
await subTcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var subStream = subTcp.GetStream();
await ConnectAsync(subStream, "unsub-recv");
await SendMqttPacketAsync(subStream, BuildSubscribePacket(1, "nosub/topic", 0));
_ = await ReadMqttPacketAsync(subStream); // SUBACK
// Unsubscribe
await SendMqttPacketAsync(subStream, BuildUnsubscribePacket(2, "nosub/topic"));
_ = await ReadMqttPacketAsync(subStream); // UNSUBACK
// Publisher
using var pubTcp = new TcpClient();
await pubTcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var pubStream = pubTcp.GetStream();
await ConnectAsync(pubStream, "unsub-pub");
await SendMqttPacketAsync(pubStream,
MqttPacketWriter.WritePublish("nosub/topic", "invisible"u8));
// Subscriber should NOT receive anything
var result = await ReadWithTimeoutAsync(subStream, 200);
result.ShouldBeNull();
}
[Fact]
public async Task Binary_disconnect_clears_will_message()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
// Subscriber for will topic
using var subTcp = new TcpClient();
await subTcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var subStream = subTcp.GetStream();
await ConnectAsync(subStream, "will-sub");
await SendMqttPacketAsync(subStream, BuildSubscribePacket(1, "will/topic", 0));
_ = await ReadMqttPacketAsync(subStream); // SUBACK
// Client with will
using var willTcp = new TcpClient();
await willTcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var willStream = willTcp.GetStream();
var connectPayload = BuildConnectPayload("will-client",
willTopic: "will/topic", willMessage: "oops");
await SendMqttPacketAsync(willStream, MqttPacketWriter.Write(MqttControlPacketType.Connect, connectPayload));
_ = await ReadMqttPacketAsync(willStream); // CONNACK
// Clean DISCONNECT — should clear will
await SendMqttPacketAsync(willStream,
MqttPacketWriter.Write(MqttControlPacketType.Disconnect, []));
// Wait a bit and check that will was NOT published
var result = await ReadWithTimeoutAsync(subStream, 300);
result.ShouldBeNull();
}
[Fact]
public async Task Binary_duplicate_clientid_takeover()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
// First connection
using var tcp1 = new TcpClient();
await tcp1.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream1 = tcp1.GetStream();
await ConnectAsync(stream1, "dup-client");
// Second connection with same client-id (takeover)
using var tcp2 = new TcpClient();
await tcp2.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream2 = tcp2.GetStream();
await ConnectAsync(stream2, "dup-client");
// First connection should be closed
var result = await ReadWithTimeoutAsync(stream1, 500);
result.ShouldBeNull();
// Second connection should still work (PINGREQ/PINGRESP)
await SendMqttPacketAsync(stream2, MqttPacketWriter.Write(MqttControlPacketType.PingReq, []));
var pingResp = await ReadMqttPacketAsync(stream2);
pingResp.Type.ShouldBe(MqttControlPacketType.PingResp);
}
[Fact]
public async Task Binary_subscribe_flags_validation()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
await ConnectAsync(stream, "bad-sub-flags");
// Send SUBSCRIBE with wrong flags (0x00 instead of 0x02)
var subPayload = BuildSubscribePayload(1, "test/topic", 0);
var badPacket = MqttPacketWriter.Write(MqttControlPacketType.Subscribe, subPayload, flags: 0x00);
await SendMqttPacketAsync(stream, badPacket);
// Connection should be closed
var result = await ReadWithTimeoutAsync(stream, 500);
result.ShouldBeNull();
}
[Fact]
public async Task Binary_retained_message_tombstone()
{
await using var listener = new MqttListener("127.0.0.1", 0);
await listener.StartAsync(CancellationToken.None);
using var tcp = new TcpClient();
await tcp.ConnectAsync(IPAddress.Loopback, listener.Port);
var stream = tcp.GetStream();
await ConnectAsync(stream, "retain-client");
// Publish retained message
await SendMqttPacketAsync(stream,
MqttPacketWriter.WritePublish("retain/topic", "kept"u8, retain: true));
// Wait for the server to process the retained publish
for (var i = 0; i < 20; i++)
{
if (listener.GetRetainedMessage("retain/topic") != null)
break;
await Task.Delay(25);
}
// Verify retained
listener.GetRetainedMessage("retain/topic").ShouldBe("kept");
// Publish empty retained (tombstone)
await SendMqttPacketAsync(stream,
MqttPacketWriter.WritePublish("retain/topic", ReadOnlySpan<byte>.Empty, retain: true));
// Wait for the server to process the packet
for (var i = 0; i < 20; i++)
{
if (listener.GetRetainedMessage("retain/topic") == null)
break;
await Task.Delay(25);
}
// Verify tombstoned
listener.GetRetainedMessage("retain/topic").ShouldBeNull();
}
// -----------------------------------------------------------------------
// Helpers
// -----------------------------------------------------------------------
private static async Task ConnectAsync(NetworkStream stream, string clientId)
{
await SendMqttPacketAsync(stream, BuildConnectPacket(clientId));
var connAck = await ReadMqttPacketAsync(stream);
connAck.Type.ShouldBe(MqttControlPacketType.ConnAck);
connAck.Payload.Span[1].ShouldBe(MqttProtocolConstants.ConnAckAccepted);
}
private static byte[] BuildConnectPacket(string clientId, string? username = null, string? password = null,
bool cleanSession = true, byte protocolLevel = 4, string? willTopic = null, string? willMessage = null)
{
var payload = BuildConnectPayload(clientId, username, password, cleanSession, protocolLevel, willTopic, willMessage);
return MqttPacketWriter.Write(MqttControlPacketType.Connect, payload);
}
private static byte[] BuildConnectPayload(string clientId, string? username = null, string? password = null,
bool cleanSession = true, byte protocolLevel = 4, string? willTopic = null, string? willMessage = null)
{
var buf = new List<byte>();
// Protocol name "MQTT"
buf.AddRange(MqttPacketWriter.WriteString("MQTT"));
// Protocol level
buf.Add(protocolLevel);
// Connect flags
byte flags = 0;
if (cleanSession) flags |= 0x02;
if (username != null) flags |= 0x80;
if (password != null) flags |= 0x40;
if (willTopic != null)
{
flags |= 0x04; // will flag
// will QoS = 0, will retain = 0
}
buf.Add(flags);
// Keep-alive (60 seconds)
buf.Add(0x00);
buf.Add(0x3C);
// Client ID
buf.AddRange(MqttPacketWriter.WriteString(clientId));
// Will topic + message
if (willTopic != null)
{
buf.AddRange(MqttPacketWriter.WriteString(willTopic));
buf.AddRange(MqttPacketWriter.WriteBytes(
Encoding.UTF8.GetBytes(willMessage ?? "")));
}
// Username
if (username != null)
buf.AddRange(MqttPacketWriter.WriteString(username));
// Password
if (password != null)
buf.AddRange(MqttPacketWriter.WriteString(password));
return [.. buf];
}
private static byte[] BuildSubscribePacket(ushort packetId, string topic, byte qos)
{
var payload = BuildSubscribePayload(packetId, topic, qos);
return MqttPacketWriter.Write(MqttControlPacketType.Subscribe, payload, flags: 0x02);
}
private static byte[] BuildSubscribePayload(ushort packetId, string topic, byte qos)
{
var buf = new List<byte>();
buf.Add((byte)(packetId >> 8));
buf.Add((byte)(packetId & 0xFF));
buf.AddRange(MqttPacketWriter.WriteString(topic));
buf.Add(qos);
return [.. buf];
}
private static byte[] BuildUnsubscribePacket(ushort packetId, string topic)
{
var buf = new List<byte>();
buf.Add((byte)(packetId >> 8));
buf.Add((byte)(packetId & 0xFF));
buf.AddRange(MqttPacketWriter.WriteString(topic));
return MqttPacketWriter.Write(MqttControlPacketType.Unsubscribe, [.. buf], flags: 0x02);
}
private static async Task SendMqttPacketAsync(NetworkStream stream, byte[] packet)
{
await stream.WriteAsync(packet);
await stream.FlushAsync();
}
private static async Task<MqttControlPacket> ReadMqttPacketAsync(NetworkStream stream, int timeoutMs = 2000)
{
using var cts = new CancellationTokenSource(timeoutMs);
var buf = new byte[4096];
var offset = 0;
while (true)
{
var read = await stream.ReadAsync(buf.AsMemory(offset), cts.Token);
if (read == 0)
throw new IOException("Connection closed while reading MQTT packet");
offset += read;
var seq = new ReadOnlySequence<byte>(buf.AsMemory(0, offset));
if (MqttPacketReader.TryRead(seq, out var packet, out _))
return packet!;
}
}
private static async Task<MqttControlPacket?> ReadWithTimeoutAsync(NetworkStream stream, int timeoutMs)
{
try
{
return await ReadMqttPacketAsync(stream, timeoutMs);
}
catch
{
return null;
}
}
/// <summary>
/// Helper for creating segmented ReadOnlySequence for split-packet tests.
/// </summary>
private sealed class MemorySegment<T> : ReadOnlySequenceSegment<T>
{
public MemorySegment(ReadOnlyMemory<T> memory)
{
Memory = memory;
}
public MemorySegment<T> Append(ReadOnlyMemory<T> memory)
{
var segment = new MemorySegment<T>(memory)
{
RunningIndex = RunningIndex + Memory.Length,
};
Next = segment;
return segment;
}
}
}