feat: add WebSocket origin checker
This commit is contained in:
81
src/NATS.Server/WebSocket/WsOriginChecker.cs
Normal file
81
src/NATS.Server/WebSocket/WsOriginChecker.cs
Normal file
@@ -0,0 +1,81 @@
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Validates WebSocket Origin headers per RFC 6455 Section 10.2.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 933-1000.
|
||||
/// </summary>
|
||||
public sealed class WsOriginChecker
|
||||
{
|
||||
private readonly bool _sameOrigin;
|
||||
private readonly Dictionary<string, AllowedOrigin>? _allowedOrigins;
|
||||
|
||||
public WsOriginChecker(bool sameOrigin, List<string>? allowedOrigins)
|
||||
{
|
||||
_sameOrigin = sameOrigin;
|
||||
if (allowedOrigins is { Count: > 0 })
|
||||
{
|
||||
_allowedOrigins = new Dictionary<string, AllowedOrigin>(StringComparer.OrdinalIgnoreCase);
|
||||
foreach (var ao in allowedOrigins)
|
||||
{
|
||||
if (Uri.TryCreate(ao, UriKind.Absolute, out var uri))
|
||||
{
|
||||
var (host, port) = GetHostAndPort(uri.Scheme == "https", uri.Host, uri.Port);
|
||||
_allowedOrigins[host] = new AllowedOrigin(uri.Scheme, port);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Returns null if origin is allowed, or an error message if rejected.
|
||||
/// </summary>
|
||||
public string? CheckOrigin(string? origin, string requestHost, bool isTls)
|
||||
{
|
||||
if (!_sameOrigin && _allowedOrigins == null)
|
||||
return null;
|
||||
|
||||
if (string.IsNullOrEmpty(origin))
|
||||
return null;
|
||||
|
||||
if (!Uri.TryCreate(origin, UriKind.Absolute, out var originUri))
|
||||
return $"invalid origin: {origin}";
|
||||
|
||||
var (oh, op) = GetHostAndPort(originUri.Scheme == "https", originUri.Host, originUri.Port);
|
||||
|
||||
if (_sameOrigin)
|
||||
{
|
||||
var (rh, rp) = ParseHostPort(requestHost, isTls);
|
||||
if (!string.Equals(oh, rh, StringComparison.OrdinalIgnoreCase) || op != rp)
|
||||
return "not same origin";
|
||||
}
|
||||
|
||||
if (_allowedOrigins != null)
|
||||
{
|
||||
if (!_allowedOrigins.TryGetValue(oh, out var allowed) ||
|
||||
!string.Equals(originUri.Scheme, allowed.Scheme, StringComparison.OrdinalIgnoreCase) ||
|
||||
op != allowed.Port)
|
||||
{
|
||||
return "not in the allowed list";
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private static (string host, int port) GetHostAndPort(bool tls, string host, int port)
|
||||
{
|
||||
if (port <= 0)
|
||||
port = tls ? 443 : 80;
|
||||
return (host.ToLowerInvariant(), port);
|
||||
}
|
||||
|
||||
private static (string host, int port) ParseHostPort(string hostPort, bool isTls)
|
||||
{
|
||||
var colonIdx = hostPort.LastIndexOf(':');
|
||||
if (colonIdx > 0 && int.TryParse(hostPort.AsSpan(colonIdx + 1), out var port))
|
||||
return (hostPort[..colonIdx].ToLowerInvariant(), port);
|
||||
return (hostPort.ToLowerInvariant(), isTls ? 443 : 80);
|
||||
}
|
||||
|
||||
private readonly record struct AllowedOrigin(string Scheme, int Port);
|
||||
}
|
||||
Reference in New Issue
Block a user