using System.Text; using NATS.Server.WebSocket; namespace NATS.Server.Transport.Tests.WebSocket; public class WsUpgradeTests { private static string BuildValidRequest(string path = "/", string? extraHeaders = null) { var sb = new StringBuilder(); sb.Append($"GET {path} HTTP/1.1\r\n"); sb.Append("Host: localhost:4222\r\n"); sb.Append("Upgrade: websocket\r\n"); sb.Append("Connection: Upgrade\r\n"); sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"); sb.Append("Sec-WebSocket-Version: 13\r\n"); if (extraHeaders != null) sb.Append(extraHeaders); sb.Append("\r\n"); return sb.ToString(); } [Fact] public async Task ValidUpgrade_Returns101() { var request = BuildValidRequest(); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Kind.ShouldBe(WsClientKind.Client); var response = ReadResponse(outputStream); response.ShouldContain("HTTP/1.1 101"); response.ShouldContain("Upgrade: websocket"); response.ShouldContain("Sec-WebSocket-Accept:"); } [Fact] public async Task MissingUpgradeHeader_Returns400() { var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); ReadResponse(outputStream).ShouldContain("400"); } [Fact] public async Task MissingHost_Returns400() { var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); } [Fact] public async Task WrongVersion_Returns400() { var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); } [Fact] public async Task LeafNodePath_ReturnsLeafKind() { var request = BuildValidRequest("/leafnode"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Kind.ShouldBe(WsClientKind.Leaf); } [Fact] public async Task MqttPath_ReturnsMqttKind() { var request = BuildValidRequest("/mqtt"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Kind.ShouldBe(WsClientKind.Mqtt); } [Fact] public async Task CompressionNegotiation_WhenEnabled() { var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); result.Success.ShouldBeTrue(); result.Compress.ShouldBeTrue(); ReadResponse(outputStream).ShouldContain("permessage-deflate"); } [Fact] public async Task CompressionNegotiation_WhenDisabled() { var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false }); result.Success.ShouldBeTrue(); result.Compress.ShouldBeFalse(); } [Fact] public async Task NoMaskingHeader_ForLeaf() { var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.MaskRead.ShouldBeFalse(); } [Fact] public async Task BrowserDetection_Mozilla() { var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.Browser.ShouldBeTrue(); } [Fact] public async Task SafariDetection_NoCompFrag() { var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" + $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true }); result.Success.ShouldBeTrue(); result.NoCompFrag.ShouldBeTrue(); } [Fact] public void AcceptKey_MatchesRfc6455Example() { // RFC 6455 Section 4.2.2 example var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } [Fact] public async Task CookieExtraction() { var request = BuildValidRequest(extraHeaders: "Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt_token", UsernameCookie = "nats_user", PasswordCookie = "nats_pass", }; var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts); result.Success.ShouldBeTrue(); result.CookieJwt.ShouldBe("my-jwt"); result.CookieUsername.ShouldBe("admin"); result.CookiePassword.ShouldBe("secret"); } [Fact] public async Task XForwardedFor_ExtractsClientIp() { var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n"); var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeTrue(); result.ClientIp.ShouldBe("192.168.1.100"); } [Fact] public async Task PostMethod_Returns405() { var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n"; var (inputStream, outputStream) = CreateStreamPair(request); var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true }); result.Success.ShouldBeFalse(); ReadResponse(outputStream).ShouldContain("405"); } // Helper: create a readable input stream and writable output stream private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest) { var inputBytes = Encoding.ASCII.GetBytes(httpRequest); return (new MemoryStream(inputBytes), new MemoryStream()); } private static string ReadResponse(MemoryStream output) { output.Position = 0; return Encoding.ASCII.GetString(output.ToArray()); } }