diff --git a/src/NATS.Server/NatsClient.cs b/src/NATS.Server/NatsClient.cs index 44cbc56..31b609a 100644 --- a/src/NATS.Server/NatsClient.cs +++ b/src/NATS.Server/NatsClient.cs @@ -169,7 +169,7 @@ public sealed class NatsClient : IDisposable case CommandType.Pub: case CommandType.HPub: - ProcessPub(cmd); + await ProcessPubAsync(cmd); break; } } @@ -220,11 +220,28 @@ public sealed class NatsClient : IDisposable sl.SubList.Remove(sub); } - private void ProcessPub(ParsedCommand cmd) + private async ValueTask ProcessPubAsync(ParsedCommand cmd) { Interlocked.Increment(ref InMsgs); Interlocked.Add(ref InBytes, cmd.Payload.Length); + // Max payload validation (always, hard close) + if (cmd.Payload.Length > _options.MaxPayload) + { + _logger.LogWarning("Client {ClientId} exceeded max payload: {Size} > {MaxPayload}", + Id, cmd.Payload.Length, _options.MaxPayload); + await SendErrAndCloseAsync(NatsProtocol.ErrMaxPayloadViolation); + return; + } + + // Pedantic mode: validate publish subject + if (ClientOpts?.Pedantic == true && !SubjectMatch.IsValidPublishSubject(cmd.Subject!)) + { + _logger.LogDebug("Client {ClientId} invalid publish subject: {Subject}", Id, cmd.Subject); + await SendErrAsync(NatsProtocol.ErrInvalidPublishSubject); + return; + } + ReadOnlyMemory headers = default; ReadOnlyMemory payload = cmd.Payload; diff --git a/src/NATS.Server/Protocol/NatsParser.cs b/src/NATS.Server/Protocol/NatsParser.cs index 8497c61..2689ec0 100644 --- a/src/NATS.Server/Protocol/NatsParser.cs +++ b/src/NATS.Server/Protocol/NatsParser.cs @@ -203,10 +203,10 @@ public sealed class NatsParser throw new ProtocolViolationException("Invalid PUB arguments"); } - if (size < 0 || size > _maxPayload) + if (size < 0) throw new ProtocolViolationException("Invalid payload size"); - // Now read payload + \r\n + // Now read payload + \r\n (max payload enforcement is done at the client level) buffer = buffer.Slice(afterLine); _awaitingPayload = true; _expectedPayloadSize = size; @@ -253,7 +253,7 @@ public sealed class NatsParser throw new ProtocolViolationException("Invalid HPUB arguments"); } - if (hdrSize < 0 || totalSize < 0 || hdrSize > totalSize || totalSize > _maxPayload) + if (hdrSize < 0 || totalSize < 0 || hdrSize > totalSize) throw new ProtocolViolationException("Invalid HPUB sizes"); buffer = buffer.Slice(afterLine); diff --git a/src/NATS.Server/Subscriptions/SubList.cs b/src/NATS.Server/Subscriptions/SubList.cs index 382689b..047bef2 100644 --- a/src/NATS.Server/Subscriptions/SubList.cs +++ b/src/NATS.Server/Subscriptions/SubList.cs @@ -15,8 +15,13 @@ public sealed class SubList : IDisposable private readonly TrieLevel _root = new(); private Dictionary? _cache = new(StringComparer.Ordinal); private uint _count; + private volatile bool _disposed; - public void Dispose() => _lock.Dispose(); + public void Dispose() + { + _disposed = true; + _lock.Dispose(); + } public uint Count { @@ -95,6 +100,7 @@ public sealed class SubList : IDisposable public void Remove(Subscription sub) { + if (_disposed) return; _lock.EnterWriteLock(); try { diff --git a/tests/NATS.Server.Tests/ServerTests.cs b/tests/NATS.Server.Tests/ServerTests.cs index e80bf7c..476e0c6 100644 --- a/tests/NATS.Server.Tests/ServerTests.cs +++ b/tests/NATS.Server.Tests/ServerTests.cs @@ -123,6 +123,94 @@ public class ServerTests : IAsyncLifetime msg.ShouldContain("MSG foo.bar 1 5\r\n"); } + + [Fact] + public async Task Server_pedantic_rejects_invalid_publish_subject() + { + using var pub = await ConnectClientAsync(); + using var sub = await ConnectClientAsync(); + + // Read INFO from both + await ReadLineAsync(pub); + await ReadLineAsync(sub); + + // Connect with pedantic mode ON + await pub.SendAsync(Encoding.ASCII.GetBytes( + "CONNECT {\"pedantic\":true}\r\nPING\r\n")); + var pong = await ReadUntilAsync(pub, "PONG"); + + // Subscribe on sub + await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo.* 1\r\nPING\r\n")); + await ReadUntilAsync(sub, "PONG"); + + // PUB with wildcard subject (invalid for publish) + await pub.SendAsync(Encoding.ASCII.GetBytes("PUB foo.* 5\r\nHello\r\n")); + + // Publisher should get -ERR + var errResponse = await ReadUntilAsync(pub, "-ERR", timeoutMs: 3000); + errResponse.ShouldContain("-ERR 'Invalid Publish Subject'"); + } + + [Fact] + public async Task Server_nonpedantic_allows_wildcard_publish_subject() + { + using var pub = await ConnectClientAsync(); + using var sub = await ConnectClientAsync(); + + await ReadLineAsync(pub); + await ReadLineAsync(sub); + + // Connect without pedantic mode (default) + await sub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nSUB foo.* 1\r\nPING\r\n")); + await ReadUntilAsync(sub, "PONG"); + + await pub.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\nPUB foo.* 5\r\nHello\r\n")); + + // Sub should still receive the message (no validation in non-pedantic mode) + var msg = await ReadUntilAsync(sub, "Hello\r\n"); + msg.ShouldContain("MSG foo.* 1 5\r\nHello\r\n"); + } + + [Fact] + public async Task Server_rejects_max_payload_violation() + { + // Create server with tiny max payload + var port = GetFreePort(); + using var cts = new CancellationTokenSource(); + var server = new NatsServer(new NatsOptions { Port = port, MaxPayload = 10 }, NullLoggerFactory.Instance); + _ = server.StartAsync(cts.Token); + await server.WaitForReadyAsync(); + + try + { + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(IPAddress.Loopback, port); + + var buf = new byte[4096]; + await client.ReceiveAsync(buf, SocketFlags.None); // INFO + + await client.SendAsync(Encoding.ASCII.GetBytes("CONNECT {}\r\n")); + + // Send PUB with payload larger than MaxPayload (10 bytes) + await client.SendAsync(Encoding.ASCII.GetBytes("PUB foo 20\r\n12345678901234567890\r\n")); + + using var readCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + var n = await client.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + var response = Encoding.ASCII.GetString(buf, 0, n); + response.ShouldContain("-ERR 'Maximum Payload Violation'"); + + // Connection should be closed + n = await client.ReceiveAsync(buf, SocketFlags.None, readCts.Token); + n.ShouldBe(0); + + client.Dispose(); + } + finally + { + await cts.CancelAsync(); + server.Dispose(); + } + } } public class MaxConnectionsTests : IAsyncLifetime