From 7faf42c588eb13e8060db0285c82adc2771a9775 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Mon, 23 Feb 2026 14:41:23 -0500 Subject: [PATCH] feat: implement mqtt packet-level parser and writer --- src/NATS.Server/Mqtt/MqttPacketReader.cs | 63 +++++++++++++++++++ src/NATS.Server/Mqtt/MqttPacketWriter.cs | 38 +++++++++++ src/NATS.Server/Mqtt/MqttProtocolParser.cs | 6 ++ .../Mqtt/MqttPacketParserTests.cs | 26 ++++++++ .../Mqtt/MqttPacketWriterTests.cs | 20 ++++++ 5 files changed, 153 insertions(+) create mode 100644 src/NATS.Server/Mqtt/MqttPacketReader.cs create mode 100644 src/NATS.Server/Mqtt/MqttPacketWriter.cs create mode 100644 tests/NATS.Server.Tests/Mqtt/MqttPacketParserTests.cs create mode 100644 tests/NATS.Server.Tests/Mqtt/MqttPacketWriterTests.cs diff --git a/src/NATS.Server/Mqtt/MqttPacketReader.cs b/src/NATS.Server/Mqtt/MqttPacketReader.cs new file mode 100644 index 0000000..e188097 --- /dev/null +++ b/src/NATS.Server/Mqtt/MqttPacketReader.cs @@ -0,0 +1,63 @@ +namespace NATS.Server.Mqtt; + +public enum MqttControlPacketType : byte +{ + Reserved = 0, + Connect = 1, + ConnAck = 2, + Publish = 3, + PubAck = 4, + Subscribe = 8, + SubAck = 9, + PingReq = 12, + PingResp = 13, + Disconnect = 14, +} + +public sealed record MqttControlPacket( + MqttControlPacketType Type, + byte Flags, + int RemainingLength, + ReadOnlyMemory Payload); + +public static class MqttPacketReader +{ + public static MqttControlPacket Read(ReadOnlySpan buffer) + { + if (buffer.Length < 2) + throw new FormatException("MQTT packet is shorter than fixed header."); + + var first = buffer[0]; + var type = (MqttControlPacketType)(first >> 4); + var flags = (byte)(first & 0x0F); + var remainingLength = DecodeRemainingLength(buffer[1..], out var consumed); + var payloadStart = 1 + consumed; + var totalLength = payloadStart + remainingLength; + if (remainingLength < 0 || totalLength > buffer.Length) + throw new FormatException("MQTT packet remaining length exceeds available bytes."); + + var payload = buffer[payloadStart..totalLength].ToArray(); + return new MqttControlPacket(type, flags, remainingLength, payload); + } + + internal static int DecodeRemainingLength(ReadOnlySpan encoded, out int consumed) + { + var multiplier = 1; + var value = 0; + consumed = 0; + + for (var i = 0; i < encoded.Length && i < 4; i++) + { + var digit = encoded[i]; + consumed++; + value += (digit & 0x7F) * multiplier; + + if ((digit & 0x80) == 0) + return value; + + multiplier *= 128; + } + + throw new FormatException("Invalid MQTT remaining length encoding."); + } +} diff --git a/src/NATS.Server/Mqtt/MqttPacketWriter.cs b/src/NATS.Server/Mqtt/MqttPacketWriter.cs new file mode 100644 index 0000000..e459010 --- /dev/null +++ b/src/NATS.Server/Mqtt/MqttPacketWriter.cs @@ -0,0 +1,38 @@ +namespace NATS.Server.Mqtt; + +public static class MqttPacketWriter +{ + public static byte[] Write(MqttControlPacketType type, ReadOnlySpan payload, byte flags = 0) + { + if (type == MqttControlPacketType.Reserved) + throw new ArgumentOutOfRangeException(nameof(type), "MQTT control packet type must be non-zero."); + + var remainingLength = payload.Length; + var encodedRemainingLength = EncodeRemainingLength(remainingLength); + var buffer = new byte[1 + encodedRemainingLength.Length + remainingLength]; + buffer[0] = (byte)(((byte)type << 4) | (flags & 0x0F)); + encodedRemainingLength.CopyTo(buffer.AsSpan(1)); + payload.CopyTo(buffer.AsSpan(1 + encodedRemainingLength.Length)); + return buffer; + } + + internal static byte[] EncodeRemainingLength(int value) + { + if (value < 0 || value > 268_435_455) + throw new ArgumentOutOfRangeException(nameof(value), "MQTT remaining length must be between 0 and 268435455."); + + Span scratch = stackalloc byte[4]; + var index = 0; + + do + { + var digit = (byte)(value % 128); + value /= 128; + if (value > 0) + digit |= 0x80; + scratch[index++] = digit; + } while (value > 0); + + return scratch[..index].ToArray(); + } +} diff --git a/src/NATS.Server/Mqtt/MqttProtocolParser.cs b/src/NATS.Server/Mqtt/MqttProtocolParser.cs index 8a10201..db8e7bb 100644 --- a/src/NATS.Server/Mqtt/MqttProtocolParser.cs +++ b/src/NATS.Server/Mqtt/MqttProtocolParser.cs @@ -12,6 +12,12 @@ public sealed record MqttPacket(MqttPacketType Type, string Topic, string Payloa public sealed class MqttProtocolParser { + public MqttControlPacket ParsePacket(ReadOnlySpan packet) + => MqttPacketReader.Read(packet); + + public byte[] WritePacket(MqttControlPacketType type, ReadOnlySpan payload, byte flags = 0) + => MqttPacketWriter.Write(type, payload, flags); + public MqttPacket ParseLine(string line) { var trimmed = line.Trim(); diff --git a/tests/NATS.Server.Tests/Mqtt/MqttPacketParserTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttPacketParserTests.cs new file mode 100644 index 0000000..3b9c814 --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttPacketParserTests.cs @@ -0,0 +1,26 @@ +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttPacketParserTests +{ + [Fact] + public void Connect_packet_fixed_header_and_remaining_length_parse_correctly() + { + var packet = MqttPacketReader.Read(ConnectPacketBytes.Sample); + packet.Type.ShouldBe(MqttControlPacketType.Connect); + packet.RemainingLength.ShouldBe(12); + packet.Payload.Length.ShouldBe(12); + } + + private static class ConnectPacketBytes + { + public static readonly byte[] Sample = + [ + 0x10, 0x0C, // CONNECT + remaining length + 0x00, 0x04, (byte)'M', (byte)'Q', (byte)'T', (byte)'T', + 0x04, 0x02, 0x00, 0x3C, // protocol level/flags/keepalive + 0x00, 0x00, // empty client id + ]; + } +} diff --git a/tests/NATS.Server.Tests/Mqtt/MqttPacketWriterTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttPacketWriterTests.cs new file mode 100644 index 0000000..b5ac4ff --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttPacketWriterTests.cs @@ -0,0 +1,20 @@ +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttPacketWriterTests +{ + [Fact] + public void Writer_emits_fixed_header_and_round_trips_with_reader() + { + byte[] payload = Enumerable.Repeat((byte)0xAB, 130).ToArray(); + + var encoded = MqttPacketWriter.Write(MqttControlPacketType.Publish, payload); + encoded[0].ShouldBe((byte)0x30); // PUBLISH type with default flags + + var decoded = MqttPacketReader.Read(encoded); + decoded.Type.ShouldBe(MqttControlPacketType.Publish); + decoded.RemainingLength.ShouldBe(payload.Length); + decoded.Payload.ToArray().ShouldBe(payload); + } +}