refactor: extract NATS.Server.Transport.Tests project

Move TLS, OCSP, WebSocket, Networking, and IO test files from
NATS.Server.Tests into a dedicated NATS.Server.Transport.Tests
project. Update namespaces, replace private GetFreePort/ReadUntilAsync
with shared TestUtilities helpers, extract TestCertHelper to
TestUtilities, and replace Task.Delay polling loops with
PollHelper.WaitUntilAsync/YieldForAsync for proper synchronization.
This commit is contained in:
Joseph Doherty
2026-03-12 14:57:35 -04:00
parent 5c608f07e3
commit d2c04fcca5
36 changed files with 157 additions and 152 deletions

View File

@@ -0,0 +1,50 @@
using Shouldly;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WebSocketOptionsTests
{
[Fact]
public void DefaultOptions_PortIsNegativeOne_Disabled()
{
var opts = new WebSocketOptions();
opts.Port.ShouldBe(-1);
opts.Host.ShouldBe("0.0.0.0");
opts.Compression.ShouldBeFalse();
opts.NoTls.ShouldBeFalse();
opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2));
opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2));
}
[Fact]
public void NatsOptions_HasWebSocketProperty()
{
var opts = new NatsOptions();
opts.WebSocket.ShouldNotBeNull();
opts.WebSocket.Port.ShouldBe(-1);
}
[Fact]
public void WsAuthConfig_sets_auth_override_when_websocket_auth_fields_are_present()
{
var ws = new WebSocketOptions
{
Username = "u",
};
WsAuthConfig.Apply(ws);
ws.AuthOverride.ShouldBeTrue();
}
[Fact]
public void WsAuthConfig_keeps_auth_override_false_when_no_ws_auth_fields_are_present()
{
var ws = new WebSocketOptions();
WsAuthConfig.Apply(ws);
ws.AuthOverride.ShouldBeFalse();
}
}

View File

@@ -0,0 +1,172 @@
using NATS.Server.Auth;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WebSocketOptionsValidatorParityBatch2Tests
{
[Fact]
public void Validate_rejects_tls_listener_without_cert_key_when_not_no_tls()
{
var opts = new NatsOptions
{
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = false,
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("TLS", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_rejects_invalid_allowed_origins()
{
var opts = new NatsOptions
{
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
AllowedOrigins = ["not-a-uri"],
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("allowed origin", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_rejects_no_auth_user_not_present_in_configured_users()
{
var opts = new NatsOptions
{
Users = [new User { Username = "alice", Password = "x" }],
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
NoAuthUser = "bob",
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("NoAuthUser", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_rejects_username_or_token_when_users_or_nkeys_are_set()
{
var opts = new NatsOptions
{
Users = [new User { Username = "alice", Password = "x" }],
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
Username = "ws-user",
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("users", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_rejects_jwt_cookie_without_trusted_operators()
{
var opts = new NatsOptions
{
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
JwtCookie = "jwt",
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("JwtCookie", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_rejects_reserved_response_headers_override()
{
var opts = new NatsOptions
{
TrustedKeys = ["OP1"],
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
Headers = new Dictionary<string, string>
{
["Sec-WebSocket-Accept"] = "bad",
},
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("reserved", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_rejects_tls_pinned_certs_when_websocket_tls_is_disabled()
{
var opts = new NatsOptions
{
TlsPinnedCerts = ["ABCDEF0123"],
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeFalse();
result.Errors.ShouldContain(e => e.Contains("TLSPinnedCerts", StringComparison.OrdinalIgnoreCase));
}
[Fact]
public void Validate_accepts_valid_minimal_configuration()
{
var opts = new NatsOptions
{
TrustedKeys = ["OP1"],
Users = [new User { Username = "alice", Password = "x" }],
WebSocket = new WebSocketOptions
{
Port = 8080,
NoTls = true,
NoAuthUser = "alice",
AllowedOrigins = ["https://app.example.com"],
JwtCookie = "jwt",
Headers = new Dictionary<string, string>
{
["X-App-Version"] = "1",
},
},
};
var result = WebSocketOptionsValidator.Validate(opts);
result.IsValid.ShouldBeTrue();
result.Errors.ShouldBeEmpty();
}
}

View File

@@ -0,0 +1,119 @@
// Go reference: server/websocket.go — wsTLSConfig and related TLS handling.
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Transport.Tests.WebSocket;
/// <summary>
/// Tests for <see cref="WebSocketTlsConfig"/> — WebSocket-specific TLS configuration
/// with separate cert/key from the main NATS listener (Gap 15.1).
/// </summary>
public class WebSocketTlsTests
{
// ── IsConfigured ──────────────────────────────────────────────────────────
[Fact]
public void IsConfigured_WithCert_ReturnsTrue()
{
var cfg = new WebSocketTlsConfig { CertFile = "server.pem", KeyFile = "server-key.pem" };
cfg.IsConfigured.ShouldBeTrue();
}
[Fact]
public void IsConfigured_WithoutCert_ReturnsFalse()
{
var cfg = new WebSocketTlsConfig();
cfg.IsConfigured.ShouldBeFalse();
}
// ── Validate ──────────────────────────────────────────────────────────────
[Fact]
public void Validate_ValidConfig_NoErrors()
{
var cfg = new WebSocketTlsConfig { CertFile = "server.pem", KeyFile = "server-key.pem" };
var result = cfg.Validate();
result.IsValid.ShouldBeTrue();
result.Errors.ShouldBeEmpty();
}
[Fact]
public void Validate_KeyWithoutCert_HasError()
{
var cfg = new WebSocketTlsConfig { KeyFile = "server-key.pem" };
var result = cfg.Validate();
result.IsValid.ShouldBeFalse();
result.Errors.ShouldNotBeEmpty();
}
[Fact]
public void Validate_CertWithoutKey_HasError()
{
var cfg = new WebSocketTlsConfig { CertFile = "server.pem" };
var result = cfg.Validate();
result.IsValid.ShouldBeFalse();
result.Errors.ShouldNotBeEmpty();
}
[Fact]
public void Validate_EmptyConfig_Valid()
{
// An empty configuration means "no TLS" — that is a valid state.
var cfg = new WebSocketTlsConfig();
var result = cfg.Validate();
result.IsValid.ShouldBeTrue();
result.Errors.ShouldBeEmpty();
}
// ── HasChangedFrom ────────────────────────────────────────────────────────
[Fact]
public void HasChangedFrom_SameConfig_ReturnsFalse()
{
var a = new WebSocketTlsConfig
{
CertFile = "server.pem",
KeyFile = "server-key.pem",
CaFile = "ca.pem",
RequireClientCert = true,
InsecureSkipVerify = false,
};
var b = new WebSocketTlsConfig
{
CertFile = "server.pem",
KeyFile = "server-key.pem",
CaFile = "ca.pem",
RequireClientCert = true,
InsecureSkipVerify = false,
};
a.HasChangedFrom(b).ShouldBeFalse();
}
[Fact]
public void HasChangedFrom_DifferentCert_ReturnsTrue()
{
var a = new WebSocketTlsConfig { CertFile = "old.pem", KeyFile = "old-key.pem" };
var b = new WebSocketTlsConfig { CertFile = "new.pem", KeyFile = "new-key.pem" };
a.HasChangedFrom(b).ShouldBeTrue();
}
[Fact]
public void HasChangedFrom_NullOther_ReturnsTrue()
{
var cfg = new WebSocketTlsConfig { CertFile = "server.pem", KeyFile = "server-key.pem" };
cfg.HasChangedFrom(null).ShouldBeTrue();
}
// ── Defaults ──────────────────────────────────────────────────────────────
[Fact]
public void RequireClientCert_Default_False()
{
var cfg = new WebSocketTlsConfig();
cfg.RequireClientCert.ShouldBeFalse();
}
}

View File

@@ -0,0 +1,327 @@
// Tests for WebSocket permessage-deflate parameter negotiation (E10).
// Verifies RFC 7692 extension parameter parsing and negotiation during
// WebSocket upgrade handshake.
// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport (line 885).
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsCompressionNegotiationTests
{
// ─── WsDeflateNegotiator.Negotiate tests ──────────────────────────────
[Fact]
public void Negotiate_NullHeader_ReturnsNull()
{
// Go parity: wsPMCExtensionSupport — no extension header means no compression
var result = WsDeflateNegotiator.Negotiate(null);
result.ShouldBeNull();
}
[Fact]
public void Negotiate_EmptyHeader_ReturnsNull()
{
var result = WsDeflateNegotiator.Negotiate("");
result.ShouldBeNull();
}
[Fact]
public void Negotiate_NoPermessageDeflate_ReturnsNull()
{
var result = WsDeflateNegotiator.Negotiate("x-webkit-deflate-frame");
result.ShouldBeNull();
}
[Fact]
public void Negotiate_BarePermessageDeflate_ReturnsDefaults()
{
// Go parity: wsPMCExtensionSupport — basic extension without parameters
var result = WsDeflateNegotiator.Negotiate("permessage-deflate");
result.ShouldNotBeNull();
// NATS always enforces no_context_takeover
result.Value.ServerNoContextTakeover.ShouldBeTrue();
result.Value.ClientNoContextTakeover.ShouldBeTrue();
result.Value.ServerMaxWindowBits.ShouldBe(15);
result.Value.ClientMaxWindowBits.ShouldBe(15);
}
[Fact]
public void Negotiate_WithServerNoContextTakeover()
{
// Go parity: wsPMCExtensionSupport — server_no_context_takeover parameter
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_no_context_takeover");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
}
[Fact]
public void Negotiate_WithClientNoContextTakeover()
{
// Go parity: wsPMCExtensionSupport — client_no_context_takeover parameter
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_no_context_takeover");
result.ShouldNotBeNull();
result.Value.ClientNoContextTakeover.ShouldBeTrue();
}
[Fact]
public void Negotiate_WithBothNoContextTakeover()
{
// Go parity: wsPMCExtensionSupport — both no_context_takeover parameters
var result = WsDeflateNegotiator.Negotiate(
"permessage-deflate; server_no_context_takeover; client_no_context_takeover");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
result.Value.ClientNoContextTakeover.ShouldBeTrue();
}
[Fact]
public void Negotiate_WithServerMaxWindowBits()
{
// RFC 7692 Section 7.1.2.1: server_max_window_bits parameter
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_max_window_bits=10");
result.ShouldNotBeNull();
result.Value.ServerMaxWindowBits.ShouldBe(10);
}
[Fact]
public void Negotiate_WithClientMaxWindowBits_Value()
{
// RFC 7692 Section 7.1.2.2: client_max_window_bits with explicit value
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits=12");
result.ShouldNotBeNull();
result.Value.ClientMaxWindowBits.ShouldBe(12);
}
[Fact]
public void Negotiate_WithClientMaxWindowBits_NoValue()
{
// RFC 7692 Section 7.1.2.2: client_max_window_bits with no value means
// client supports any value 8-15; defaults to 15
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits");
result.ShouldNotBeNull();
result.Value.ClientMaxWindowBits.ShouldBe(15);
}
[Fact]
public void Negotiate_WindowBits_ClampedToValidRange()
{
// RFC 7692: valid range is 8-15
var result = WsDeflateNegotiator.Negotiate(
"permessage-deflate; server_max_window_bits=5; client_max_window_bits=20");
result.ShouldNotBeNull();
result.Value.ServerMaxWindowBits.ShouldBe(8); // Clamped up from 5
result.Value.ClientMaxWindowBits.ShouldBe(15); // Clamped down from 20
}
[Fact]
public void Negotiate_FullParameters()
{
// All parameters specified
var result = WsDeflateNegotiator.Negotiate(
"permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=9; client_max_window_bits=11");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
result.Value.ClientNoContextTakeover.ShouldBeTrue();
result.Value.ServerMaxWindowBits.ShouldBe(9);
result.Value.ClientMaxWindowBits.ShouldBe(11);
}
[Fact]
public void Negotiate_CaseInsensitive()
{
// RFC 7692 extension names are case-insensitive
var result = WsDeflateNegotiator.Negotiate("Permessage-Deflate; Server_No_Context_Takeover");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
}
[Fact]
public void Negotiate_MultipleExtensions_PicksDeflate()
{
// Header may contain multiple comma-separated extensions
var result = WsDeflateNegotiator.Negotiate(
"x-custom-ext, permessage-deflate; server_no_context_takeover, other-ext");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
}
[Fact]
public void Negotiate_WhitespaceHandling()
{
// Extra whitespace around parameters
var result = WsDeflateNegotiator.Negotiate(
" permessage-deflate ; server_no_context_takeover ; client_max_window_bits = 10 ");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
result.Value.ClientMaxWindowBits.ShouldBe(10);
}
// ─── NatsAlwaysEnforcesNoContextTakeover ─────────────────────────────
[Fact]
public void Negotiate_AlwaysEnforcesNoContextTakeover()
{
// NATS Go server always returns server_no_context_takeover and
// client_no_context_takeover regardless of what the client requests
var result = WsDeflateNegotiator.Negotiate("permessage-deflate");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
result.Value.ClientNoContextTakeover.ShouldBeTrue();
}
// ─── WsDeflateParams.ToResponseHeaderValue tests ────────────────────
[Fact]
public void DefaultParams_ResponseHeader_ContainsNoContextTakeover()
{
var header = WsDeflateParams.Default.ToResponseHeaderValue();
header.ShouldContain("permessage-deflate");
header.ShouldContain("server_no_context_takeover");
header.ShouldContain("client_no_context_takeover");
header.ShouldNotContain("server_max_window_bits");
header.ShouldNotContain("client_max_window_bits");
}
[Fact]
public void CustomWindowBits_ResponseHeader_IncludesValues()
{
var params_ = new WsDeflateParams(
ServerNoContextTakeover: true,
ClientNoContextTakeover: true,
ServerMaxWindowBits: 10,
ClientMaxWindowBits: 12);
var header = params_.ToResponseHeaderValue();
header.ShouldContain("server_max_window_bits=10");
header.ShouldContain("client_max_window_bits=12");
}
[Fact]
public void DefaultWindowBits_ResponseHeader_OmitsValues()
{
// RFC 7692: window bits of 15 is the default and should not be sent
var params_ = new WsDeflateParams(
ServerNoContextTakeover: true,
ClientNoContextTakeover: true,
ServerMaxWindowBits: 15,
ClientMaxWindowBits: 15);
var header = params_.ToResponseHeaderValue();
header.ShouldNotContain("server_max_window_bits");
header.ShouldNotContain("client_max_window_bits");
}
// ─── Integration with WsUpgrade ─────────────────────────────────────
[Fact]
public async Task Upgrade_WithDeflateParams_NegotiatesCompression()
{
// Go parity: WebSocket upgrade with permessage-deflate parameters
var request = BuildValidRequest(extraHeaders:
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=10\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, Compression = true };
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.Compress.ShouldBeTrue();
result.DeflateParams.ShouldNotBeNull();
result.DeflateParams.Value.ServerNoContextTakeover.ShouldBeTrue();
result.DeflateParams.Value.ClientNoContextTakeover.ShouldBeTrue();
result.DeflateParams.Value.ServerMaxWindowBits.ShouldBe(10);
}
[Fact]
public async Task Upgrade_WithDeflateParams_ResponseIncludesNegotiatedParams()
{
var request = BuildValidRequest(extraHeaders:
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_max_window_bits=10\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, Compression = true };
await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
var response = ReadResponse(outputStream);
response.ShouldContain("permessage-deflate");
response.ShouldContain("server_no_context_takeover");
response.ShouldContain("client_no_context_takeover");
response.ShouldContain("client_max_window_bits=10");
}
[Fact]
public async Task Upgrade_CompressionDisabled_NoDeflateParams()
{
var request = BuildValidRequest(extraHeaders:
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, Compression = false };
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.Compress.ShouldBeFalse();
result.DeflateParams.ShouldBeNull();
}
[Fact]
public async Task Upgrade_NoExtensionHeader_NoCompression()
{
var request = BuildValidRequest();
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, Compression = true };
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.Compress.ShouldBeFalse();
result.DeflateParams.ShouldBeNull();
}
// ─── Helpers ─────────────────────────────────────────────────────────
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
private static string ReadResponse(MemoryStream output)
{
output.Position = 0;
return Encoding.ASCII.GetString(output.ToArray());
}
}

View File

@@ -0,0 +1,58 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsCompressionTests
{
[Fact]
public void CompressDecompress_RoundTrip()
{
var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray();
var compressed = WsCompression.Compress(original);
compressed.ShouldNotBeNull();
compressed.Length.ShouldBeGreaterThan(0);
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
decompressed.ShouldBe(original);
}
[Fact]
public void Decompress_ExceedsMaxPayload_Throws()
{
var original = new byte[1000];
Random.Shared.NextBytes(original);
var compressed = WsCompression.Compress(original);
Should.Throw<InvalidOperationException>(() =>
WsCompression.Decompress([compressed], maxPayload: 100));
}
[Fact]
public void Compress_RemovesTrailing4Bytes()
{
var data = new byte[200];
Random.Shared.NextBytes(data);
var compressed = WsCompression.Compress(data);
// The compressed data should be valid for decompression when we add the trailer back
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
decompressed.ShouldBe(data);
}
[Fact]
public void Decompress_MultipleBuffers()
{
var original = new byte[500];
Random.Shared.NextBytes(original);
var compressed = WsCompression.Compress(original);
// Split compressed data into multiple chunks
int mid = compressed.Length / 2;
var chunk1 = compressed[..mid];
var chunk2 = compressed[mid..];
var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096);
decompressed.ShouldBe(original);
}
}

View File

@@ -0,0 +1,124 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsConnectionTests
{
[Fact]
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
{
var payload = "SUB test 1\r\n"u8.ToArray();
var frame = BuildUnmaskedFrame(payload);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(payload.Length);
buf[..n].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_FramesPayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// First 2 bytes should be WS frame header
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
int len = written[1] & 0x7F;
len.ShouldBe(payload.Length);
written[2..].ShouldBe(payload);
}
[Fact]
public async Task WriteAsync_WithCompression_CompressesLargePayload()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = new byte[200];
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should be set for compressed frame
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
// Compressed size should be less than original
written.Length.ShouldBeLessThan(payload.Length + 10);
}
[Fact]
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
{
var inner = new MemoryStream();
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
await ws.WriteAsync(payload);
await ws.FlushAsync();
inner.Position = 0;
var written = inner.ToArray();
// RSV1 bit should NOT be set for small payloads
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
}
[Fact]
public async Task ReadAsync_DecodesMaskedFrame()
{
var payload = "CONNECT {}\r\n"u8.ToArray();
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
var frame = new byte[header.Length + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, header.Length);
var inner = new MemoryStream(frame);
var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe("CONNECT {}\r\n".Length);
System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n");
}
[Fact]
public async Task ReadAsync_ReturnsZero_OnEndOfStream()
{
// Empty stream should return 0 (true end of stream)
var inner = new MemoryStream([]);
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
var buf = new byte[256];
int n = await ws.ReadAsync(buf);
n.ShouldBe(0);
}
private static byte[] BuildUnmaskedFrame(byte[] payload)
{
var header = new byte[2];
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
header[1] = (byte)payload.Length;
var frame = new byte[2 + payload.Length];
header.CopyTo(frame, 0);
payload.CopyTo(frame, 2);
return frame;
}
}

View File

@@ -0,0 +1,53 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsConstantsTests
{
[Fact]
public void OpCodes_MatchRfc6455()
{
WsConstants.TextMessage.ShouldBe(1);
WsConstants.BinaryMessage.ShouldBe(2);
WsConstants.CloseMessage.ShouldBe(8);
WsConstants.PingMessage.ShouldBe(9);
WsConstants.PongMessage.ShouldBe(10);
}
[Fact]
public void FrameBits_MatchRfc6455()
{
WsConstants.FinalBit.ShouldBe((byte)0x80);
WsConstants.Rsv1Bit.ShouldBe((byte)0x40);
WsConstants.MaskBit.ShouldBe((byte)0x80);
}
[Fact]
public void CloseStatusCodes_MatchRfc6455()
{
WsConstants.CloseStatusNormalClosure.ShouldBe(1000);
WsConstants.CloseStatusGoingAway.ShouldBe(1001);
WsConstants.CloseStatusProtocolError.ShouldBe(1002);
WsConstants.CloseStatusPolicyViolation.ShouldBe(1008);
WsConstants.CloseStatusMessageTooBig.ShouldBe(1009);
}
[Theory]
[InlineData(WsConstants.CloseMessage)]
[InlineData(WsConstants.PingMessage)]
[InlineData(WsConstants.PongMessage)]
public void IsControlFrame_True(int opcode)
{
WsConstants.IsControlFrame(opcode).ShouldBeTrue();
}
[Theory]
[InlineData(WsConstants.TextMessage)]
[InlineData(WsConstants.BinaryMessage)]
[InlineData(0)]
public void IsControlFrame_False(int opcode)
{
WsConstants.IsControlFrame(opcode).ShouldBeFalse();
}
}

View File

@@ -0,0 +1,163 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsFrameReadTests
{
/// <summary>Helper: build a single unmasked binary frame.</summary>
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
{
int payloadLen = payload.Length;
byte b0 = (byte)opcode;
if (fin) b0 |= WsConstants.FinalBit;
if (compressed) b0 |= WsConstants.Rsv1Bit;
byte b1 = 0;
if (mask) b1 |= WsConstants.MaskBit;
byte[] lenBytes;
if (payloadLen <= 125)
{
lenBytes = [(byte)(b1 | (byte)payloadLen)];
}
else if (payloadLen < 65536)
{
lenBytes = new byte[3];
lenBytes[0] = (byte)(b1 | 126);
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
}
else
{
lenBytes = new byte[9];
lenBytes[0] = (byte)(b1 | 127);
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
}
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
var frame = new byte[totalLen];
frame[0] = b0;
lenBytes.CopyTo(frame.AsSpan(1));
int pos = 1 + lenBytes.Length;
if (mask && maskKey != null)
{
maskKey.CopyTo(frame.AsSpan(pos));
pos += 4;
var maskedPayload = payload.ToArray();
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
maskedPayload.CopyTo(frame.AsSpan(pos));
}
else
{
payload.CopyTo(frame.AsSpan(pos));
}
return frame;
}
[Fact]
public void ReadSingleUnmaskedFrame()
{
var payload = "Hello"u8.ToArray();
var frame = BuildFrame(payload);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void ReadMaskedFrame()
{
var payload = "Hello"u8.ToArray();
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
var frame = BuildFrame(payload, mask: true, maskKey: key);
var readInfo = new WsReadInfo(expectMask: true);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void Read16BitLengthFrame()
{
var payload = new byte[200];
Random.Shared.NextBytes(payload);
var frame = BuildFrame(payload);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(1);
result[0].ShouldBe(payload);
}
[Fact]
public void ReadPingFrame_ReturnsPongAction()
{
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0); // control frames don't produce payload
readInfo.PendingControlFrames.Count.ShouldBe(1);
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
}
[Fact]
public void ReadCloseFrame_ReturnsCloseAction()
{
var closePayload = new byte[2];
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0);
readInfo.CloseReceived.ShouldBeTrue();
readInfo.CloseStatus.ShouldBe(1000);
}
[Fact]
public void ReadPongFrame_NoAction()
{
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
var readInfo = new WsReadInfo(expectMask: false);
var stream = new MemoryStream(frame);
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
result.Count.ShouldBe(0);
readInfo.PendingControlFrames.Count.ShouldBe(0);
}
[Fact]
public void Unmask_Optimized_8ByteChunks()
{
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
var original = new byte[32];
Random.Shared.NextBytes(original);
var masked = original.ToArray();
// Mask it
for (int i = 0; i < masked.Length; i++)
masked[i] ^= key[i & 3];
// Unmask using the state machine
var info = new WsReadInfo(expectMask: true);
info.SetMaskKey(key);
info.Unmask(masked);
masked.ShouldBe(original);
}
}

View File

@@ -0,0 +1,152 @@
using System.Buffers.Binary;
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsFrameWriterTests
{
[Fact]
public void CreateFrameHeader_SmallPayload_7BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 100);
header.Length.ShouldBe(2);
(header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
(header[1] & 0x7F).ShouldBe(100);
}
[Fact]
public void CreateFrameHeader_MediumPayload_16BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
header.Length.ShouldBe(4);
(header[1] & 0x7F).ShouldBe(126);
BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
}
[Fact]
public void CreateFrameHeader_LargePayload_64BitLength()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 70000);
header.Length.ShouldBe(10);
(header[1] & 0x7F).ShouldBe(127);
BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
}
[Fact]
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
{
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 10);
header.Length.ShouldBe(6); // 2 header + 4 mask key
(header[1] & WsConstants.MaskBit).ShouldNotBe(0);
key.ShouldNotBeNull();
key.Length.ShouldBe(4);
}
[Fact]
public void CreateFrameHeader_Compressed_SetsRsv1Bit()
{
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: true,
opcode: WsConstants.BinaryMessage, payloadLength: 10);
(header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
}
[Fact]
public void MaskBuf_XorsCorrectly()
{
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
byte[] expected = new byte[data.Length];
for (int i = 0; i < data.Length; i++)
expected[i] = (byte)(data[i] ^ key[i & 3]);
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(expected);
}
[Fact]
public void MaskBuf_RoundTrip()
{
byte[] key = [0x12, 0x34, 0x56, 0x78];
byte[] original = "Hello, WebSocket!"u8.ToArray();
var data = original.ToArray();
WsFrameWriter.MaskBuf(key, data);
data.ShouldNotBe(original);
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(original);
}
[Fact]
public void CreateCloseMessage_WithStatusAndBody()
{
var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
msg.Length.ShouldBe(2 + "normal closure".Length);
BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
}
[Fact]
public void CreateCloseMessage_LongBody_Truncated()
{
var longBody = new string('x', 200);
var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
}
[Fact]
public void MapCloseStatus_ClientClosed_NormalClosure()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
.ShouldBe(WsConstants.CloseStatusNormalClosure);
}
[Fact]
public void MapCloseStatus_AuthTimeout_PolicyViolation()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
}
[Fact]
public void MapCloseStatus_ParseError_ProtocolError()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
.ShouldBe(WsConstants.CloseStatusProtocolError);
}
[Fact]
public void MapCloseStatus_MaxPayload_MessageTooBig()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
}
[Fact]
public void BuildControlFrame_PingNomask()
{
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
frame.Length.ShouldBe(2);
(frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
(frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
(frame[1] & 0x7F).ShouldBe(0);
}
[Fact]
public void BuildControlFrame_PongWithPayload()
{
byte[] payload = [1, 2, 3, 4];
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
frame.Length.ShouldBe(2 + 4);
frame[2..].ShouldBe(payload);
}
}

View File

@@ -0,0 +1,782 @@
// Port of Go server/websocket_test.go — WebSocket protocol parity tests.
// Reference: golang/nats-server/server/websocket_test.go
//
// Tests cover: compression negotiation, JWT auth extraction (bearer/cookie/query),
// frame encoding/decoding, origin checking, upgrade handshake, and close messages.
using System.Buffers.Binary;
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
/// <summary>
/// Parity tests ported from Go server/websocket_test.go exercising WebSocket
/// frame encoding, compression negotiation, origin checking, upgrade validation,
/// and JWT authentication extraction.
/// </summary>
public class WsGoParityTests
{
// ========================================================================
// TestWSIsControlFrame
// Go reference: websocket_test.go:TestWSIsControlFrame
// ========================================================================
[Theory]
[InlineData(WsConstants.CloseMessage, true)]
[InlineData(WsConstants.PingMessage, true)]
[InlineData(WsConstants.PongMessage, true)]
[InlineData(WsConstants.TextMessage, false)]
[InlineData(WsConstants.BinaryMessage, false)]
[InlineData(WsConstants.ContinuationFrame, false)]
public void IsControlFrame_CorrectClassification(int opcode, bool expected)
{
// Go: TestWSIsControlFrame websocket_test.go
WsConstants.IsControlFrame(opcode).ShouldBe(expected);
}
// ========================================================================
// TestWSUnmask
// Go reference: websocket_test.go:TestWSUnmask
// ========================================================================
[Fact]
public void Unmask_XorsWithKey()
{
// Go: TestWSUnmask — XOR unmasking with 4-byte key.
var ri = new WsReadInfo(expectMask: true);
var key = new byte[] { 0x12, 0x34, 0x56, 0x78 };
ri.SetMaskKey(key);
var data = new byte[] { 0x12 ^ (byte)'H', 0x34 ^ (byte)'e', 0x56 ^ (byte)'l', 0x78 ^ (byte)'l', 0x12 ^ (byte)'o' };
ri.Unmask(data);
Encoding.ASCII.GetString(data).ShouldBe("Hello");
}
[Fact]
public void Unmask_LargeBuffer_UsesOptimizedPath()
{
// Go: TestWSUnmask — optimized 8-byte chunk path for larger buffers.
var ri = new WsReadInfo(expectMask: true);
var key = new byte[] { 0xAA, 0xBB, 0xCC, 0xDD };
ri.SetMaskKey(key);
// Create a buffer large enough to trigger the optimized path (>= 16 bytes)
var original = new byte[32];
for (int i = 0; i < original.Length; i++)
original[i] = (byte)(i + 1);
// Mask it
var masked = new byte[original.Length];
for (int i = 0; i < masked.Length; i++)
masked[i] = (byte)(original[i] ^ key[i % 4]);
// Unmask
ri.Unmask(masked);
masked.ShouldBe(original);
}
// ========================================================================
// TestWSCreateCloseMessage
// Go reference: websocket_test.go:TestWSCreateCloseMessage
// ========================================================================
[Fact]
public void CreateCloseMessage_StatusAndBody()
{
// Go: TestWSCreateCloseMessage — close message has 2-byte status + body.
var msg = WsFrameWriter.CreateCloseMessage(
WsConstants.CloseStatusNormalClosure, "goodbye");
msg.Length.ShouldBeGreaterThan(2);
var status = BinaryPrimitives.ReadUInt16BigEndian(msg);
status.ShouldBe((ushort)WsConstants.CloseStatusNormalClosure);
Encoding.UTF8.GetString(msg.AsSpan(2)).ShouldBe("goodbye");
}
[Fact]
public void CreateCloseMessage_LongBody_Truncated()
{
// Go: TestWSCreateCloseMessage — body truncated to MaxControlPayloadSize.
var longBody = new string('x', 200);
var msg = WsFrameWriter.CreateCloseMessage(
WsConstants.CloseStatusGoingAway, longBody);
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
// Should end with "..."
var body = Encoding.UTF8.GetString(msg.AsSpan(2));
body.ShouldEndWith("...");
}
// ========================================================================
// TestWSCreateFrameHeader
// Go reference: websocket_test.go:TestWSCreateFrameHeader
// ========================================================================
[Fact]
public void CreateFrameHeader_SmallPayload_2ByteHeader()
{
// Go: TestWSCreateFrameHeader — payload <= 125 uses 2-byte header.
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 50);
header.Length.ShouldBe(2);
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
(header[0] & WsConstants.FinalBit).ShouldBe(WsConstants.FinalBit);
(header[1] & 0x7F).ShouldBe(50);
key.ShouldBeNull();
}
[Fact]
public void CreateFrameHeader_MediumPayload_4ByteHeader()
{
// Go: TestWSCreateFrameHeader — payload 126-65535 uses 4-byte header.
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
header.Length.ShouldBe(4);
(header[1] & 0x7F).ShouldBe(126);
var payloadLen = BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2));
payloadLen.ShouldBe((ushort)1000);
key.ShouldBeNull();
}
[Fact]
public void CreateFrameHeader_LargePayload_10ByteHeader()
{
// Go: TestWSCreateFrameHeader — payload >= 65536 uses 10-byte header.
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 100000);
header.Length.ShouldBe(10);
(header[1] & 0x7F).ShouldBe(127);
var payloadLen = BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2));
payloadLen.ShouldBe(100000UL);
key.ShouldBeNull();
}
[Fact]
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
{
// Go: TestWSCreateFrameHeader — masking adds 4-byte key to header.
var (header, key) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: 50);
header.Length.ShouldBe(6); // 2 base + 4 mask key
(header[1] & WsConstants.MaskBit).ShouldBe(WsConstants.MaskBit);
key.ShouldNotBeNull();
key!.Length.ShouldBe(4);
}
[Fact]
public void CreateFrameHeader_Compressed_SetsRsv1()
{
// Go: TestWSCreateFrameHeader — compressed frames have RSV1 bit set.
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: false, compressed: true,
opcode: WsConstants.BinaryMessage, payloadLength: 50);
(header[0] & WsConstants.Rsv1Bit).ShouldBe(WsConstants.Rsv1Bit);
}
// ========================================================================
// TestWSCheckOrigin
// Go reference: websocket_test.go:TestWSCheckOrigin
// ========================================================================
[Fact]
public void OriginChecker_SameOrigin_Allowed()
{
// Go: TestWSCheckOrigin — same origin passes.
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost:4222", "localhost:4222", isTls: false).ShouldBeNull();
}
[Fact]
public void OriginChecker_SameOrigin_Rejected()
{
// Go: TestWSCheckOrigin — different origin fails.
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
var result = checker.CheckOrigin("http://evil.com", "localhost:4222", isTls: false);
result.ShouldNotBeNull();
result.ShouldContain("not same origin");
}
[Fact]
public void OriginChecker_AllowedList_Allowed()
{
// Go: TestWSCheckOrigin — allowed origins list.
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["http://example.com"]);
checker.CheckOrigin("http://example.com", "localhost:4222", isTls: false).ShouldBeNull();
}
[Fact]
public void OriginChecker_AllowedList_Rejected()
{
// Go: TestWSCheckOrigin — origin not in allowed list.
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["http://example.com"]);
var result = checker.CheckOrigin("http://evil.com", "localhost:4222", isTls: false);
result.ShouldNotBeNull();
result.ShouldContain("not in the allowed list");
}
[Fact]
public void OriginChecker_EmptyOrigin_Allowed()
{
// Go: TestWSCheckOrigin — empty origin (non-browser) is always allowed.
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin(null, "localhost:4222", isTls: false).ShouldBeNull();
checker.CheckOrigin("", "localhost:4222", isTls: false).ShouldBeNull();
}
[Fact]
public void OriginChecker_NoRestrictions_AllAllowed()
{
// Go: no restrictions means all origins pass.
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
checker.CheckOrigin("http://anything.com", "localhost:4222", isTls: false).ShouldBeNull();
}
[Fact]
public void OriginChecker_AllowedWithPort()
{
// Go: TestWSSetOriginOptions — origins with explicit ports.
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["http://example.com:8080"]);
checker.CheckOrigin("http://example.com:8080", "localhost", isTls: false).ShouldBeNull();
checker.CheckOrigin("http://example.com", "localhost", isTls: false).ShouldNotBeNull(); // wrong port
}
// ========================================================================
// TestWSCompressNegotiation
// Go reference: websocket_test.go:TestWSCompressNegotiation
// ========================================================================
[Fact]
public void CompressNegotiation_FullParams()
{
// Go: TestWSCompressNegotiation — full parameter negotiation.
var result = WsDeflateNegotiator.Negotiate(
"permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=10; client_max_window_bits=12");
result.ShouldNotBeNull();
result.Value.ServerNoContextTakeover.ShouldBeTrue();
result.Value.ClientNoContextTakeover.ShouldBeTrue();
result.Value.ServerMaxWindowBits.ShouldBe(10);
result.Value.ClientMaxWindowBits.ShouldBe(12);
}
[Fact]
public void CompressNegotiation_NoExtension_ReturnsNull()
{
// Go: TestWSCompressNegotiation — no permessage-deflate in header.
WsDeflateNegotiator.Negotiate("x-webkit-deflate-frame").ShouldBeNull();
}
// ========================================================================
// WS Upgrade — JWT extraction (bearer, cookie, query parameter)
// Go reference: websocket_test.go:TestWSBasicAuth, TestWSBindToProperAccount
// ========================================================================
[Fact]
public async Task Upgrade_BearerJwt_ExtractedFromAuthHeader()
{
// Go: TestWSBasicAuth — JWT extracted from Authorization: Bearer header.
var request = BuildValidRequest(extraHeaders:
"Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.test_jwt_token\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.test_jwt_token");
}
[Fact]
public async Task Upgrade_CookieJwt_ExtractedFromCookie()
{
// Go: TestWSBindToProperAccount — JWT extracted from cookie when configured.
var request = BuildValidRequest(extraHeaders:
"Cookie: jwt=eyJhbGciOiJIUzI1NiJ9.cookie_jwt; other=value\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.CookieJwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.cookie_jwt");
// Cookie JWT becomes fallback JWT
result.Jwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.cookie_jwt");
}
[Fact]
public async Task Upgrade_QueryJwt_ExtractedFromQueryParam()
{
// Go: JWT extracted from query parameter when no auth header or cookie.
var request = BuildValidRequest(
path: "/?jwt=eyJhbGciOiJIUzI1NiJ9.query_jwt");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.query_jwt");
}
[Fact]
public async Task Upgrade_JwtPriority_BearerOverCookieOverQuery()
{
// Go: Authorization header takes priority over cookie and query.
var request = BuildValidRequest(
path: "/?jwt=query_token",
extraHeaders: "Authorization: Bearer bearer_token\r\nCookie: jwt=cookie_token\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe("bearer_token");
}
// ========================================================================
// TestWSXForwardedFor
// Go reference: websocket_test.go:TestWSXForwardedFor
// ========================================================================
[Fact]
public async Task Upgrade_XForwardedFor_ExtractsClientIp()
{
// Go: TestWSXForwardedFor — X-Forwarded-For header extracts first IP.
var request = BuildValidRequest(extraHeaders:
"X-Forwarded-For: 192.168.1.100, 10.0.0.1\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.ClientIp.ShouldBe("192.168.1.100");
}
// ========================================================================
// TestWSUpgradeValidationErrors
// Go reference: websocket_test.go:TestWSUpgradeValidationErrors
// ========================================================================
[Fact]
public async Task Upgrade_MissingHost_Fails()
{
// Go: TestWSUpgradeValidationErrors — missing Host header.
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeFalse();
}
[Fact]
public async Task Upgrade_MissingUpgradeHeader_Fails()
{
// Go: TestWSUpgradeValidationErrors — missing Upgrade header.
var request = "GET / HTTP/1.1\r\nHost: localhost:4222\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeFalse();
}
[Fact]
public async Task Upgrade_MissingKey_Fails()
{
// Go: TestWSUpgradeValidationErrors — missing Sec-WebSocket-Key.
var request = "GET / HTTP/1.1\r\nHost: localhost:4222\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeFalse();
}
[Fact]
public async Task Upgrade_WrongVersion_Fails()
{
// Go: TestWSUpgradeValidationErrors — wrong WebSocket version.
var request = BuildValidRequest(versionOverride: "12");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeFalse();
}
// ========================================================================
// TestWSSetHeader
// Go reference: websocket_test.go:TestWSSetHeader
// ========================================================================
[Fact]
public async Task Upgrade_CustomHeaders_IncludedInResponse()
{
// Go: TestWSSetHeader — custom headers added to upgrade response.
var request = BuildValidRequest();
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions
{
NoTls = true,
Headers = new Dictionary<string, string> { ["X-Custom"] = "test-value" },
};
await WsUpgrade.TryUpgradeAsync(input, output, opts);
var response = ReadResponse(output);
response.ShouldContain("X-Custom: test-value");
}
// ========================================================================
// TestWSWebrowserClient
// Go reference: websocket_test.go:TestWSWebrowserClient
// ========================================================================
[Fact]
public async Task Upgrade_BrowserUserAgent_DetectedAsBrowser()
{
// Go: TestWSWebrowserClient — Mozilla user-agent detected as browser.
var request = BuildValidRequest(extraHeaders:
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Browser.ShouldBeTrue();
}
[Fact]
public async Task Upgrade_NonBrowserUserAgent_NotDetected()
{
// Go: non-browser user agent is not flagged.
var request = BuildValidRequest(extraHeaders:
"User-Agent: nats-client/1.0\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Browser.ShouldBeFalse();
}
// ========================================================================
// TestWSCompressionBasic
// Go reference: websocket_test.go:TestWSCompressionBasic
// ========================================================================
[Fact]
public void Compression_RoundTrip()
{
// Go: TestWSCompressionBasic — compress then decompress returns original.
var original = "Hello, WebSocket compression test! This is a reasonably long string."u8.ToArray();
var compressed = WsCompression.Compress(original);
var decompressed = WsCompression.Decompress([compressed], maxPayload: 1024 * 1024);
decompressed.ShouldBe(original);
}
[Fact]
public void Compression_SmallData_StillWorks()
{
// Go: even very small data can be compressed/decompressed.
var original = "Hi"u8.ToArray();
var compressed = WsCompression.Compress(original);
var decompressed = WsCompression.Decompress([compressed], maxPayload: 1024);
decompressed.ShouldBe(original);
}
[Fact]
public void Compression_EmptyData()
{
var compressed = WsCompression.Compress(ReadOnlySpan<byte>.Empty);
var decompressed = WsCompression.Decompress([compressed], maxPayload: 1024);
decompressed.ShouldBeEmpty();
}
// ========================================================================
// TestWSDecompressLimit
// Go reference: websocket_test.go:TestWSDecompressLimit
// ========================================================================
[Fact]
public void Decompress_ExceedsMaxPayload_Throws()
{
// Go: TestWSDecompressLimit — decompressed data exceeding max payload throws.
// Create data larger than the limit
var large = new byte[10000];
for (int i = 0; i < large.Length; i++) large[i] = (byte)(i % 256);
var compressed = WsCompression.Compress(large);
Should.Throw<InvalidOperationException>(() =>
WsCompression.Decompress([compressed], maxPayload: 100));
}
// ========================================================================
// MaskBuf / MaskBufs
// Go reference: websocket_test.go TestWSFrameOutbound
// ========================================================================
[Fact]
public void MaskBuf_XorsInPlace()
{
// Go: TestWSFrameOutbound — masking XORs buffer with key.
var key = new byte[] { 0xAA, 0xBB, 0xCC, 0xDD };
var data = new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05 };
var expected = new byte[] { 0x01 ^ 0xAA, 0x02 ^ 0xBB, 0x03 ^ 0xCC, 0x04 ^ 0xDD, 0x05 ^ 0xAA };
WsFrameWriter.MaskBuf(key, data);
data.ShouldBe(expected);
}
[Fact]
public void MaskBuf_DoubleApply_RestoresOriginal()
{
// Go: masking is its own inverse.
var key = new byte[] { 0x12, 0x34, 0x56, 0x78 };
var original = "Hello World"u8.ToArray();
var copy = original.ToArray();
WsFrameWriter.MaskBuf(key, copy);
copy.ShouldNotBe(original);
WsFrameWriter.MaskBuf(key, copy);
copy.ShouldBe(original);
}
// ========================================================================
// MapCloseStatus
// Go reference: websocket_test.go TestWSEnqueueCloseMsg
// ========================================================================
[Fact]
public void MapCloseStatus_ClientClosed_NormalClosure()
{
// Go: TestWSEnqueueCloseMsg — client-initiated close maps to 1000.
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
.ShouldBe(WsConstants.CloseStatusNormalClosure);
}
[Fact]
public void MapCloseStatus_AuthViolation_PolicyViolation()
{
// Go: TestWSEnqueueCloseMsg — auth violation maps to 1008.
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationViolation)
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
}
[Fact]
public void MapCloseStatus_ProtocolError_ProtocolError()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ProtocolViolation)
.ShouldBe(WsConstants.CloseStatusProtocolError);
}
[Fact]
public void MapCloseStatus_ServerShutdown_GoingAway()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.ServerShutdown)
.ShouldBe(WsConstants.CloseStatusGoingAway);
}
[Fact]
public void MapCloseStatus_MaxPayloadExceeded_MessageTooBig()
{
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
}
// ========================================================================
// WsUpgrade.ComputeAcceptKey
// Go reference: websocket_test.go — RFC 6455 example
// ========================================================================
[Fact]
public void ComputeAcceptKey_Rfc6455Example()
{
// RFC 6455 Section 4.2.2 example
var accept = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
accept.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
// ========================================================================
// WsUpgrade — path-based client kind detection
// Go reference: websocket_test.go TestWSWebrowserClient
// ========================================================================
[Fact]
public async Task Upgrade_LeafNodePath_DetectedAsLeaf()
{
var request = BuildValidRequest(path: "/leafnode");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Leaf);
}
[Fact]
public async Task Upgrade_MqttPath_DetectedAsMqtt()
{
var request = BuildValidRequest(path: "/mqtt");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Mqtt);
}
[Fact]
public async Task Upgrade_RootPath_DetectedAsClient()
{
var request = BuildValidRequest(path: "/");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true };
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Client);
}
// ========================================================================
// WsUpgrade — cookie extraction
// Go reference: websocket_test.go TestWSNoAuthUserValidation
// ========================================================================
[Fact]
public async Task Upgrade_Cookies_Extracted()
{
// Go: TestWSNoAuthUserValidation — username/password/token from cookies.
var request = BuildValidRequest(extraHeaders:
"Cookie: nats_user=admin; nats_pass=secret; nats_token=tok123\r\n");
var (input, output) = CreateStreamPair(request);
var opts = new WebSocketOptions
{
NoTls = true,
UsernameCookie = "nats_user",
PasswordCookie = "nats_pass",
TokenCookie = "nats_token",
};
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
result.Success.ShouldBeTrue();
result.CookieUsername.ShouldBe("admin");
result.CookiePassword.ShouldBe("secret");
result.CookieToken.ShouldBe("tok123");
}
// ========================================================================
// ExtractBearerToken
// Go reference: websocket_test.go — bearer token extraction
// ========================================================================
[Fact]
public void ExtractBearerToken_WithPrefix()
{
WsUpgrade.ExtractBearerToken("Bearer my-token").ShouldBe("my-token");
}
[Fact]
public void ExtractBearerToken_WithoutPrefix()
{
WsUpgrade.ExtractBearerToken("my-token").ShouldBe("my-token");
}
[Fact]
public void ExtractBearerToken_Empty_ReturnsNull()
{
WsUpgrade.ExtractBearerToken("").ShouldBeNull();
WsUpgrade.ExtractBearerToken(null).ShouldBeNull();
WsUpgrade.ExtractBearerToken(" ").ShouldBeNull();
}
// ========================================================================
// ParseQueryString
// Go reference: websocket_test.go — query parameter parsing
// ========================================================================
[Fact]
public void ParseQueryString_MultipleParams()
{
var result = WsUpgrade.ParseQueryString("?jwt=abc&user=admin&pass=secret");
result["jwt"].ShouldBe("abc");
result["user"].ShouldBe("admin");
result["pass"].ShouldBe("secret");
}
[Fact]
public void ParseQueryString_UrlEncoded()
{
var result = WsUpgrade.ParseQueryString("?key=hello%20world");
result["key"].ShouldBe("hello world");
}
[Fact]
public void ParseQueryString_NoQuestionMark()
{
var result = WsUpgrade.ParseQueryString("jwt=token123");
result["jwt"].ShouldBe("token123");
}
// ========================================================================
// Helpers
// ========================================================================
private static string BuildValidRequest(string path = "/", string? extraHeaders = null, string? versionOverride = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append($"Sec-WebSocket-Version: {versionOverride ?? "13"}\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
private static string ReadResponse(MemoryStream output)
{
output.Position = 0;
return Encoding.ASCII.GetString(output.ToArray());
}
}

View File

@@ -0,0 +1,166 @@
using System.Buffers.Binary;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsIntegrationTests : IAsyncLifetime
{
private NatsServer _server = null!;
private NatsOptions _options = null!;
public async Task InitializeAsync()
{
_options = new NatsOptions
{
Port = 0,
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
};
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
_server = new NatsServer(_options, loggerFactory);
_ = _server.StartAsync(CancellationToken.None);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _server.ShutdownAsync();
_server.Dispose();
}
[Fact]
public async Task WebSocket_ConnectAndReceiveInfo()
{
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
using var stream = new NetworkStream(socket, ownsSocket: false);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
var wsFrame = await ReadWsFrame(stream);
var info = Encoding.ASCII.GetString(wsFrame);
info.ShouldStartWith("INFO ");
}
[Fact]
public async Task WebSocket_ConnectAndPing()
{
using var client = await ConnectWsClient();
// Send CONNECT and PING together
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
// Read PONG WS frame
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var pong = await ReadWsFrameAsync(client, cts.Token);
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
}
[Fact]
public async Task WebSocket_PubSub()
{
using var sub = await ConnectWsClient();
using var pub = await ConnectWsClient();
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\nPING\r\n");
// Wait for PONG to confirm subscription is registered
using var subCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var pong = await ReadWsFrameAsync(sub, subCts.Token);
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var msg = await ReadWsFrameAsync(sub, cts.Token);
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
}
private async Task<NetworkStream> ConnectWsClient()
{
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
var stream = new NetworkStream(socket, ownsSocket: true);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
await ReadWsFrame(stream); // Read INFO frame
return stream;
}
private static async Task SendUpgradeRequest(NetworkStream stream)
{
var keyBytes = new byte[16];
RandomNumberGenerator.Fill(keyBytes);
var key = Convert.ToBase64String(keyBytes);
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
await stream.FlushAsync();
}
private static async Task<string> ReadHttpResponse(NetworkStream stream)
{
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
var sb = new StringBuilder();
var buf = new byte[1];
while (true)
{
int n = await stream.ReadAsync(buf);
if (n == 0) break;
sb.Append((char)buf[0]);
if (sb.Length >= 4 &&
sb[^4] == '\r' && sb[^3] == '\n' &&
sb[^2] == '\r' && sb[^1] == '\n')
break;
}
return sb.ToString();
}
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
=> ReadWsFrameAsync(stream, CancellationToken.None);
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
{
var header = new byte[2];
await stream.ReadExactlyAsync(header, ct);
int len = header[1] & 0x7F;
if (len == 126)
{
var extLen = new byte[2];
await stream.ReadExactlyAsync(extLen, ct);
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
}
else if (len == 127)
{
var extLen = new byte[8];
await stream.ReadExactlyAsync(extLen, ct);
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
}
var payload = new byte[len];
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
return payload;
}
private static async Task SendWsText(NetworkStream stream, string text)
{
var payload = Encoding.ASCII.GetBytes(text);
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
await stream.WriteAsync(header);
await stream.WriteAsync(payload);
await stream.FlushAsync();
}
}

View File

@@ -0,0 +1,316 @@
// Tests for WebSocket JWT authentication during upgrade (E11).
// Verifies JWT extraction from Authorization header, cookie, and query parameter.
// Reference: golang/nats-server/server/websocket.go — cookie JWT extraction (line 856),
// websocket_test.go — TestWSReloadTLSConfig (line 4066).
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsJwtAuthTests
{
// ─── Authorization header JWT extraction ─────────────────────────────
[Fact]
public async Task Upgrade_AuthorizationBearerHeader_ExtractsJwt()
{
// JWT from Authorization: Bearer <token> header (standard HTTP auth)
var jwt = "eyJhbGciOiJFZDI1NTE5IiwidHlwIjoiSldUIn0.test-payload.test-sig";
var request = BuildValidRequest(extraHeaders:
$"Authorization: Bearer {jwt}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(jwt);
}
[Fact]
public async Task Upgrade_AuthorizationBearerCaseInsensitive()
{
// RFC 7235: "bearer" scheme is case-insensitive
var jwt = "my-jwt-token-123";
var request = BuildValidRequest(extraHeaders:
$"Authorization: bearer {jwt}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(jwt);
}
[Fact]
public async Task Upgrade_AuthorizationBareToken_ExtractsJwt()
{
// Some clients send the token directly without "Bearer" prefix
var jwt = "raw-jwt-token-456";
var request = BuildValidRequest(extraHeaders:
$"Authorization: {jwt}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(jwt);
}
// ─── Cookie JWT extraction ──────────────────────────────────────────
[Fact]
public async Task Upgrade_JwtCookie_ExtractsJwt()
{
// Go parity: websocket.go line 856 — JWT from configured cookie name
var jwt = "cookie-jwt-token-789";
var request = BuildValidRequest(extraHeaders:
$"Cookie: jwt={jwt}; other=value\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.CookieJwt.ShouldBe(jwt);
// Cookie JWT is used as fallback when no Authorization header is present
result.Jwt.ShouldBe(jwt);
}
[Fact]
public async Task Upgrade_AuthorizationHeader_TakesPriorityOverCookie()
{
// Authorization header has higher priority than cookie
var headerJwt = "auth-header-jwt";
var cookieJwt = "cookie-jwt";
var request = BuildValidRequest(extraHeaders:
$"Authorization: Bearer {headerJwt}\r\n" +
$"Cookie: jwt={cookieJwt}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(headerJwt);
result.CookieJwt.ShouldBe(cookieJwt); // Cookie value is still preserved
}
// ─── Query parameter JWT extraction ─────────────────────────────────
[Fact]
public async Task Upgrade_QueryParamJwt_ExtractsJwt()
{
// JWT from ?jwt= query parameter (useful for browser clients)
var jwt = "query-jwt-token-abc";
var request = BuildValidRequest(path: $"/?jwt={jwt}");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(jwt);
}
[Fact]
public async Task Upgrade_QueryParamJwt_UrlEncoded()
{
// JWT value may be URL-encoded
var jwt = "eyJ0eXAiOiJKV1QifQ.payload.sig";
var encoded = Uri.EscapeDataString(jwt);
var request = BuildValidRequest(path: $"/?jwt={encoded}");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(jwt);
}
[Fact]
public async Task Upgrade_AuthorizationHeader_TakesPriorityOverQueryParam()
{
// Authorization header > query parameter
var headerJwt = "auth-header-jwt";
var queryJwt = "query-jwt";
var request = BuildValidRequest(
path: $"/?jwt={queryJwt}",
extraHeaders: $"Authorization: Bearer {headerJwt}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(headerJwt);
}
[Fact]
public async Task Upgrade_Cookie_TakesPriorityOverQueryParam()
{
// Cookie > query parameter
var cookieJwt = "cookie-jwt";
var queryJwt = "query-jwt";
var request = BuildValidRequest(
path: $"/?jwt={queryJwt}",
extraHeaders: $"Cookie: jwt_token={cookieJwt}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt_token" };
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.Jwt.ShouldBe(cookieJwt);
}
// ─── No JWT scenarios ───────────────────────────────────────────────
[Fact]
public async Task Upgrade_NoJwtAnywhere_JwtIsNull()
{
// No JWT in any source
var request = BuildValidRequest();
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Jwt.ShouldBeNull();
}
[Fact]
public async Task Upgrade_EmptyAuthorizationHeader_JwtIsEmpty()
{
// Empty authorization header should produce empty string (non-null)
var request = BuildValidRequest(extraHeaders: "Authorization: \r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
// Empty auth header is treated as null/no JWT
result.Jwt.ShouldBeNull();
}
// ─── ExtractBearerToken unit tests ──────────────────────────────────
[Fact]
public void ExtractBearerToken_BearerPrefix()
{
WsUpgrade.ExtractBearerToken("Bearer my-token").ShouldBe("my-token");
}
[Fact]
public void ExtractBearerToken_BearerPrefixLowerCase()
{
WsUpgrade.ExtractBearerToken("bearer my-token").ShouldBe("my-token");
}
[Fact]
public void ExtractBearerToken_BareToken()
{
WsUpgrade.ExtractBearerToken("raw-token").ShouldBe("raw-token");
}
[Fact]
public void ExtractBearerToken_Null()
{
WsUpgrade.ExtractBearerToken(null).ShouldBeNull();
}
[Fact]
public void ExtractBearerToken_Empty()
{
WsUpgrade.ExtractBearerToken("").ShouldBeNull();
}
[Fact]
public void ExtractBearerToken_Whitespace()
{
WsUpgrade.ExtractBearerToken(" ").ShouldBeNull();
}
// ─── ParseQueryString unit tests ────────────────────────────────────
[Fact]
public void ParseQueryString_SingleParam()
{
var result = WsUpgrade.ParseQueryString("?jwt=token123");
result["jwt"].ShouldBe("token123");
}
[Fact]
public void ParseQueryString_MultipleParams()
{
var result = WsUpgrade.ParseQueryString("?jwt=token&user=admin");
result["jwt"].ShouldBe("token");
result["user"].ShouldBe("admin");
}
[Fact]
public void ParseQueryString_UrlEncoded()
{
var result = WsUpgrade.ParseQueryString("?jwt=a%20b%3Dc");
result["jwt"].ShouldBe("a b=c");
}
[Fact]
public void ParseQueryString_NoQuestionMark()
{
var result = WsUpgrade.ParseQueryString("jwt=token");
result["jwt"].ShouldBe("token");
}
// ─── FailUnauthorizedAsync ──────────────────────────────────────────
[Fact]
public async Task FailUnauthorizedAsync_Returns401()
{
var output = new MemoryStream();
var result = await WsUpgrade.FailUnauthorizedAsync(output, "invalid JWT");
result.Success.ShouldBeFalse();
output.Position = 0;
var response = Encoding.ASCII.GetString(output.ToArray());
response.ShouldContain("401");
response.ShouldContain("invalid JWT");
}
// ─── Query param path routing still works with query strings ────────
[Fact]
public async Task Upgrade_PathWithQueryParam_StillRoutesCorrectly()
{
// /leafnode?jwt=token should still detect as leaf kind
var request = BuildValidRequest(path: "/leafnode?jwt=my-token");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Leaf);
result.Jwt.ShouldBe("my-token");
}
// ─── Helpers ─────────────────────────────────────────────────────────
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
}

View File

@@ -0,0 +1,82 @@
using NATS.Server.WebSocket;
using Shouldly;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsOriginCheckerTests
{
[Fact]
public void NoOriginHeader_Accepted()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false)
.ShouldBeNull();
}
[Fact]
public void NeitherSameNorList_AlwaysAccepted()
{
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
checker.CheckOrigin("https://evil.com", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_Match()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost:4222", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_Mismatch()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://other:4222", "localhost:4222", false)
.ShouldNotBeNull();
}
[Fact]
public void SameOrigin_DefaultPort_Http()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("http://localhost", "localhost:80", false)
.ShouldBeNull();
}
[Fact]
public void SameOrigin_DefaultPort_Https()
{
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
checker.CheckOrigin("https://localhost", "localhost:443", true)
.ShouldBeNull();
}
[Fact]
public void AllowedOrigins_Match()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("https://app.example.com", "localhost:4222", false)
.ShouldBeNull();
}
[Fact]
public void AllowedOrigins_Mismatch()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("https://evil.example.com", "localhost:4222", false)
.ShouldNotBeNull();
}
[Fact]
public void AllowedOrigins_SchemeMismatch()
{
var checker = new WsOriginChecker(sameOrigin: false,
allowedOrigins: ["https://app.example.com"]);
checker.CheckOrigin("http://app.example.com", "localhost:4222", false)
.ShouldNotBeNull();
}
}

View File

@@ -0,0 +1,66 @@
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsUpgradeHelperParityBatch1Tests
{
[Fact]
public void MakeChallengeKey_returns_base64_of_16_random_bytes()
{
var key = WsUpgrade.MakeChallengeKey();
var decoded = Convert.FromBase64String(key);
decoded.Length.ShouldBe(16);
}
[Fact]
public void Url_helpers_match_ws_and_wss_schemes()
{
WsUpgrade.IsWsUrl("ws://localhost:8080").ShouldBeTrue();
WsUpgrade.IsWsUrl("wss://localhost:8443").ShouldBeFalse();
WsUpgrade.IsWsUrl("http://localhost").ShouldBeFalse();
WsUpgrade.IsWssUrl("wss://localhost:8443").ShouldBeTrue();
WsUpgrade.IsWssUrl("ws://localhost:8080").ShouldBeFalse();
WsUpgrade.IsWssUrl("https://localhost").ShouldBeFalse();
}
[Fact]
public async Task RejectNoMaskingForTest_forces_no_masking_handshake_rejection()
{
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
using var input = new MemoryStream(Encoding.ASCII.GetBytes(request));
using var output = new MemoryStream();
try
{
WsUpgrade.RejectNoMaskingForTest = true;
var result = await WsUpgrade.TryUpgradeAsync(input, output, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
output.Position = 0;
var response = Encoding.ASCII.GetString(output.ToArray());
response.ShouldContain("400 Bad Request");
response.ShouldContain("invalid value for no-masking");
}
finally
{
WsUpgrade.RejectNoMaskingForTest = false;
}
}
private static string BuildValidRequest(string path = "/", string extraHeaders = "")
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:8080\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
}

View File

@@ -0,0 +1,226 @@
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Transport.Tests.WebSocket;
public class WsUpgradeTests
{
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
[Fact]
public async Task ValidUpgrade_Returns101()
{
var request = BuildValidRequest();
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Client);
var response = ReadResponse(outputStream);
response.ShouldContain("HTTP/1.1 101");
response.ShouldContain("Upgrade: websocket");
response.ShouldContain("Sec-WebSocket-Accept:");
}
[Fact]
public async Task MissingUpgradeHeader_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("400");
}
[Fact]
public async Task MissingHost_Returns400()
{
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task WrongVersion_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task LeafNodePath_ReturnsLeafKind()
{
var request = BuildValidRequest("/leafnode");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Leaf);
}
[Fact]
public async Task MqttPath_ReturnsMqttKind()
{
var request = BuildValidRequest("/mqtt");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Mqtt);
}
[Fact]
public async Task CompressionNegotiation_WhenEnabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeTrue();
ReadResponse(outputStream).ShouldContain("permessage-deflate");
}
[Fact]
public async Task CompressionNegotiation_WhenDisabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeFalse();
}
[Fact]
public async Task NoMaskingHeader_ForLeaf()
{
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.MaskRead.ShouldBeFalse();
}
[Fact]
public async Task BrowserDetection_Mozilla()
{
var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Browser.ShouldBeTrue();
}
[Fact]
public async Task SafariDetection_NoCompFrag()
{
var request = BuildValidRequest(extraHeaders:
"User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
$"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.NoCompFrag.ShouldBeTrue();
}
[Fact]
public void AcceptKey_MatchesRfc6455Example()
{
// RFC 6455 Section 4.2.2 example
var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
[Fact]
public async Task CookieExtraction()
{
var request = BuildValidRequest(extraHeaders:
"Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions
{
NoTls = true,
JwtCookie = "jwt_token",
UsernameCookie = "nats_user",
PasswordCookie = "nats_pass",
};
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.CookieJwt.ShouldBe("my-jwt");
result.CookieUsername.ShouldBe("admin");
result.CookiePassword.ShouldBe("secret");
}
[Fact]
public async Task XForwardedFor_ExtractsClientIp()
{
var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.ClientIp.ShouldBe("192.168.1.100");
}
[Fact]
public async Task PostMethod_Returns405()
{
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("405");
}
// Helper: create a readable input stream and writable output stream
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
private static string ReadResponse(MemoryStream output)
{
output.Position = 0;
return Encoding.ASCII.GetString(output.ToArray());
}
}