diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.WebSocket.cs b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.WebSocket.cs index 5d0406c..6200289 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.WebSocket.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/NatsServer.WebSocket.cs @@ -82,6 +82,15 @@ public sealed partial class NatsServer internal static Exception WsReturnHTTPError(int statusCode, string message) => new WsHttpError(statusCode, message); + private static bool wsHeaderContains(NameValueCollection headers, string key, string expected) => + WsHeaderContains(headers, key, expected); + + private static (bool supported, bool noContext) wsPMCExtensionSupport(NameValueCollection headers, bool checkNoContextTakeOver) => + WsPMCExtensionSupport(headers, checkNoContextTakeOver); + + private static Exception wsReturnHTTPError(int statusCode, string message) => + WsReturnHTTPError(statusCode, message); + internal static string WsGetHostAndPort(string hostPort, out int port) { port = 0; @@ -103,6 +112,9 @@ public sealed partial class NatsServer return hostPort; } + private static string wsGetHostAndPort(string hostPort, out int port) => + WsGetHostAndPort(hostPort, out port); + internal static byte[] WsMakeChallengeKey(string key) { ArgumentNullException.ThrowIfNull(key); @@ -116,6 +128,10 @@ public sealed partial class NatsServer return Convert.ToBase64String(digest); } + private static byte[] wsMakeChallengeKey(string key) => WsMakeChallengeKey(key); + + private static string wsAcceptKey(string key) => WsAcceptKey(key); + internal static Exception? ValidateWebsocketOptions(WebsocketOpts options) { if (options.Port < 0 || options.Port > 65535) @@ -127,6 +143,8 @@ public sealed partial class NatsServer return null; } + private static Exception? validateWebsocketOptions(WebsocketOpts options) => ValidateWebsocketOptions(options); + private void WsSetOriginOptions() { lock (_websocket.Mu) @@ -304,4 +322,11 @@ public sealed partial class NatsServer return false; return uri.Scheme.Equals(WsConstants.SchemePrefixTls, StringComparison.OrdinalIgnoreCase); } + + private static bool isWSURL(string url) => IsWSURL(url); + + private static bool isWSSURL(string url) => IsWSSURL(url); + + private System.Net.Security.SslServerAuthenticationOptions? wsGetTLSConfig() => + GetOpts().Websocket.TlsConfig; } diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketHandler.cs b/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketHandler.cs index 6930ca0..d09d4e8 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketHandler.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketHandler.cs @@ -25,4 +25,31 @@ internal sealed class WebSocketHandler public static byte[] WsCreateCloseMessage(int status, string body) => WebSocketHelpers.WsCreateCloseMessage(status, body); + + public static bool WsHeaderContains(System.Collections.Specialized.NameValueCollection headers, string key, string expected) => + NatsServer.WsHeaderContains(headers, key, expected); + + public static (bool supported, bool noContext) WsPMCExtensionSupport(System.Collections.Specialized.NameValueCollection headers, bool checkNoContextTakeOver) => + NatsServer.WsPMCExtensionSupport(headers, checkNoContextTakeOver); + + public static Exception WsReturnHTTPError(int statusCode, string message) => + NatsServer.WsReturnHTTPError(statusCode, message); + + public static string WsGetHostAndPort(string hostPort, out int port) => + NatsServer.WsGetHostAndPort(hostPort, out port); + + public static string WsAcceptKey(string key) => + NatsServer.WsAcceptKey(key); + + public static byte[] WsMakeChallengeKey(string key) => + NatsServer.WsMakeChallengeKey(key); + + public static Exception? ValidateWebsocketOptions(WebsocketOpts options) => + NatsServer.ValidateWebsocketOptions(options); + + public static bool IsWSURL(string url) => + NatsServer.IsWSURL(url); + + public static bool IsWSSURL(string url) => + NatsServer.IsWSSURL(url); } diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketTypes.cs b/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketTypes.cs index 7d41f5d..6fd4a64 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketTypes.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/WebSocket/WebSocketTypes.cs @@ -203,6 +203,31 @@ internal sealed class SrvWebsocket public bool Compression { get; set; } public string Host { get; set; } = string.Empty; public int Port { get; set; } + + public Exception? checkOrigin(string requestHost, string? origin) + { + if (!SameOrigin && AllowedOrigins.Count == 0) + return null; + if (string.IsNullOrWhiteSpace(origin)) + return new InvalidOperationException("origin header missing"); + if (!Uri.TryCreate(origin, UriKind.Absolute, out var uri)) + return new InvalidOperationException("invalid origin"); + if (SameOrigin && !string.Equals(uri.Host, requestHost, StringComparison.OrdinalIgnoreCase)) + return new InvalidOperationException("origin not same as host"); + if (AllowedOrigins.Count == 0) + return null; + if (!AllowedOrigins.TryGetValue(uri.Host, out var allowed)) + return new InvalidOperationException("origin host not allowed"); + + var port = uri.IsDefaultPort + ? uri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase) ? "443" : "80" + : uri.Port.ToString(); + if (!string.Equals(allowed.Scheme, uri.Scheme, StringComparison.OrdinalIgnoreCase)) + return new InvalidOperationException("origin scheme not allowed"); + if (!string.Equals(allowed.Port, port, StringComparison.Ordinal)) + return new InvalidOperationException("origin port not allowed"); + return null; + } } internal readonly record struct AllowedOrigin(string Scheme, string Port); diff --git a/porting.db b/porting.db index dd42a6d..baa47b0 100644 Binary files a/porting.db and b/porting.db differ