diff --git a/src/NATS.Server/WebSocket/WsOriginChecker.cs b/src/NATS.Server/WebSocket/WsOriginChecker.cs new file mode 100644 index 0000000..c11d1ce --- /dev/null +++ b/src/NATS.Server/WebSocket/WsOriginChecker.cs @@ -0,0 +1,81 @@ +namespace NATS.Server.WebSocket; + +/// +/// Validates WebSocket Origin headers per RFC 6455 Section 10.2. +/// Ported from golang/nats-server/server/websocket.go lines 933-1000. +/// +public sealed class WsOriginChecker +{ + private readonly bool _sameOrigin; + private readonly Dictionary? _allowedOrigins; + + public WsOriginChecker(bool sameOrigin, List? allowedOrigins) + { + _sameOrigin = sameOrigin; + if (allowedOrigins is { Count: > 0 }) + { + _allowedOrigins = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var ao in allowedOrigins) + { + if (Uri.TryCreate(ao, UriKind.Absolute, out var uri)) + { + var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port); + _allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port); + } + } + } + } + + /// + /// Returns null if origin is allowed, or an error message if rejected. + /// + public string? CheckOrigin(string? origin, string requestHost, bool isTls) + { + if (!_sameOrigin && _allowedOrigins == null) + return null; + + if (string.IsNullOrEmpty(origin)) + return null; + + if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri)) + return $"invalid origin: {origin}"; + + var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port); + + if (_sameOrigin) + { + var (rh, rp) = ParseHostPort(requestHost, isTls); + if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp) + return "not same origin"; + } + + if (_allowedOrigins != null) + { + if (!_allowedOrigins.TryGetValue(oh, out var allowed) || + !string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) || + op != allowed.Port) + { + return "not in the allowed list"; + } + } + + return null; + } + + private static (string host, int port) GetHostAndPort(bool tls, string host, int port) + { + if (port <= 0) + port = tls ? 443 : 80; + return (host.ToLowerInvariant(), port); + } + + private static (string host, int port) ParseHostPort(string hostPort, bool isTls) + { + var colonIdx = hostPort.LastIndexOf(':'); + if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port)) + return (hostPort[..colonIdx].ToLowerInvariant(), port); + return (hostPort.ToLowerInvariant(), isTls ? 443 : 80); + } + + private readonly record struct AllowedOrigin(string Scheme, int Port); +} diff --git a/tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs b/tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs new file mode 100644 index 0000000..ebd3531 --- /dev/null +++ b/tests/NATS.Server.Tests/WebSocket/WsOriginCheckerTests.cs @@ -0,0 +1,82 @@ +using NATS.Server.WebSocket; +using Shouldly; + +namespace NATS.Server.Tests.WebSocket; + +public class WsOriginCheckerTests +{ + [Fact] + public void NoOriginHeader_Accepted() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false) + .ShouldBeNull(); + } + + [Fact] + public void NeitherSameNorList_AlwaysAccepted() + { + var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null); + checker.CheckOrigin("https://evil.com", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_Match() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://localhost:4222", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_Mismatch() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://other:4222", "localhost:4222", false) + .ShouldNotBeNull(); + } + + [Fact] + public void SameOrigin_DefaultPort_Http() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("http://localhost", "localhost:80", false) + .ShouldBeNull(); + } + + [Fact] + public void SameOrigin_DefaultPort_Https() + { + var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null); + checker.CheckOrigin("https://localhost", "localhost:443", true) + .ShouldBeNull(); + } + + [Fact] + public void AllowedOrigins_Match() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("https://app.example.com", "localhost:4222", false) + .ShouldBeNull(); + } + + [Fact] + public void AllowedOrigins_Mismatch() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("https://evil.example.com", "localhost:4222", false) + .ShouldNotBeNull(); + } + + [Fact] + public void AllowedOrigins_SchemeMismatch() + { + var checker = new WsOriginChecker(sameOrigin: false, + allowedOrigins: ["https://app.example.com"]); + checker.CheckOrigin("http://app.example.com", "localhost:4222", false) + .ShouldNotBeNull(); + } +}