feat: add client kind command matrix parity
This commit is contained in:
@@ -29,6 +29,7 @@ public interface ISubListAccess
|
||||
|
||||
public sealed class NatsClient : IDisposable
|
||||
{
|
||||
private static readonly ClientCommandMatrix CommandMatrix = new();
|
||||
private readonly Socket _socket;
|
||||
private readonly Stream _stream;
|
||||
private readonly NatsOptions _options;
|
||||
@@ -46,6 +47,7 @@ public sealed class NatsClient : IDisposable
|
||||
private readonly ServerStats _serverStats;
|
||||
|
||||
public ulong Id { get; }
|
||||
public ClientKind Kind { get; }
|
||||
public ClientOptions? ClientOpts { get; private set; }
|
||||
public IMessageRouter? Router { get; set; }
|
||||
public Account? Account { get; private set; }
|
||||
@@ -103,9 +105,11 @@ public sealed class NatsClient : IDisposable
|
||||
public IReadOnlyDictionary<string, Subscription> Subscriptions => _subs;
|
||||
|
||||
public NatsClient(ulong id, Stream stream, Socket socket, NatsOptions options, ServerInfo serverInfo,
|
||||
AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats)
|
||||
AuthService authService, byte[]? nonce, ILogger logger, ServerStats serverStats,
|
||||
ClientKind kind = ClientKind.Client)
|
||||
{
|
||||
Id = id;
|
||||
Kind = kind;
|
||||
_socket = socket;
|
||||
_stream = stream;
|
||||
_options = options;
|
||||
@@ -311,6 +315,13 @@ public sealed class NatsClient : IDisposable
|
||||
{
|
||||
Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
|
||||
|
||||
if (!CommandMatrix.IsAllowed(Kind, cmd.Operation))
|
||||
{
|
||||
_logger.LogDebug("Command {Command} is not allowed for client kind {ClientKind}", cmd.Operation, Kind);
|
||||
await SendErrAndCloseAsync("Parser Error");
|
||||
return;
|
||||
}
|
||||
|
||||
// If auth is required and CONNECT hasn't been received yet,
|
||||
// only allow CONNECT and PING commands
|
||||
if (_authService.IsAuthRequired && !ConnectReceived)
|
||||
|
||||
17
src/NATS.Server/Protocol/ClientCommandMatrix.cs
Normal file
17
src/NATS.Server/Protocol/ClientCommandMatrix.cs
Normal file
@@ -0,0 +1,17 @@
|
||||
namespace NATS.Server.Protocol;
|
||||
|
||||
public sealed class ClientCommandMatrix
|
||||
{
|
||||
public bool IsAllowed(ClientKind kind, string? op)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(op))
|
||||
return true;
|
||||
|
||||
return (kind, op.ToUpperInvariant()) switch
|
||||
{
|
||||
(ClientKind.Router, "RS+") => true,
|
||||
(_, "RS+") => false,
|
||||
_ => true,
|
||||
};
|
||||
}
|
||||
}
|
||||
12
src/NATS.Server/Protocol/ClientKind.cs
Normal file
12
src/NATS.Server/Protocol/ClientKind.cs
Normal file
@@ -0,0 +1,12 @@
|
||||
namespace NATS.Server.Protocol;
|
||||
|
||||
public enum ClientKind
|
||||
{
|
||||
Client,
|
||||
Router,
|
||||
Gateway,
|
||||
Leaf,
|
||||
System,
|
||||
JetStream,
|
||||
Account,
|
||||
}
|
||||
@@ -21,6 +21,7 @@ public enum CommandType
|
||||
public readonly struct ParsedCommand
|
||||
{
|
||||
public CommandType Type { get; init; }
|
||||
public string? Operation { get; init; }
|
||||
public string? Subject { get; init; }
|
||||
public string? ReplyTo { get; init; }
|
||||
public string? Queue { get; init; }
|
||||
@@ -29,7 +30,8 @@ public readonly struct ParsedCommand
|
||||
public int HeaderSize { get; init; }
|
||||
public ReadOnlyMemory<byte> Payload { get; init; }
|
||||
|
||||
public static ParsedCommand Simple(CommandType type) => new() { Type = type, MaxMessages = -1 };
|
||||
public static ParsedCommand Simple(CommandType type, string operation) =>
|
||||
new() { Type = type, Operation = operation, MaxMessages = -1 };
|
||||
}
|
||||
|
||||
public sealed class NatsParser
|
||||
@@ -46,6 +48,7 @@ public sealed class NatsParser
|
||||
private string? _pendingReplyTo;
|
||||
private int _pendingHeaderSize;
|
||||
private CommandType _pendingType;
|
||||
private string _pendingOperation = string.Empty;
|
||||
|
||||
public NatsParser(int maxPayload = NatsProtocol.MaxPayloadSize, ILogger? logger = null)
|
||||
{
|
||||
@@ -103,7 +106,7 @@ public sealed class NatsParser
|
||||
case (byte)'p':
|
||||
if (b1 == (byte)'i') // PING
|
||||
{
|
||||
command = ParsedCommand.Simple(CommandType.Ping);
|
||||
command = ParsedCommand.Simple(CommandType.Ping, "PING");
|
||||
buffer = buffer.Slice(reader.Position);
|
||||
TraceInOp("PING");
|
||||
return true;
|
||||
@@ -111,7 +114,7 @@ public sealed class NatsParser
|
||||
|
||||
if (b1 == (byte)'o') // PONG
|
||||
{
|
||||
command = ParsedCommand.Simple(CommandType.Pong);
|
||||
command = ParsedCommand.Simple(CommandType.Pong, "PONG");
|
||||
buffer = buffer.Slice(reader.Position);
|
||||
TraceInOp("PONG");
|
||||
return true;
|
||||
@@ -177,13 +180,13 @@ public sealed class NatsParser
|
||||
break;
|
||||
|
||||
case (byte)'+': // +OK
|
||||
command = ParsedCommand.Simple(CommandType.Ok);
|
||||
command = ParsedCommand.Simple(CommandType.Ok, "+OK");
|
||||
buffer = buffer.Slice(reader.Position);
|
||||
TraceInOp("+OK");
|
||||
return true;
|
||||
|
||||
case (byte)'-': // -ERR
|
||||
command = ParsedCommand.Simple(CommandType.Err);
|
||||
command = ParsedCommand.Simple(CommandType.Err, "-ERR");
|
||||
buffer = buffer.Slice(reader.Position);
|
||||
TraceInOp("-ERR");
|
||||
return true;
|
||||
@@ -236,6 +239,7 @@ public sealed class NatsParser
|
||||
_pendingReplyTo = reply;
|
||||
_pendingHeaderSize = -1;
|
||||
_pendingType = CommandType.Pub;
|
||||
_pendingOperation = "PUB";
|
||||
|
||||
TraceInOp("PUB", argsSpan);
|
||||
return TryReadPayload(ref buffer, out command);
|
||||
@@ -286,6 +290,7 @@ public sealed class NatsParser
|
||||
_pendingReplyTo = reply;
|
||||
_pendingHeaderSize = hdrSize;
|
||||
_pendingType = CommandType.HPub;
|
||||
_pendingOperation = "HPUB";
|
||||
|
||||
TraceInOp("HPUB", argsSpan);
|
||||
return TryReadPayload(ref buffer, out command);
|
||||
@@ -315,6 +320,7 @@ public sealed class NatsParser
|
||||
command = new ParsedCommand
|
||||
{
|
||||
Type = _pendingType,
|
||||
Operation = _pendingOperation,
|
||||
Subject = _pendingSubject,
|
||||
ReplyTo = _pendingReplyTo,
|
||||
Payload = payload,
|
||||
@@ -339,6 +345,7 @@ public sealed class NatsParser
|
||||
2 => new ParsedCommand
|
||||
{
|
||||
Type = CommandType.Sub,
|
||||
Operation = "SUB",
|
||||
Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]),
|
||||
Sid = Encoding.ASCII.GetString(argsSpan[ranges[1]]),
|
||||
MaxMessages = -1,
|
||||
@@ -346,6 +353,7 @@ public sealed class NatsParser
|
||||
3 => new ParsedCommand
|
||||
{
|
||||
Type = CommandType.Sub,
|
||||
Operation = "SUB",
|
||||
Subject = Encoding.ASCII.GetString(argsSpan[ranges[0]]),
|
||||
Queue = Encoding.ASCII.GetString(argsSpan[ranges[1]]),
|
||||
Sid = Encoding.ASCII.GetString(argsSpan[ranges[2]]),
|
||||
@@ -367,12 +375,14 @@ public sealed class NatsParser
|
||||
1 => new ParsedCommand
|
||||
{
|
||||
Type = CommandType.Unsub,
|
||||
Operation = "UNSUB",
|
||||
Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]),
|
||||
MaxMessages = -1,
|
||||
},
|
||||
2 => new ParsedCommand
|
||||
{
|
||||
Type = CommandType.Unsub,
|
||||
Operation = "UNSUB",
|
||||
Sid = Encoding.ASCII.GetString(argsSpan[ranges[0]]),
|
||||
MaxMessages = ParseSize(argsSpan[ranges[1]]),
|
||||
},
|
||||
@@ -391,6 +401,7 @@ public sealed class NatsParser
|
||||
return new ParsedCommand
|
||||
{
|
||||
Type = CommandType.Connect,
|
||||
Operation = "CONNECT",
|
||||
Payload = json.ToArray(),
|
||||
MaxMessages = -1,
|
||||
};
|
||||
@@ -407,6 +418,7 @@ public sealed class NatsParser
|
||||
return new ParsedCommand
|
||||
{
|
||||
Type = CommandType.Info,
|
||||
Operation = "INFO",
|
||||
Payload = json.ToArray(),
|
||||
MaxMessages = -1,
|
||||
};
|
||||
|
||||
14
tests/NATS.Server.Tests/ClientKindCommandMatrixTests.cs
Normal file
14
tests/NATS.Server.Tests/ClientKindCommandMatrixTests.cs
Normal file
@@ -0,0 +1,14 @@
|
||||
using NATS.Server.Protocol;
|
||||
|
||||
namespace NATS.Server.Tests;
|
||||
|
||||
public class ClientKindCommandMatrixTests
|
||||
{
|
||||
[Fact]
|
||||
public void Router_only_commands_are_rejected_for_client_kind()
|
||||
{
|
||||
var matrix = new ClientCommandMatrix();
|
||||
matrix.IsAllowed(ClientKind.Client, "RS+").ShouldBeFalse();
|
||||
matrix.IsAllowed(ClientKind.Router, "RS+").ShouldBeTrue();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user