using System.Buffers; using System.Text; using Microsoft.Extensions.Logging; namespace NATS.Server.Protocol; public enum CommandType { Ping, Pong, Connect, Info, Pub, HPub, Sub, Unsub, Ok, Err, } public readonly struct ParsedCommand { public CommandType Type { get; init; } public string? Subject { get; init; } public string? ReplyTo { get; init; } public string? Queue { get; init; } public string? Sid { get; init; } public int MaxMessages { get; init; } public int HeaderSize { get; init; } public ReadOnlyMemory Payload { get; init; } public static ParsedCommand Simple(CommandType type) => new() { Type = type, MaxMessages = -1 }; } public sealed class NatsParser { private static readonly byte[] CrLfBytes = "\r\n"u8.ToArray(); private readonly int _maxPayload; private ILogger? _logger; public ILogger? Logger { set => _logger = value; } // State for split-packet payload reading private bool _awaitingPayload; private int _expectedPayloadSize; private string? _pendingSubject; private string? _pendingReplyTo; private int _pendingHeaderSize; private CommandType _pendingType; public NatsParser(int maxPayload = NatsProtocol.MaxPayloadSize, ILogger? logger = null) { _maxPayload = maxPayload; _logger = logger; } private void TraceInOp(string op, ReadOnlySpan arg = default) { if (_logger == null || !_logger.IsEnabled(LogLevel.Trace)) return; if (arg.IsEmpty) _logger.LogTrace("<<- {Op}", op); else _logger.LogTrace("<<- {Op} {Arg}", op, Encoding.ASCII.GetString(arg)); } public bool TryParse(ref ReadOnlySequence buffer, out ParsedCommand command) { command = default; if (_awaitingPayload) return TryReadPayload(ref buffer, out command); // Look for \r\n to find control line var reader = new SequenceReader(buffer); if (!reader.TryReadTo(out ReadOnlySequence line, CrLfBytes.AsSpan())) return false; // Control line size check if (line.Length > NatsProtocol.MaxControlLineSize) throw new ProtocolViolationException("Maximum control line exceeded"); // Get line as contiguous span Span lineSpan = stackalloc byte[(int)line.Length]; line.CopyTo(lineSpan); // Identify command by first bytes if (lineSpan.Length < 2) { if (lineSpan.Length == 1 && lineSpan[0] is (byte)'+') { // partial -- need more data return false; } throw new ProtocolViolationException("Unknown protocol operation"); } byte b0 = (byte)(lineSpan[0] | 0x20); // lowercase byte b1 = (byte)(lineSpan[1] | 0x20); switch (b0) { case (byte)'p': if (b1 == (byte)'i') // PING { command = ParsedCommand.Simple(CommandType.Ping); buffer = buffer.Slice(reader.Position); TraceInOp("PING"); return true; } if (b1 == (byte)'o') // PONG { command = ParsedCommand.Simple(CommandType.Pong); buffer = buffer.Slice(reader.Position); TraceInOp("PONG"); return true; } if (b1 == (byte)'u') // PUB { return ParsePub(lineSpan, ref buffer, reader.Position, out command); } break; case (byte)'h': if (b1 == (byte)'p') // HPUB { return ParseHPub(lineSpan, ref buffer, reader.Position, out command); } break; case (byte)'s': if (b1 == (byte)'u') // SUB { command = ParseSub(lineSpan); buffer = buffer.Slice(reader.Position); TraceInOp("SUB", lineSpan[4..]); return true; } break; case (byte)'u': if (b1 == (byte)'n') // UNSUB { command = ParseUnsub(lineSpan); buffer = buffer.Slice(reader.Position); TraceInOp("UNSUB", lineSpan[6..]); return true; } break; case (byte)'c': if (b1 == (byte)'o') // CONNECT { command = ParseConnect(lineSpan); buffer = buffer.Slice(reader.Position); TraceInOp("CONNECT"); return true; } break; case (byte)'i': if (b1 == (byte)'n') // INFO { command = ParseInfo(lineSpan); buffer = buffer.Slice(reader.Position); TraceInOp("INFO"); return true; } break; case (byte)'+': // +OK command = ParsedCommand.Simple(CommandType.Ok); buffer = buffer.Slice(reader.Position); TraceInOp("+OK"); return true; case (byte)'-': // -ERR command = ParsedCommand.Simple(CommandType.Err); buffer = buffer.Slice(reader.Position); TraceInOp("-ERR"); return true; } throw new ProtocolViolationException("Unknown protocol operation"); } private bool ParsePub( Span line, ref ReadOnlySequence buffer, SequencePosition afterLine, out ParsedCommand command) { command = default; // PUB subject [reply] size -- skip "PUB " Span ranges = stackalloc Range[4]; var argsSpan = line[4..]; int argCount = SplitArgs(argsSpan, ranges); string subject; string? reply = null; int size; if (argCount == 2) { subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); size = ParseSize(argsSpan[ranges[1]]); } else if (argCount == 3) { subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); reply = Encoding.ASCII.GetString(argsSpan[ranges[1]]); size = ParseSize(argsSpan[ranges[2]]); } else { throw new ProtocolViolationException("Invalid PUB arguments"); } if (size < 0) throw new ProtocolViolationException("Invalid payload size"); // Now read payload + \r\n (max payload enforcement is done at the client level) buffer = buffer.Slice(afterLine); _awaitingPayload = true; _expectedPayloadSize = size; _pendingSubject = subject; _pendingReplyTo = reply; _pendingHeaderSize = -1; _pendingType = CommandType.Pub; TraceInOp("PUB", argsSpan); return TryReadPayload(ref buffer, out command); } private bool ParseHPub( Span line, ref ReadOnlySequence buffer, SequencePosition afterLine, out ParsedCommand command) { command = default; // HPUB subject [reply] hdr_size total_size -- skip "HPUB " Span ranges = stackalloc Range[5]; var argsSpan = line[5..]; int argCount = SplitArgs(argsSpan, ranges); string subject; string? reply = null; int hdrSize, totalSize; if (argCount == 3) { subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); hdrSize = ParseSize(argsSpan[ranges[1]]); totalSize = ParseSize(argsSpan[ranges[2]]); } else if (argCount == 4) { subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]); reply = Encoding.ASCII.GetString(argsSpan[ranges[1]]); hdrSize = ParseSize(argsSpan[ranges[2]]); totalSize = ParseSize(argsSpan[ranges[3]]); } else { throw new ProtocolViolationException("Invalid HPUB arguments"); } if (hdrSize < 0 || totalSize < 0 || hdrSize > totalSize) throw new ProtocolViolationException("Invalid HPUB sizes"); buffer = buffer.Slice(afterLine); _awaitingPayload = true; _expectedPayloadSize = totalSize; _pendingSubject = subject; _pendingReplyTo = reply; _pendingHeaderSize = hdrSize; _pendingType = CommandType.HPub; TraceInOp("HPUB", argsSpan); return TryReadPayload(ref buffer, out command); } private bool TryReadPayload(ref ReadOnlySequence buffer, out ParsedCommand command) { command = default; // Need: _expectedPayloadSize bytes + \r\n long needed = _expectedPayloadSize + 2; // payload + \r\n if (buffer.Length < needed) return false; // Extract payload var payloadSlice = buffer.Slice(0, _expectedPayloadSize); var payload = new byte[_expectedPayloadSize]; payloadSlice.CopyTo(payload); // Verify \r\n after payload var trailer = buffer.Slice(_expectedPayloadSize, 2); Span trailerBytes = stackalloc byte[2]; trailer.CopyTo(trailerBytes); if (trailerBytes[0] != (byte)'\r' || trailerBytes[1] != (byte)'\n') throw new ProtocolViolationException("Expected \\r\\n after payload"); command = new ParsedCommand { Type = _pendingType, Subject = _pendingSubject, ReplyTo = _pendingReplyTo, Payload = payload, HeaderSize = _pendingHeaderSize, MaxMessages = -1, }; buffer = buffer.Slice(buffer.GetPosition(needed)); _awaitingPayload = false; return true; } private static ParsedCommand ParseSub(Span line) { // SUB subject [queue] sid -- skip "SUB " Span ranges = stackalloc Range[4]; var argsSpan = line[4..]; int argCount = SplitArgs(argsSpan, ranges); return argCount switch { 2 => new ParsedCommand { Type = CommandType.Sub, Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]), Sid = Encoding.ASCII.GetString(argsSpan[ranges[1]]), MaxMessages = -1, }, 3 => new ParsedCommand { Type = CommandType.Sub, Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]), Queue = Encoding.ASCII.GetString(argsSpan[ranges[1]]), Sid = Encoding.ASCII.GetString(argsSpan[ranges[2]]), MaxMessages = -1, }, _ => throw new ProtocolViolationException("Invalid SUB arguments"), }; } private static ParsedCommand ParseUnsub(Span line) { // UNSUB sid [max_msgs] -- skip "UNSUB " Span ranges = stackalloc Range[3]; var argsSpan = line[6..]; int argCount = SplitArgs(argsSpan, ranges); return argCount switch { 1 => new ParsedCommand { Type = CommandType.Unsub, Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]), MaxMessages = -1, }, 2 => new ParsedCommand { Type = CommandType.Unsub, Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]), MaxMessages = ParseSize(argsSpan[ranges[1]]), }, _ => throw new ProtocolViolationException("Invalid UNSUB arguments"), }; } private static ParsedCommand ParseConnect(Span line) { // CONNECT {json} -- find first space after command int spaceIdx = line.IndexOf((byte)' '); if (spaceIdx < 0) throw new ProtocolViolationException("Invalid CONNECT"); var json = line[(spaceIdx + 1)..]; return new ParsedCommand { Type = CommandType.Connect, Payload = json.ToArray(), MaxMessages = -1, }; } private static ParsedCommand ParseInfo(Span line) { // INFO {json} -- find first space after command int spaceIdx = line.IndexOf((byte)' '); if (spaceIdx < 0) throw new ProtocolViolationException("Invalid INFO"); var json = line[(spaceIdx + 1)..]; return new ParsedCommand { Type = CommandType.Info, Payload = json.ToArray(), MaxMessages = -1, }; } /// /// Parse a decimal integer from ASCII bytes. Returns -1 on error. /// internal static int ParseSize(Span data) { if (data.Length == 0 || data.Length > 9) return -1; int n = 0; foreach (byte b in data) { if (b < (byte)'0' || b > (byte)'9') return -1; n = n * 10 + (b - '0'); } return n; } /// /// Split by spaces/tabs into argument ranges. Returns the number of arguments found. /// Uses Span<Range> for zero-allocation argument splitting. /// internal static int SplitArgs(Span data, Span ranges) { int count = 0; int start = -1; for (int i = 0; i < data.Length; i++) { byte b = data[i]; if (b is (byte)' ' or (byte)'\t') { if (start >= 0) { if (count >= ranges.Length) throw new ProtocolViolationException("Too many arguments"); ranges[count++] = start..i; start = -1; } } else { if (start < 0) start = i; } } if (start >= 0) { if (count >= ranges.Length) throw new ProtocolViolationException("Too many arguments"); ranges[count++] = start..data.Length; } return count; } } public class ProtocolViolationException : Exception { public ProtocolViolationException(string message) : base(message) { } }