diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 620f275..53b0f44 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -29,6 +29,7 @@ public interface ISubListAccess public sealed class NatsClient : IDisposable { + private static readonly ClientCommandMatrix CommandMatrix = new(); private readonly Socket _socket; private readonly Stream _stream; private readonly NatsOptions _options; @@ -46,6 +47,7 @@ public sealed class NatsClient : IDisposable private readonly ServerStats _serverStats; public ulong Id { get; } + public ClientKind Kind { get; } public ClientOptions? ClientOpts { get; private set; } public IMessageRouter? Router { get; set; } public Account? Account { get; private set; } @@ -103,9 +105,11 @@ public sealed class NatsClient : IDisposable public IReadOnlyDictionary Subscriptions => _subs; public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo, - AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats) + AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats, + ClientKind kind = ClientKind.Client) { Id = id; + Kind = kind; _socket = socket; _stream = stream; _options = options; @@ -311,6 +315,13 @@ public sealed class NatsClient : IDisposable { Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks); + if (!CommandMatrix.IsAllowed(Kind, cmd.Operation)) + { + _logger.LogDebug("Command {Command} is not allowed for client kind {ClientKind}", cmd.Operation, Kind); + await SendErrAndCloseAsync("Parser Error"); + return; + } + // If auth is required and CONNECT hasn't been received yet, // only allow CONNECT and PING commands if (_authService.IsAuthRequired && !ConnectReceived) diff --git a/src/NATS.Server/Protocol/ClientCommandMatrix.cs b/src/NATS.Server/Protocol/ClientCommandMatrix.cs new file mode 100644 index 0000000..973e151 --- /dev/null +++ b/src/NATS.Server/Protocol/ClientCommandMatrix.cs @@ -0,0 +1,17 @@ +namespace NATS.Server.Protocol; + +public sealed class ClientCommandMatrix +{ + public bool IsAllowed(ClientKind kind, string? op) + { + if (string.IsNullOrWhiteSpace(op)) + return true; + + return (kind, op.ToUpperInvariant()) switch + { + (ClientKind.Router, "RS+") => true, + (_, "RS+") => false, + _ => true, + }; + } +} diff --git a/src/NATS.Server/Protocol/ClientKind.cs b/src/NATS.Server/Protocol/ClientKind.cs new file mode 100644 index 0000000..d0d9973 --- /dev/null +++ b/src/NATS.Server/Protocol/ClientKind.cs @@ -0,0 +1,12 @@ +namespace NATS.Server.Protocol; + +public enum ClientKind +{ + Client, + Router, + Gateway, + Leaf, + System, + JetStream, + Account, +} diff --git a/src/NATS.Server/Protocol/NatsParser.cs b/src/NATS.Server/Protocol/NatsParser.cs index 7ba0aad..b8df1e8 100644 --- a/src/NATS.Server/Protocol/NatsParser.cs +++ b/src/NATS.Server/Protocol/NatsParser.cs @@ -21,6 +21,7 @@ public enum CommandType public readonly struct ParsedCommand { public CommandType Type { get; init; } + public string? Operation { get; init; } public string? Subject { get; init; } public string? ReplyTo { get; init; } public string? Queue { get; init; } @@ -29,7 +30,8 @@ public readonly struct ParsedCommand public int HeaderSize { get; init; } public ReadOnlyMemory Payload { get; init; } - public static ParsedCommand Simple(CommandType type) => new() { Type = type, MaxMessages = -1 }; + public static ParsedCommand Simple(CommandType type, string operation) => + new() { Type = type, Operation = operation, MaxMessages = -1 }; } public sealed class NatsParser @@ -46,6 +48,7 @@ public sealed class NatsParser private string? _pendingReplyTo; private int _pendingHeaderSize; private CommandType _pendingType; + private string _pendingOperation = string.Empty; public NatsParser(int maxPayload = NatsProtocol.MaxPayloadSize, ILogger? logger = null) { @@ -103,7 +106,7 @@ public sealed class NatsParser case (byte)'p': if (b1 == (byte)'i') // PING { - command = ParsedCommand.Simple(CommandType.Ping); + command = ParsedCommand.Simple(CommandType.Ping, "PING"); buffer = buffer.Slice(reader.Position); TraceInOp("PING"); return true; @@ -111,7 +114,7 @@ public sealed class NatsParser if (b1 == (byte)'o') // PONG { - command = ParsedCommand.Simple(CommandType.Pong); + command = ParsedCommand.Simple(CommandType.Pong, "PONG"); buffer = buffer.Slice(reader.Position); TraceInOp("PONG"); return true; @@ -177,13 +180,13 @@ public sealed class NatsParser break; case (byte)'+': // +OK - command = ParsedCommand.Simple(CommandType.Ok); + command = ParsedCommand.Simple(CommandType.Ok, "+OK"); buffer = buffer.Slice(reader.Position); TraceInOp("+OK"); return true; case (byte)'-': // -ERR - command = ParsedCommand.Simple(CommandType.Err); + command = ParsedCommand.Simple(CommandType.Err, "-ERR"); buffer = buffer.Slice(reader.Position); TraceInOp("-ERR"); return true; @@ -236,6 +239,7 @@ public sealed class NatsParser _pendingReplyTo = reply; _pendingHeaderSize = -1; _pendingType = CommandType.Pub; + _pendingOperation = "PUB"; TraceInOp("PUB", argsSpan); return TryReadPayload(ref buffer, out command); @@ -286,6 +290,7 @@ public sealed class NatsParser _pendingReplyTo = reply; _pendingHeaderSize = hdrSize; _pendingType = CommandType.HPub; + _pendingOperation = "HPUB"; TraceInOp("HPUB", argsSpan); return TryReadPayload(ref buffer, out command); @@ -315,6 +320,7 @@ public sealed class NatsParser command = new ParsedCommand { Type = _pendingType, + Operation = _pendingOperation, Subject = _pendingSubject, ReplyTo = _pendingReplyTo, Payload = payload, @@ -339,6 +345,7 @@ public sealed class NatsParser 2 => new ParsedCommand { Type = CommandType.Sub, + Operation = "SUB", Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]), Sid = Encoding.ASCII.GetString(argsSpan[ranges[1]]), MaxMessages = -1, @@ -346,6 +353,7 @@ public sealed class NatsParser 3 => new ParsedCommand { Type = CommandType.Sub, + Operation = "SUB", Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]), Queue = Encoding.ASCII.GetString(argsSpan[ranges[1]]), Sid = Encoding.ASCII.GetString(argsSpan[ranges[2]]), @@ -367,12 +375,14 @@ public sealed class NatsParser 1 => new ParsedCommand { Type = CommandType.Unsub, + Operation = "UNSUB", Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]), MaxMessages = -1, }, 2 => new ParsedCommand { Type = CommandType.Unsub, + Operation = "UNSUB", Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]), MaxMessages = ParseSize(argsSpan[ranges[1]]), }, @@ -391,6 +401,7 @@ public sealed class NatsParser return new ParsedCommand { Type = CommandType.Connect, + Operation = "CONNECT", Payload = json.ToArray(), MaxMessages = -1, }; @@ -407,6 +418,7 @@ public sealed class NatsParser return new ParsedCommand { Type = CommandType.Info, + Operation = "INFO", Payload = json.ToArray(), MaxMessages = -1, }; diff --git a/tests/NATS.Server.Tests/ClientKindCommandMatrixTests.cs b/tests/NATS.Server.Tests/ClientKindCommandMatrixTests.cs new file mode 100644 index 0000000..c47ea6d --- /dev/null +++ b/tests/NATS.Server.Tests/ClientKindCommandMatrixTests.cs @@ -0,0 +1,14 @@ +using NATS.Server.Protocol; + +namespace NATS.Server.Tests; + +public class ClientKindCommandMatrixTests +{ + [Fact] + public void Router_only_commands_are_rejected_for_client_kind() + { + var matrix = new ClientCommandMatrix(); + matrix.IsAllowed(ClientKind.Client, "RS+").ShouldBeFalse(); + matrix.IsAllowed(ClientKind.Router, "RS+").ShouldBeTrue(); + } +}