diff --git a/src/NATS.Server/Protocol/NatsParser.cs b/src/NATS.Server/Protocol/NatsParser.cs new file mode 100644 index 0000000..8497c61 --- /dev/null +++ b/src/NATS.Server/Protocol/NatsParser.cs @@ -0,0 +1,455 @@ +using System.Buffers; +using System.Text; + +namespace NATS.Server.Protocol; + +public enum CommandType +{ + Ping, + Pong, + Connect, + Info, + Pub, + HPub, + Sub, + Unsub, + Ok, + Err, +} + +public readonly struct ParsedCommand +{ + public CommandType Type { get; init; } + public string? Subject { get; init; } + public string? ReplyTo { get; init; } + public string? Queue { get; init; } + public string? Sid { get; init; } + public int MaxMessages { get; init; } + public int HeaderSize { get; init; } + public ReadOnlyMemory Payload { get; init; } + + public static ParsedCommand Simple(CommandType type) => new() { Type = type, MaxMessages = -1 }; +} + +public sealed class NatsParser +{ + private static readonly byte[] CrLfBytes = "\r\n"u8.ToArray(); + private readonly int _maxPayload; + + // State for split-packet payload reading + private bool _awaitingPayload; + private int _expectedPayloadSize; + private string? _pendingSubject; + private string? _pendingReplyTo; + private int _pendingHeaderSize; + private CommandType _pendingType; + + public NatsParser(int maxPayload = NatsProtocol.MaxPayloadSize) + { + _maxPayload = maxPayload; + } + + public bool TryParse(ref ReadOnlySequence buffer, out ParsedCommand command) + { + command = default; + + if (_awaitingPayload) + return TryReadPayload(ref buffer, out command); + + // Look for \r\n to find control line + var reader = new SequenceReader(buffer); + if (!reader.TryReadTo(out ReadOnlySequence line, CrLfBytes.AsSpan())) + return false; + + // Control line size check + if (line.Length > NatsProtocol.MaxControlLineSize) + throw new ProtocolViolationException("Maximum control line exceeded"); + + // Get line as contiguous span + Span lineSpan = stackalloc byte[(int)line.Length]; + line.CopyTo(lineSpan); + + // Identify command by first bytes + if (lineSpan.Length < 2) + { + if (lineSpan.Length == 1 && lineSpan[0] is (byte)'+') + { + // partial -- need more data + return false; + } + + throw new ProtocolViolationException("Unknown protocol operation"); + } + + byte b0 = (byte)(lineSpan[0] | 0x20); // lowercase + byte b1 = (byte)(lineSpan[1] | 0x20); + + switch (b0) + { + case (byte)'p': + if (b1 == (byte)'i') // PING + { + command = ParsedCommand.Simple(CommandType.Ping); + buffer = buffer.Slice(reader.Position); + return true; + } + + if (b1 == (byte)'o') // PONG + { + command = ParsedCommand.Simple(CommandType.Pong); + buffer = buffer.Slice(reader.Position); + return true; + } + + if (b1 == (byte)'u') // PUB + { + return ParsePub(lineSpan, ref buffer, reader.Position, out command); + } + + break; + + case (byte)'h': + if (b1 == (byte)'p') // HPUB + { + return ParseHPub(lineSpan, ref buffer, reader.Position, out command); + } + + break; + + case (byte)'s': + if (b1 == (byte)'u') // SUB + { + command = ParseSub(lineSpan); + buffer = buffer.Slice(reader.Position); + return true; + } + + break; + + case (byte)'u': + if (b1 == (byte)'n') // UNSUB + { + command = ParseUnsub(lineSpan); + buffer = buffer.Slice(reader.Position); + return true; + } + + break; + + case (byte)'c': + if (b1 == (byte)'o') // CONNECT + { + command = ParseConnect(lineSpan); + buffer = buffer.Slice(reader.Position); + return true; + } + + break; + + case (byte)'i': + if (b1 == (byte)'n') // INFO + { + command = ParseInfo(lineSpan); + buffer = buffer.Slice(reader.Position); + return true; + } + + break; + + case (byte)'+': // +OK + command = ParsedCommand.Simple(CommandType.Ok); + buffer = buffer.Slice(reader.Position); + return true; + + case (byte)'-': // -ERR + command = ParsedCommand.Simple(CommandType.Err); + buffer = buffer.Slice(reader.Position); + return true; + } + + throw new ProtocolViolationException("Unknown protocol operation"); + } + + private bool ParsePub( + Span line, + ref ReadOnlySequence buffer, + SequencePosition afterLine, + out ParsedCommand command) + { + command = default; + + // PUB subject [reply] size -- skip "PUB " + Span ranges = stackalloc Range[4]; + var argsSpan = line[4..]; + int argCount = SplitArgs(argsSpan, ranges); + + string subject; + string? reply = null; + int size; + + if (argCount == 2) + { + subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); + size = ParseSize(argsSpan[ranges[1]]); + } + else if (argCount == 3) + { + subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); + reply = Encoding.ASCII.GetString(argsSpan[ranges[1]]); + size = ParseSize(argsSpan[ranges[2]]); + } + else + { + throw new ProtocolViolationException("Invalid PUB arguments"); + } + + if (size < 0 || size > _maxPayload) + throw new ProtocolViolationException("Invalid payload size"); + + // Now read payload + \r\n + buffer = buffer.Slice(afterLine); + _awaitingPayload = true; + _expectedPayloadSize = size; + _pendingSubject = subject; + _pendingReplyTo = reply; + _pendingHeaderSize = -1; + _pendingType = CommandType.Pub; + + return TryReadPayload(ref buffer, out command); + } + + private bool ParseHPub( + Span line, + ref ReadOnlySequence buffer, + SequencePosition afterLine, + out ParsedCommand command) + { + command = default; + + // HPUB subject [reply] hdr_size total_size -- skip "HPUB " + Span ranges = stackalloc Range[5]; + var argsSpan = line[5..]; + int argCount = SplitArgs(argsSpan, ranges); + + string subject; + string? reply = null; + int hdrSize, totalSize; + + if (argCount == 3) + { + subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); + hdrSize = ParseSize(argsSpan[ranges[1]]); + totalSize = ParseSize(argsSpan[ranges[2]]); + } + else if (argCount == 4) + { + subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); + reply = Encoding.ASCII.GetString(argsSpan[ranges[1]]); + hdrSize = ParseSize(argsSpan[ranges[2]]); + totalSize = ParseSize(argsSpan[ranges[3]]); + } + else + { + throw new ProtocolViolationException("Invalid HPUB arguments"); + } + + if (hdrSize < 0 || totalSize < 0 || hdrSize > totalSize || totalSize > _maxPayload) + throw new ProtocolViolationException("Invalid HPUB sizes"); + + buffer = buffer.Slice(afterLine); + _awaitingPayload = true; + _expectedPayloadSize = totalSize; + _pendingSubject = subject; + _pendingReplyTo = reply; + _pendingHeaderSize = hdrSize; + _pendingType = CommandType.HPub; + + return TryReadPayload(ref buffer, out command); + } + + private bool TryReadPayload(ref ReadOnlySequence buffer, out ParsedCommand command) + { + command = default; + + // Need: _expectedPayloadSize bytes + \r\n + long needed = _expectedPayloadSize + 2; // payload + \r\n + if (buffer.Length < needed) + return false; + + // Extract payload + var payloadSlice = buffer.Slice(0, _expectedPayloadSize); + var payload = new byte[_expectedPayloadSize]; + payloadSlice.CopyTo(payload); + + // Verify \r\n after payload + var trailer = buffer.Slice(_expectedPayloadSize, 2); + Span trailerBytes = stackalloc byte[2]; + trailer.CopyTo(trailerBytes); + if (trailerBytes[0] != (byte)'\r' || trailerBytes[1] != (byte)'\n') + throw new ProtocolViolationException("Expected \\r\\n after payload"); + + command = new ParsedCommand + { + Type = _pendingType, + Subject = _pendingSubject, + ReplyTo = _pendingReplyTo, + Payload = payload, + HeaderSize = _pendingHeaderSize, + MaxMessages = -1, + }; + + buffer = buffer.Slice(buffer.GetPosition(needed)); + _awaitingPayload = false; + return true; + } + + private static ParsedCommand ParseSub(Span line) + { + // SUB subject [queue] sid -- skip "SUB " + Span ranges = stackalloc Range[4]; + var argsSpan = line[4..]; + int argCount = SplitArgs(argsSpan, ranges); + + return argCount switch + { + 2 => new ParsedCommand + { + Type = CommandType.Sub, + Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]), + Sid = Encoding.ASCII.GetString(argsSpan[ranges[1]]), + MaxMessages = -1, + }, + 3 => new ParsedCommand + { + Type = CommandType.Sub, + Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]), + Queue = Encoding.ASCII.GetString(argsSpan[ranges[1]]), + Sid = Encoding.ASCII.GetString(argsSpan[ranges[2]]), + MaxMessages = -1, + }, + _ => throw new ProtocolViolationException("Invalid SUB arguments"), + }; + } + + private static ParsedCommand ParseUnsub(Span line) + { + // UNSUB sid [max_msgs] -- skip "UNSUB " + Span ranges = stackalloc Range[3]; + var argsSpan = line[6..]; + int argCount = SplitArgs(argsSpan, ranges); + + return argCount switch + { + 1 => new ParsedCommand + { + Type = CommandType.Unsub, + Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]), + MaxMessages = -1, + }, + 2 => new ParsedCommand + { + Type = CommandType.Unsub, + Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]), + MaxMessages = ParseSize(argsSpan[ranges[1]]), + }, + _ => throw new ProtocolViolationException("Invalid UNSUB arguments"), + }; + } + + private static ParsedCommand ParseConnect(Span line) + { + // CONNECT {json} -- find first space after command + int spaceIdx = line.IndexOf((byte)' '); + if (spaceIdx < 0) + throw new ProtocolViolationException("Invalid CONNECT"); + + var json = line[(spaceIdx + 1)..]; + return new ParsedCommand + { + Type = CommandType.Connect, + Payload = json.ToArray(), + MaxMessages = -1, + }; + } + + private static ParsedCommand ParseInfo(Span line) + { + // INFO {json} -- find first space after command + int spaceIdx = line.IndexOf((byte)' '); + if (spaceIdx < 0) + throw new ProtocolViolationException("Invalid INFO"); + + var json = line[(spaceIdx + 1)..]; + return new ParsedCommand + { + Type = CommandType.Info, + Payload = json.ToArray(), + MaxMessages = -1, + }; + } + + /// + /// Parse a decimal integer from ASCII bytes. Returns -1 on error. + /// + internal static int ParseSize(Span data) + { + if (data.Length == 0 || data.Length > 9) + return -1; + int n = 0; + foreach (byte b in data) + { + if (b < (byte)'0' || b > (byte)'9') + return -1; + n = n * 10 + (b - '0'); + } + + return n; + } + + /// + /// Split by spaces/tabs into argument ranges. Returns the number of arguments found. + /// Uses Span<Range> for zero-allocation argument splitting. + /// + internal static int SplitArgs(Span data, Span ranges) + { + int count = 0; + int start = -1; + + for (int i = 0; i < data.Length; i++) + { + byte b = data[i]; + if (b is (byte)' ' or (byte)'\t') + { + if (start >= 0) + { + if (count >= ranges.Length) + throw new ProtocolViolationException("Too many arguments"); + ranges[count++] = start..i; + start = -1; + } + } + else + { + if (start < 0) + start = i; + } + } + + if (start >= 0) + { + if (count >= ranges.Length) + throw new ProtocolViolationException("Too many arguments"); + ranges[count++] = start..data.Length; + } + + return count; + } +} + +public class ProtocolViolationException : Exception +{ + public ProtocolViolationException(string message) + : base(message) + { + } +} diff --git a/tests/NATS.Server.Tests/ParserTests.cs b/tests/NATS.Server.Tests/ParserTests.cs new file mode 100644 index 0000000..a108f23 --- /dev/null +++ b/tests/NATS.Server.Tests/ParserTests.cs @@ -0,0 +1,177 @@ +using System.Buffers; +using System.IO.Pipelines; +using System.Text; +using NATS.Server.Protocol; + +namespace NATS.Server.Tests; + +public class ParserTests +{ + private static async Task> ParseAsync(string input) + { + var pipe = new Pipe(); + var commands = new List(); + + // Write input to pipe + var bytes = Encoding.ASCII.GetBytes(input); + await pipe.Writer.WriteAsync(bytes); + pipe.Writer.Complete(); + + // Parse from pipe + var parser = new NatsParser(maxPayload: NatsProtocol.MaxPayloadSize); + while (true) + { + var result = await pipe.Reader.ReadAsync(); + var buffer = result.Buffer; + + while (parser.TryParse(ref buffer, out var cmd)) + commands.Add(cmd); + + pipe.Reader.AdvanceTo(buffer.Start, buffer.End); + + if (result.IsCompleted) + break; + } + + return commands; + } + + [Fact] + public async Task Parse_PING() + { + var cmds = await ParseAsync("PING\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Ping, cmds[0].Type); + } + + [Fact] + public async Task Parse_PONG() + { + var cmds = await ParseAsync("PONG\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Pong, cmds[0].Type); + } + + [Fact] + public async Task Parse_CONNECT() + { + var cmds = await ParseAsync("CONNECT {\"verbose\":false,\"echo\":true}\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Connect, cmds[0].Type); + Assert.Contains("verbose", Encoding.ASCII.GetString(cmds[0].Payload.ToArray())); + } + + [Fact] + public async Task Parse_SUB_without_queue() + { + var cmds = await ParseAsync("SUB foo 1\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Sub, cmds[0].Type); + Assert.Equal("foo", cmds[0].Subject); + Assert.Null(cmds[0].Queue); + Assert.Equal("1", cmds[0].Sid); + } + + [Fact] + public async Task Parse_SUB_with_queue() + { + var cmds = await ParseAsync("SUB foo workers 1\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Sub, cmds[0].Type); + Assert.Equal("foo", cmds[0].Subject); + Assert.Equal("workers", cmds[0].Queue); + Assert.Equal("1", cmds[0].Sid); + } + + [Fact] + public async Task Parse_UNSUB() + { + var cmds = await ParseAsync("UNSUB 1\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Unsub, cmds[0].Type); + Assert.Equal("1", cmds[0].Sid); + Assert.Equal(-1, cmds[0].MaxMessages); + } + + [Fact] + public async Task Parse_UNSUB_with_max() + { + var cmds = await ParseAsync("UNSUB 1 5\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Unsub, cmds[0].Type); + Assert.Equal("1", cmds[0].Sid); + Assert.Equal(5, cmds[0].MaxMessages); + } + + [Fact] + public async Task Parse_PUB_with_payload() + { + var cmds = await ParseAsync("PUB foo 5\r\nHello\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Pub, cmds[0].Type); + Assert.Equal("foo", cmds[0].Subject); + Assert.Null(cmds[0].ReplyTo); + Assert.Equal("Hello", Encoding.ASCII.GetString(cmds[0].Payload.ToArray())); + } + + [Fact] + public async Task Parse_PUB_with_reply() + { + var cmds = await ParseAsync("PUB foo reply 5\r\nHello\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Pub, cmds[0].Type); + Assert.Equal("foo", cmds[0].Subject); + Assert.Equal("reply", cmds[0].ReplyTo); + Assert.Equal("Hello", Encoding.ASCII.GetString(cmds[0].Payload.ToArray())); + } + + [Fact] + public async Task Parse_multiple_commands() + { + var cmds = await ParseAsync("PING\r\nPONG\r\nSUB foo 1\r\n"); + Assert.Equal(3, cmds.Count); + Assert.Equal(CommandType.Ping, cmds[0].Type); + Assert.Equal(CommandType.Pong, cmds[1].Type); + Assert.Equal(CommandType.Sub, cmds[2].Type); + } + + [Fact] + public async Task Parse_PUB_zero_payload() + { + var cmds = await ParseAsync("PUB foo 0\r\n\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Pub, cmds[0].Type); + Assert.Empty(cmds[0].Payload.ToArray()); + } + + [Fact] + public async Task Parse_case_insensitive() + { + var cmds = await ParseAsync("ping\r\npub FOO 3\r\nabc\r\n"); + Assert.Equal(2, cmds.Count); + Assert.Equal(CommandType.Ping, cmds[0].Type); + Assert.Equal(CommandType.Pub, cmds[1].Type); + } + + [Fact] + public async Task Parse_HPUB() + { + // HPUB subject 12 17\r\nNATS/1.0\r\n\r\nHello\r\n + var header = "NATS/1.0\r\n\r\n"; + var payload = "Hello"; + var total = header.Length + payload.Length; + var cmds = await ParseAsync($"HPUB foo {header.Length} {total}\r\n{header}{payload}\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.HPub, cmds[0].Type); + Assert.Equal("foo", cmds[0].Subject); + Assert.Equal(header.Length, cmds[0].HeaderSize); + } + + [Fact] + public async Task Parse_INFO() + { + var cmds = await ParseAsync("INFO {\"server_id\":\"test\"}\r\n"); + Assert.Single(cmds); + Assert.Equal(CommandType.Info, cmds[0].Type); + } +}