refactor: extract NATS.Server.Transport.Tests project
Move TLS, OCSP, WebSocket, Networking, and IO test files from NATS.Server.Tests into a dedicated NATS.Server.Transport.Tests project. Update namespaces, replace private GetFreePort/ReadUntilAsync with shared TestUtilities helpers, extract TestCertHelper to TestUtilities, and replace Task.Delay polling loops with PollHelper.WaitUntilAsync/YieldForAsync for proper synchronization.
This commit is contained in:
@@ -0,0 +1,126 @@
|
||||
using NATS.Server.IO;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.IO;
|
||||
|
||||
/// <summary>
|
||||
/// Tests for the consecutive short-read counter in AdaptiveReadBuffer.
|
||||
/// Go reference: server/client.go — readLoop buffer sizing with short-read counter.
|
||||
/// </summary>
|
||||
public class AdaptiveReadBufferShortReadTests
|
||||
{
|
||||
[Fact]
|
||||
public void Initial_size_is_4096()
|
||||
{
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.CurrentSize.ShouldBe(4096);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Full_read_doubles_size()
|
||||
{
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(4096);
|
||||
b.CurrentSize.ShouldBe(8192);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Single_short_read_does_not_shrink()
|
||||
{
|
||||
// A short read is less than target/4 = 4096/4 = 1024
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100);
|
||||
b.CurrentSize.ShouldBe(4096);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Three_short_reads_do_not_shrink()
|
||||
{
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100);
|
||||
b.RecordRead(100);
|
||||
b.RecordRead(100);
|
||||
b.CurrentSize.ShouldBe(4096);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Four_short_reads_triggers_shrink()
|
||||
{
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100);
|
||||
b.RecordRead(100);
|
||||
b.RecordRead(100);
|
||||
b.RecordRead(100);
|
||||
b.CurrentSize.ShouldBe(2048);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Short_read_counter_resets_on_full_read()
|
||||
{
|
||||
// 3 short reads, then a full read resets the counter — subsequent short read should not shrink
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100); // short
|
||||
b.RecordRead(100); // short
|
||||
b.RecordRead(100); // short (3 total, not yet at threshold)
|
||||
b.RecordRead(4096); // full read — doubles size and resets counter
|
||||
b.RecordRead(512); // short (relative to new size 8192; 512 < 8192/4=2048) — only 1 consecutive
|
||||
b.CurrentSize.ShouldBe(8192);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Short_read_counter_resets_on_medium_read()
|
||||
{
|
||||
// A medium read is >= target/4 but < target
|
||||
// For target 4096: medium range is [1024, 4096)
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100); // short — counter = 1
|
||||
b.RecordRead(100); // short — counter = 2
|
||||
b.RecordRead(100); // short — counter = 3
|
||||
b.RecordRead(2000); // medium (>= 4096/4=1024, < 4096) — resets counter
|
||||
b.RecordRead(100); // short — counter = 1, should not shrink
|
||||
b.CurrentSize.ShouldBe(4096);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Short_read_counter_resets_after_shrink()
|
||||
{
|
||||
// After 4 short reads trigger a shrink, counter resets to 0
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100); // short — counter = 1
|
||||
b.RecordRead(100); // short — counter = 2
|
||||
b.RecordRead(100); // short — counter = 3
|
||||
b.RecordRead(100); // short — counter = 4 → shrinks to 2048, resets counter to 0
|
||||
b.ConsecutiveShortReads.ShouldBe(0);
|
||||
// One more short read should be counter = 1 (not triggering another shrink)
|
||||
b.RecordRead(50); // short relative to 2048 (50 < 2048/4=512) — counter = 1
|
||||
b.ConsecutiveShortReads.ShouldBe(1);
|
||||
b.CurrentSize.ShouldBe(2048);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Size_never_goes_below_512()
|
||||
{
|
||||
// Force the buffer down to 512 then attempt to shrink further
|
||||
var b = new AdaptiveReadBuffer();
|
||||
|
||||
// Drive target down to 512 via repeated shrink cycles
|
||||
for (var i = 0; i < 4; i++) b.RecordRead(1); // target 4096 → 2048
|
||||
for (var i = 0; i < 4; i++) b.RecordRead(1); // 2048 → 1024
|
||||
for (var i = 0; i < 4; i++) b.RecordRead(1); // 1024 → 512
|
||||
|
||||
b.CurrentSize.ShouldBe(512);
|
||||
|
||||
// Now try to shrink again — should stay at 512
|
||||
for (var i = 0; i < 4; i++) b.RecordRead(1);
|
||||
b.CurrentSize.ShouldBe(512);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConsecutiveShortReads_property_reflects_count()
|
||||
{
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(100); // short — counter = 1
|
||||
b.RecordRead(100); // short — counter = 2
|
||||
b.ConsecutiveShortReads.ShouldBe(2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
using NATS.Server.IO;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class AdaptiveReadBufferTests
|
||||
{
|
||||
[Fact]
|
||||
public void Read_buffer_scales_between_512_and_65536_based_on_recent_payload_pattern()
|
||||
{
|
||||
var b = new AdaptiveReadBuffer();
|
||||
b.RecordRead(512);
|
||||
b.RecordRead(4096);
|
||||
b.RecordRead(32000);
|
||||
b.CurrentSize.ShouldBeGreaterThan(4096);
|
||||
b.CurrentSize.ShouldBeLessThanOrEqualTo(64 * 1024);
|
||||
}
|
||||
}
|
||||
179
tests/NATS.Server.Transport.Tests/IO/DynamicBufferPoolTests.cs
Normal file
179
tests/NATS.Server.Transport.Tests/IO/DynamicBufferPoolTests.cs
Normal file
@@ -0,0 +1,179 @@
|
||||
using System.Text;
|
||||
using NATS.Server.IO;
|
||||
using Shouldly;
|
||||
|
||||
// Go reference: client.go — dynamic buffer sizing and broadcast flush coalescing for fan-out.
|
||||
|
||||
namespace NATS.Server.Transport.Tests.IO;
|
||||
|
||||
public class DynamicBufferPoolTests
|
||||
{
|
||||
// -----------------------------------------------------------------------
|
||||
// Rent (IMemoryOwner<byte>)
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
[Fact]
|
||||
public void Rent_returns_buffer_of_requested_size_or_larger()
|
||||
{
|
||||
// Go ref: client.go — dynamic buffer sizing (512 → 65536).
|
||||
var pool = new OutboundBufferPool();
|
||||
using var owner = pool.Rent(100);
|
||||
owner.Memory.Length.ShouldBeGreaterThanOrEqualTo(100);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// RentBuffer — tier sizing
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
[Fact]
|
||||
public void RentBuffer_returns_small_buffer()
|
||||
{
|
||||
// Go ref: client.go — initial 512 B write buffer per connection.
|
||||
var pool = new OutboundBufferPool();
|
||||
var buf = pool.RentBuffer(100);
|
||||
buf.Length.ShouldBeGreaterThanOrEqualTo(512);
|
||||
pool.ReturnBuffer(buf);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RentBuffer_returns_medium_buffer()
|
||||
{
|
||||
// Go ref: client.go — 4 KiB write buffer growth step.
|
||||
var pool = new OutboundBufferPool();
|
||||
var buf = pool.RentBuffer(1000);
|
||||
buf.Length.ShouldBeGreaterThanOrEqualTo(4096);
|
||||
pool.ReturnBuffer(buf);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RentBuffer_returns_large_buffer()
|
||||
{
|
||||
// Go ref: client.go — max 64 KiB write buffer per connection.
|
||||
var pool = new OutboundBufferPool();
|
||||
var buf = pool.RentBuffer(10000);
|
||||
buf.Length.ShouldBeGreaterThanOrEqualTo(65536);
|
||||
pool.ReturnBuffer(buf);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// ReturnBuffer + reuse
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
[Fact]
|
||||
public void ReturnBuffer_and_reuse()
|
||||
{
|
||||
// Verifies that a returned buffer is available for reuse on the next
|
||||
// RentBuffer call of the same tier.
|
||||
// Go ref: client.go — buffer pooling to avoid GC pressure.
|
||||
var pool = new OutboundBufferPool();
|
||||
|
||||
var first = pool.RentBuffer(100); // small tier → 512 B
|
||||
first.Length.ShouldBe(512);
|
||||
pool.ReturnBuffer(first);
|
||||
|
||||
var second = pool.RentBuffer(100); // should reuse the returned buffer
|
||||
second.Length.ShouldBe(512);
|
||||
// ReferenceEquals confirms the exact same array instance was reused.
|
||||
ReferenceEquals(first, second).ShouldBeTrue();
|
||||
pool.ReturnBuffer(second);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// BroadcastDrain — coalescing
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
[Fact]
|
||||
public void BroadcastDrain_coalesces_writes()
|
||||
{
|
||||
// Go ref: client.go — broadcast flush for fan-out publish.
|
||||
var pool = new OutboundBufferPool();
|
||||
|
||||
var p1 = Encoding.UTF8.GetBytes("Hello");
|
||||
var p2 = Encoding.UTF8.GetBytes(", ");
|
||||
var p3 = Encoding.UTF8.GetBytes("World");
|
||||
|
||||
IReadOnlyList<ReadOnlyMemory<byte>> pending =
|
||||
[
|
||||
p1.AsMemory(),
|
||||
p2.AsMemory(),
|
||||
p3.AsMemory(),
|
||||
];
|
||||
|
||||
var dest = new byte[OutboundBufferPool.CalculateBroadcastSize(pending)];
|
||||
pool.BroadcastDrain(pending, dest);
|
||||
|
||||
Encoding.UTF8.GetString(dest).ShouldBe("Hello, World");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BroadcastDrain_returns_correct_byte_count()
|
||||
{
|
||||
// Go ref: client.go — total bytes written during coalesced drain.
|
||||
var pool = new OutboundBufferPool();
|
||||
|
||||
IReadOnlyList<ReadOnlyMemory<byte>> pending =
|
||||
[
|
||||
new byte[10].AsMemory(),
|
||||
new byte[20].AsMemory(),
|
||||
new byte[30].AsMemory(),
|
||||
];
|
||||
|
||||
var dest = new byte[60];
|
||||
var written = pool.BroadcastDrain(pending, dest);
|
||||
|
||||
written.ShouldBe(60);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// CalculateBroadcastSize
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
[Fact]
|
||||
public void CalculateBroadcastSize_sums_all_writes()
|
||||
{
|
||||
// Go ref: client.go — pre-check buffer capacity before coalesced drain.
|
||||
IReadOnlyList<ReadOnlyMemory<byte>> pending =
|
||||
[
|
||||
new byte[7].AsMemory(),
|
||||
new byte[13].AsMemory(),
|
||||
];
|
||||
|
||||
OutboundBufferPool.CalculateBroadcastSize(pending).ShouldBe(20);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Stats counters
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
[Fact]
|
||||
public void RentCount_increments()
|
||||
{
|
||||
// Go ref: client.go — observability for buffer allocation rate.
|
||||
var pool = new OutboundBufferPool();
|
||||
|
||||
pool.RentCount.ShouldBe(0L);
|
||||
|
||||
using var _ = pool.Rent(100);
|
||||
pool.RentBuffer(200);
|
||||
|
||||
pool.RentCount.ShouldBe(2L);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BroadcastCount_increments()
|
||||
{
|
||||
// Go ref: client.go — observability for fan-out drain operations.
|
||||
var pool = new OutboundBufferPool();
|
||||
|
||||
pool.BroadcastCount.ShouldBe(0L);
|
||||
|
||||
IReadOnlyList<ReadOnlyMemory<byte>> pending = [new byte[4].AsMemory()];
|
||||
var dest = new byte[4];
|
||||
|
||||
pool.BroadcastDrain(pending, dest);
|
||||
pool.BroadcastDrain(pending, dest);
|
||||
pool.BroadcastDrain(pending, dest);
|
||||
|
||||
pool.BroadcastCount.ShouldBe(3L);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
using NATS.Server.IO;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class OutboundBufferPoolTests
|
||||
{
|
||||
[Theory]
|
||||
[InlineData(100, 512)]
|
||||
[InlineData(1000, 4096)]
|
||||
[InlineData(10000, 64 * 1024)]
|
||||
public void Rent_uses_three_tier_buffer_buckets(int requested, int expectedMinimum)
|
||||
{
|
||||
var pool = new OutboundBufferPool();
|
||||
using var owner = pool.Rent(requested);
|
||||
owner.Memory.Length.ShouldBeGreaterThanOrEqualTo(expectedMinimum);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<PropertyGroup>
|
||||
<IsPackable>false</IsPackable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="coverlet.collector" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" />
|
||||
<PackageReference Include="NATS.Client.Core" />
|
||||
<PackageReference Include="NSubstitute" />
|
||||
<PackageReference Include="Shouldly" />
|
||||
<PackageReference Include="xunit" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" />
|
||||
<PackageReference Include="Serilog.Sinks.File" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Using Include="Xunit" />
|
||||
<Using Include="Shouldly" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\NATS.Server\NATS.Server.csproj" />
|
||||
<ProjectReference Include="..\NATS.Server.TestUtilities\NATS.Server.TestUtilities.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
File diff suppressed because it is too large
Load Diff
90
tests/NATS.Server.Transport.Tests/OcspConfigTests.cs
Normal file
90
tests/NATS.Server.Transport.Tests/OcspConfigTests.cs
Normal file
@@ -0,0 +1,90 @@
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class OcspConfigTests
|
||||
{
|
||||
[Fact]
|
||||
public void OcspMode_Auto_has_value_zero()
|
||||
{
|
||||
((int)OcspMode.Auto).ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspMode_Always_has_value_one()
|
||||
{
|
||||
((int)OcspMode.Always).ShouldBe(1);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspMode_Must_has_value_two()
|
||||
{
|
||||
((int)OcspMode.Must).ShouldBe(2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspMode_Never_has_value_three()
|
||||
{
|
||||
((int)OcspMode.Never).ShouldBe(3);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspConfig_default_mode_is_Auto()
|
||||
{
|
||||
var config = new OcspConfig();
|
||||
config.Mode.ShouldBe(OcspMode.Auto);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspConfig_OverrideUrls_defaults_to_empty_array()
|
||||
{
|
||||
var config = new OcspConfig();
|
||||
config.OverrideUrls.ShouldNotBeNull();
|
||||
config.OverrideUrls.ShouldBeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspConfig_Mode_can_be_set_via_init()
|
||||
{
|
||||
var config = new OcspConfig { Mode = OcspMode.Must };
|
||||
config.Mode.ShouldBe(OcspMode.Must);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspConfig_OverrideUrls_can_be_set_via_init()
|
||||
{
|
||||
var urls = new[] { "http://ocsp.example.com", "http://backup.example.com" };
|
||||
var config = new OcspConfig { OverrideUrls = urls };
|
||||
config.OverrideUrls.ShouldBe(urls);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NatsOptions_OcspConfig_defaults_to_null()
|
||||
{
|
||||
var opts = new NatsOptions();
|
||||
opts.OcspConfig.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NatsOptions_OcspPeerVerify_defaults_to_false()
|
||||
{
|
||||
var opts = new NatsOptions();
|
||||
opts.OcspPeerVerify.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NatsOptions_OcspConfig_can_be_assigned()
|
||||
{
|
||||
var config = new OcspConfig { Mode = OcspMode.Always };
|
||||
var opts = new NatsOptions { OcspConfig = config };
|
||||
opts.OcspConfig.ShouldNotBeNull();
|
||||
opts.OcspConfig!.Mode.ShouldBe(OcspMode.Always);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NatsOptions_OcspPeerVerify_can_be_set_to_true()
|
||||
{
|
||||
var opts = new NatsOptions { OcspPeerVerify = true };
|
||||
opts.OcspPeerVerify.ShouldBeTrue();
|
||||
}
|
||||
}
|
||||
97
tests/NATS.Server.Transport.Tests/OcspStaplingTests.cs
Normal file
97
tests/NATS.Server.Transport.Tests/OcspStaplingTests.cs
Normal file
@@ -0,0 +1,97 @@
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class OcspStaplingTests
|
||||
{
|
||||
[Fact]
|
||||
public void OcspMode_Must_is_strictest()
|
||||
{
|
||||
var config = new OcspConfig { Mode = OcspMode.Must };
|
||||
config.Mode.ShouldBe(OcspMode.Must);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspMode_Never_disables_all()
|
||||
{
|
||||
var config = new OcspConfig { Mode = OcspMode.Never };
|
||||
config.Mode.ShouldBe(OcspMode.Never);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspPeerVerify_default_is_false()
|
||||
{
|
||||
var options = new NatsOptions();
|
||||
options.OcspPeerVerify.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspConfig_default_mode_is_Auto()
|
||||
{
|
||||
var config = new OcspConfig();
|
||||
config.Mode.ShouldBe(OcspMode.Auto);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspConfig_default_OverrideUrls_is_empty()
|
||||
{
|
||||
var config = new OcspConfig();
|
||||
config.OverrideUrls.ShouldBeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCertificateContext_returns_null_when_no_tls()
|
||||
{
|
||||
var options = new NatsOptions
|
||||
{
|
||||
OcspConfig = new OcspConfig { Mode = OcspMode.Always },
|
||||
};
|
||||
// HasTls is false because TlsCert and TlsKey are not set
|
||||
options.HasTls.ShouldBeFalse();
|
||||
var context = TlsHelper.BuildCertificateContext(options);
|
||||
context.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCertificateContext_returns_null_when_mode_is_Never()
|
||||
{
|
||||
var options = new NatsOptions
|
||||
{
|
||||
TlsCert = "server.pem",
|
||||
TlsKey = "server-key.pem",
|
||||
OcspConfig = new OcspConfig { Mode = OcspMode.Never },
|
||||
};
|
||||
// OcspMode.Never must short-circuit even when TLS cert paths are set
|
||||
var context = TlsHelper.BuildCertificateContext(options);
|
||||
context.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildCertificateContext_returns_null_when_OcspConfig_is_null()
|
||||
{
|
||||
var options = new NatsOptions
|
||||
{
|
||||
TlsCert = "server.pem",
|
||||
TlsKey = "server-key.pem",
|
||||
OcspConfig = null,
|
||||
};
|
||||
var context = TlsHelper.BuildCertificateContext(options);
|
||||
context.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspPeerVerify_can_be_enabled()
|
||||
{
|
||||
var options = new NatsOptions { OcspPeerVerify = true };
|
||||
options.OcspPeerVerify.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspMode_values_have_correct_ordinals()
|
||||
{
|
||||
((int)OcspMode.Auto).ShouldBe(0);
|
||||
((int)OcspMode.Always).ShouldBe(1);
|
||||
((int)OcspMode.Must).ShouldBe(2);
|
||||
((int)OcspMode.Never).ShouldBe(3);
|
||||
}
|
||||
}
|
||||
255
tests/NATS.Server.Transport.Tests/TlsConnectionWrapperTests.cs
Normal file
255
tests/NATS.Server.Transport.Tests/TlsConnectionWrapperTests.cs
Normal file
@@ -0,0 +1,255 @@
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
using NATS.Server.Protocol;
|
||||
using NATS.Server.TestUtilities;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsConnectionWrapperTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task NoTls_returns_plain_stream()
|
||||
{
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var serverStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
using var clientStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions(); // No TLS configured
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverStream, opts, null, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBe(serverStream); // Same stream, no wrapping
|
||||
infoSent.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRequired_upgrades_to_ssl()
|
||||
{
|
||||
var (cert, _) = TestCertHelper.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy" };
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then start TLS
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
// Read INFO line
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Upgrade to TLS
|
||||
var sslClient = new SslStream(clientNetStream, true,
|
||||
(_, _, _, _) => true); // Trust all for testing
|
||||
await sslClient.AuthenticateAsClientAsync("localhost");
|
||||
return sslClient;
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBeOfType<SslStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
var clientSsl = await clientTask;
|
||||
|
||||
// Verify encrypted communication works
|
||||
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var readBuf = new byte[64];
|
||||
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
||||
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
||||
msg.ShouldBe("PING\r\n");
|
||||
|
||||
stream.Dispose();
|
||||
clientSsl.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MixedMode_allows_plaintext_when_AllowNonTls()
|
||||
{
|
||||
var (cert, _) = TestCertHelper.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsCert = "dummy",
|
||||
TlsKey = "dummy",
|
||||
AllowNonTls = true,
|
||||
TlsTimeout = TimeSpan.FromSeconds(2),
|
||||
};
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then send plaintext (not TLS)
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Send plaintext CONNECT (not a TLS handshake)
|
||||
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
||||
await clientNetStream.WriteAsync(connectLine);
|
||||
await clientNetStream.FlushAsync();
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
await clientTask;
|
||||
|
||||
// In mixed mode with plaintext client, we get a PeekableStream, not SslStream
|
||||
stream.ShouldBeOfType<PeekableStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
stream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRequired_rejects_plaintext()
|
||||
{
|
||||
var (cert, _) = TestCertHelper.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsCert = "dummy",
|
||||
TlsKey = "dummy",
|
||||
AllowNonTls = false,
|
||||
TlsTimeout = TimeSpan.FromSeconds(2),
|
||||
};
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: read INFO then send plaintext
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var buf = new byte[4096];
|
||||
var read = await clientNetStream.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
// Send plaintext data (first byte is 'C', not 0x16 TLS marker)
|
||||
var connectLine = System.Text.Encoding.ASCII.GetBytes("CONNECT {}\r\n");
|
||||
await clientNetStream.WriteAsync(connectLine);
|
||||
await clientNetStream.FlushAsync();
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
|
||||
await Should.ThrowAsync<InvalidOperationException>(async () =>
|
||||
{
|
||||
await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
});
|
||||
|
||||
await clientTask;
|
||||
serverNetStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsFirst_handshakes_before_sending_info()
|
||||
{
|
||||
var (cert, _) = TestCertHelper.GenerateTestCert();
|
||||
|
||||
var (serverSocket, clientSocket) = await CreateSocketPairAsync();
|
||||
using var clientNetStream = new NetworkStream(clientSocket, ownsSocket: true);
|
||||
|
||||
var opts = new NatsOptions { TlsCert = "dummy", TlsKey = "dummy", TlsHandshakeFirst = true };
|
||||
var sslOpts = new SslServerAuthenticationOptions
|
||||
{
|
||||
ServerCertificate = cert,
|
||||
};
|
||||
var serverInfo = CreateServerInfo();
|
||||
|
||||
// Client side: immediately start TLS (no INFO first)
|
||||
var clientTask = Task.Run(async () =>
|
||||
{
|
||||
var sslClient = new SslStream(clientNetStream, true, (_, _, _, _) => true);
|
||||
await sslClient.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
// After TLS, read INFO over encrypted stream
|
||||
var buf = new byte[4096];
|
||||
var read = await sslClient.ReadAsync(buf);
|
||||
var info = System.Text.Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
|
||||
return sslClient;
|
||||
});
|
||||
|
||||
var serverNetStream = new NetworkStream(serverSocket, ownsSocket: true);
|
||||
var (stream, infoSent) = await TlsConnectionWrapper.NegotiateAsync(
|
||||
serverSocket, serverNetStream, opts, sslOpts, serverInfo, NullLogger.Instance, CancellationToken.None);
|
||||
|
||||
stream.ShouldBeOfType<SslStream>();
|
||||
infoSent.ShouldBeTrue();
|
||||
|
||||
var clientSsl = await clientTask;
|
||||
|
||||
// Verify encrypted communication works
|
||||
await stream.WriteAsync("PING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var readBuf = new byte[64];
|
||||
var bytesRead = await clientSsl.ReadAsync(readBuf);
|
||||
var msg = System.Text.Encoding.ASCII.GetString(readBuf, 0, bytesRead);
|
||||
msg.ShouldBe("PING\r\n");
|
||||
|
||||
stream.Dispose();
|
||||
clientSsl.Dispose();
|
||||
}
|
||||
|
||||
private static ServerInfo CreateServerInfo() => new()
|
||||
{
|
||||
ServerId = "TEST",
|
||||
ServerName = "test",
|
||||
Version = NatsProtocol.Version,
|
||||
Host = "127.0.0.1",
|
||||
Port = 4222,
|
||||
};
|
||||
|
||||
private static async Task<(Socket server, Socket client)> CreateSocketPairAsync()
|
||||
{
|
||||
using var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
|
||||
listener.Listen(1);
|
||||
var port = ((IPEndPoint)listener.LocalEndPoint!).Port;
|
||||
|
||||
var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, port));
|
||||
var server = await listener.AcceptAsync();
|
||||
|
||||
return (server, client);
|
||||
}
|
||||
}
|
||||
133
tests/NATS.Server.Transport.Tests/TlsHelperTests.cs
Normal file
133
tests/NATS.Server.Transport.Tests/TlsHelperTests.cs
Normal file
@@ -0,0 +1,133 @@
|
||||
using System.Net;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using NATS.Server;
|
||||
using NATS.Server.TestUtilities;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsHelperTests
|
||||
{
|
||||
[Fact]
|
||||
public void LoadCertificate_loads_pem_cert_and_key()
|
||||
{
|
||||
var (certPath, keyPath) = TestCertHelper.GenerateTestCertFiles();
|
||||
try
|
||||
{
|
||||
var cert = TlsHelper.LoadCertificate(certPath, keyPath);
|
||||
cert.ShouldNotBeNull();
|
||||
cert.HasPrivateKey.ShouldBeTrue();
|
||||
}
|
||||
finally { File.Delete(certPath); File.Delete(keyPath); }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildServerAuthOptions_creates_valid_options()
|
||||
{
|
||||
var (certPath, keyPath) = TestCertHelper.GenerateTestCertFiles();
|
||||
try
|
||||
{
|
||||
var opts = new NatsOptions { TlsCert = certPath, TlsKey = keyPath };
|
||||
var authOpts = TlsHelper.BuildServerAuthOptions(opts);
|
||||
authOpts.ShouldNotBeNull();
|
||||
authOpts.ServerCertificate.ShouldNotBeNull();
|
||||
}
|
||||
finally { File.Delete(certPath); File.Delete(keyPath); }
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LoadCaCertificates_rejects_non_certificate_pem_block()
|
||||
{
|
||||
var (_, key) = TestCertHelper.GenerateTestCert();
|
||||
var pemPath = Path.GetTempFileName();
|
||||
try
|
||||
{
|
||||
File.WriteAllText(pemPath, key.ExportPkcs8PrivateKeyPem());
|
||||
Should.Throw<InvalidDataException>(() => TlsHelper.LoadCaCertificates(pemPath));
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(pemPath);
|
||||
key.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void LoadCaCertificates_loads_multiple_certificate_blocks()
|
||||
{
|
||||
var (certA, keyA) = GenerateTestCert();
|
||||
var (certB, keyB) = GenerateTestCert();
|
||||
var pemPath = Path.GetTempFileName();
|
||||
try
|
||||
{
|
||||
File.WriteAllText(pemPath, certA.ExportCertificatePem() + certB.ExportCertificatePem());
|
||||
var collection = TlsHelper.LoadCaCertificates(pemPath);
|
||||
collection.Count.ShouldBe(2);
|
||||
}
|
||||
finally
|
||||
{
|
||||
File.Delete(pemPath);
|
||||
certA.Dispose();
|
||||
certB.Dispose();
|
||||
keyA.Dispose();
|
||||
keyB.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatchesPinnedCert_matches_correct_hash()
|
||||
{
|
||||
var (cert, _) = GenerateTestCert();
|
||||
var hash = TlsHelper.GetCertificateHash(cert);
|
||||
var pinned = new HashSet<string> { hash };
|
||||
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MatchesPinnedCert_rejects_wrong_hash()
|
||||
{
|
||||
var (cert, _) = GenerateTestCert();
|
||||
var pinned = new HashSet<string> { "0000000000000000000000000000000000000000000000000000000000000000" };
|
||||
TlsHelper.MatchesPinnedCert(cert, pinned).ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task PeekableStream_peeks_and_replays()
|
||||
{
|
||||
var data = "Hello, World!"u8.ToArray();
|
||||
using var ms = new MemoryStream(data);
|
||||
using var peekable = new PeekableStream(ms);
|
||||
|
||||
var peeked = await peekable.PeekAsync(1);
|
||||
peeked.Length.ShouldBe(1);
|
||||
peeked[0].ShouldBe((byte)'H');
|
||||
|
||||
var buf = new byte[data.Length];
|
||||
int total = 0;
|
||||
while (total < data.Length)
|
||||
{
|
||||
var read = await peekable.ReadAsync(buf.AsMemory(total));
|
||||
if (read == 0) break;
|
||||
total += read;
|
||||
}
|
||||
total.ShouldBe(data.Length);
|
||||
buf.ShouldBe(data);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task TlsRateLimiter_allows_within_limit()
|
||||
{
|
||||
using var limiter = new TlsRateLimiter(10);
|
||||
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(2));
|
||||
for (int i = 0; i < 5; i++)
|
||||
await limiter.WaitAsync(cts.Token);
|
||||
}
|
||||
|
||||
// Delegate to shared TestCertHelper in TestUtilities
|
||||
public static (string certPath, string keyPath) GenerateTestCertFiles()
|
||||
=> TestCertHelper.GenerateTestCertFiles();
|
||||
|
||||
public static (X509Certificate2 cert, RSA key) GenerateTestCert()
|
||||
=> TestCertHelper.GenerateTestCert();
|
||||
}
|
||||
134
tests/NATS.Server.Transport.Tests/TlsMapAuthenticatorTests.cs
Normal file
134
tests/NATS.Server.Transport.Tests/TlsMapAuthenticatorTests.cs
Normal file
@@ -0,0 +1,134 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using NATS.Server.Auth;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsMapAuthenticatorTests
|
||||
{
|
||||
private static X509Certificate2 CreateSelfSignedCert(string cn)
|
||||
{
|
||||
using var rsa = RSA.Create(2048);
|
||||
var req = new CertificateRequest($"CN={cn}", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
return req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1));
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateCertWithDn(string dn)
|
||||
{
|
||||
using var rsa = RSA.Create(2048);
|
||||
var req = new CertificateRequest(dn, rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
return req.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Matches_user_by_cn()
|
||||
{
|
||||
var users = new List<User>
|
||||
{
|
||||
new() { Username = "alice", Password = "" },
|
||||
};
|
||||
var auth = new TlsMapAuthenticator(users);
|
||||
var cert = CreateSelfSignedCert("alice");
|
||||
|
||||
var ctx = new ClientAuthContext
|
||||
{
|
||||
Opts = new Protocol.ClientOptions(),
|
||||
Nonce = [],
|
||||
ClientCertificate = cert,
|
||||
};
|
||||
|
||||
var result = auth.Authenticate(ctx);
|
||||
result.ShouldNotBeNull();
|
||||
result.Identity.ShouldBe("alice");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Returns_null_when_no_cert()
|
||||
{
|
||||
var users = new List<User>
|
||||
{
|
||||
new() { Username = "alice", Password = "" },
|
||||
};
|
||||
var auth = new TlsMapAuthenticator(users);
|
||||
|
||||
var ctx = new ClientAuthContext
|
||||
{
|
||||
Opts = new Protocol.ClientOptions(),
|
||||
Nonce = [],
|
||||
ClientCertificate = null,
|
||||
};
|
||||
|
||||
var result = auth.Authenticate(ctx);
|
||||
result.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Returns_null_when_cn_doesnt_match()
|
||||
{
|
||||
var users = new List<User>
|
||||
{
|
||||
new() { Username = "alice", Password = "" },
|
||||
};
|
||||
var auth = new TlsMapAuthenticator(users);
|
||||
var cert = CreateSelfSignedCert("bob");
|
||||
|
||||
var ctx = new ClientAuthContext
|
||||
{
|
||||
Opts = new Protocol.ClientOptions(),
|
||||
Nonce = [],
|
||||
ClientCertificate = cert,
|
||||
};
|
||||
|
||||
var result = auth.Authenticate(ctx);
|
||||
result.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Matches_by_full_dn_string()
|
||||
{
|
||||
var users = new List<User>
|
||||
{
|
||||
new() { Username = "CN=alice, O=TestOrg", Password = "" },
|
||||
};
|
||||
var auth = new TlsMapAuthenticator(users);
|
||||
var cert = CreateCertWithDn("CN=alice, O=TestOrg");
|
||||
|
||||
var ctx = new ClientAuthContext
|
||||
{
|
||||
Opts = new Protocol.ClientOptions(),
|
||||
Nonce = [],
|
||||
ClientCertificate = cert,
|
||||
};
|
||||
|
||||
var result = auth.Authenticate(ctx);
|
||||
result.ShouldNotBeNull();
|
||||
result.Identity.ShouldBe("CN=alice, O=TestOrg");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Returns_permissions_from_matched_user()
|
||||
{
|
||||
var perms = new Permissions
|
||||
{
|
||||
Publish = new SubjectPermission { Allow = ["foo.>"] },
|
||||
};
|
||||
var users = new List<User>
|
||||
{
|
||||
new() { Username = "alice", Password = "", Permissions = perms },
|
||||
};
|
||||
var auth = new TlsMapAuthenticator(users);
|
||||
var cert = CreateSelfSignedCert("alice");
|
||||
|
||||
var ctx = new ClientAuthContext
|
||||
{
|
||||
Opts = new Protocol.ClientOptions(),
|
||||
Nonce = [],
|
||||
ClientCertificate = cert,
|
||||
};
|
||||
|
||||
var result = auth.Authenticate(ctx);
|
||||
result.ShouldNotBeNull();
|
||||
result.Permissions.ShouldNotBeNull();
|
||||
result.Permissions.Publish!.Allow!.ShouldContain("foo.>");
|
||||
}
|
||||
}
|
||||
134
tests/NATS.Server.Transport.Tests/TlsOcspParityBatch1Tests.cs
Normal file
134
tests/NATS.Server.Transport.Tests/TlsOcspParityBatch1Tests.cs
Normal file
@@ -0,0 +1,134 @@
|
||||
using System.Security.Cryptography;
|
||||
using System.Text.Json;
|
||||
using NATS.Server.Configuration;
|
||||
using NATS.Server.TestUtilities;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsOcspParityBatch1Tests
|
||||
{
|
||||
[Fact]
|
||||
public void OCSPPeerConfig_defaults_match_go_reference()
|
||||
{
|
||||
var cfg = OCSPPeerConfig.NewOCSPPeerConfig();
|
||||
|
||||
cfg.Verify.ShouldBeFalse();
|
||||
cfg.Timeout.ShouldBe(2d);
|
||||
cfg.ClockSkew.ShouldBe(30d);
|
||||
cfg.WarnOnly.ShouldBeFalse();
|
||||
cfg.UnknownIsGood.ShouldBeFalse();
|
||||
cfg.AllowWhenCAUnreachable.ShouldBeFalse();
|
||||
cfg.TTLUnsetNextUpdate.ShouldBe(3600d);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OCSPPeerConfig_parse_map_parses_supported_fields()
|
||||
{
|
||||
var cfg = OCSPPeerConfig.Parse(new Dictionary<string, object?>
|
||||
{
|
||||
["verify"] = true,
|
||||
["allowed_clockskew"] = "45s",
|
||||
["ca_timeout"] = 1.5d,
|
||||
["cache_ttl_when_next_update_unset"] = 120L,
|
||||
["warn_only"] = true,
|
||||
["unknown_is_good"] = true,
|
||||
["allow_when_ca_unreachable"] = true,
|
||||
});
|
||||
|
||||
cfg.Verify.ShouldBeTrue();
|
||||
cfg.ClockSkew.ShouldBe(45d);
|
||||
cfg.Timeout.ShouldBe(1.5d);
|
||||
cfg.TTLUnsetNextUpdate.ShouldBe(120d);
|
||||
cfg.WarnOnly.ShouldBeTrue();
|
||||
cfg.UnknownIsGood.ShouldBeTrue();
|
||||
cfg.AllowWhenCAUnreachable.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OCSPPeerConfig_parse_unknown_field_throws()
|
||||
{
|
||||
var ex = Should.Throw<FormatException>(() =>
|
||||
OCSPPeerConfig.Parse(new Dictionary<string, object?> { ["bogus"] = true }));
|
||||
|
||||
ex.Message.ShouldContain("unknown field [bogus]");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConfigProcessor_parses_ocsp_peer_short_form()
|
||||
{
|
||||
var opts = ConfigProcessor.ProcessConfig("""
|
||||
tls {
|
||||
ocsp_peer: true
|
||||
}
|
||||
""");
|
||||
|
||||
opts.OcspPeerVerify.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ConfigProcessor_parses_ocsp_peer_long_form_verify()
|
||||
{
|
||||
var opts = ConfigProcessor.ProcessConfig("""
|
||||
tls {
|
||||
ocsp_peer {
|
||||
verify: true
|
||||
ca_timeout: 2s
|
||||
allowed_clockskew: 30s
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
opts.OcspPeerVerify.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GenerateFingerprint_uses_raw_certificate_sha256()
|
||||
{
|
||||
var (cert, _) = TestCertHelper.GenerateTestCert();
|
||||
|
||||
var expected = Convert.ToBase64String(SHA256.HashData(cert.RawData));
|
||||
TlsHelper.GenerateFingerprint(cert).ShouldBe(expected);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetWebEndpoints_filters_non_web_uris()
|
||||
{
|
||||
var urls = TlsHelper.GetWebEndpoints(
|
||||
["http://a.example", "https://b.example", "ftp://bad.example", "not a uri"]);
|
||||
|
||||
urls.Count.ShouldBe(2);
|
||||
urls[0].Scheme.ShouldBe(Uri.UriSchemeHttp);
|
||||
urls[1].Scheme.ShouldBe(Uri.UriSchemeHttps);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Subject_and_issuer_dn_helpers_return_values_and_empty_for_null()
|
||||
{
|
||||
var (cert, _) = TestCertHelper.GenerateTestCert();
|
||||
|
||||
TlsHelper.GetSubjectDNForm(cert).ShouldNotBeNullOrWhiteSpace();
|
||||
TlsHelper.GetIssuerDNForm(cert).ShouldNotBeNullOrWhiteSpace();
|
||||
TlsHelper.GetSubjectDNForm(null).ShouldBe(string.Empty);
|
||||
TlsHelper.GetIssuerDNForm(null).ShouldBe(string.Empty);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StatusAssertion_json_converter_uses_string_values_and_unknown_fallback()
|
||||
{
|
||||
var revokedJson = JsonSerializer.Serialize(StatusAssertion.Revoked);
|
||||
revokedJson.ShouldBe("\"revoked\"");
|
||||
|
||||
var unknown = JsonSerializer.Deserialize<StatusAssertion>("\"nonsense\"");
|
||||
unknown.ShouldBe(StatusAssertion.Unknown);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspPeer_messages_match_go_literals()
|
||||
{
|
||||
OcspPeerMessages.MsgTLSClientRejectConnection.ShouldBe("client not OCSP valid");
|
||||
OcspPeerMessages.MsgTLSServerRejectConnection.ShouldBe("server not OCSP valid");
|
||||
OcspPeerMessages.MsgCacheOnline.ShouldBe("OCSP peer cache online, type [%s]");
|
||||
OcspPeerMessages.MsgCacheOffline.ShouldBe("OCSP peer cache offline, type [%s]");
|
||||
}
|
||||
}
|
||||
166
tests/NATS.Server.Transport.Tests/TlsOcspParityBatch2Tests.cs
Normal file
166
tests/NATS.Server.Transport.Tests/TlsOcspParityBatch2Tests.cs
Normal file
@@ -0,0 +1,166 @@
|
||||
using System.Formats.Asn1;
|
||||
using System.Security.Cryptography;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using NATS.Server.TestUtilities;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsOcspParityBatch2Tests
|
||||
{
|
||||
[Fact]
|
||||
public void CertOCSPEligible_returns_true_and_populates_endpoints_for_http_ocsp_aia()
|
||||
{
|
||||
using var cert = CreateLeafWithOcspAia("http://ocsp.example.test");
|
||||
var link = new ChainLink { Leaf = cert };
|
||||
|
||||
var eligible = TlsHelper.CertOCSPEligible(link);
|
||||
|
||||
eligible.ShouldBeTrue();
|
||||
link.OCSPWebEndpoints.ShouldNotBeNull();
|
||||
link.OCSPWebEndpoints!.Count.ShouldBe(1);
|
||||
link.OCSPWebEndpoints[0].ToString().ShouldBe("http://ocsp.example.test/");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CertOCSPEligible_returns_false_when_leaf_has_no_ocsp_servers()
|
||||
{
|
||||
var (leaf, _) = TestCertHelper.GenerateTestCert();
|
||||
var link = new ChainLink { Leaf = leaf };
|
||||
|
||||
TlsHelper.CertOCSPEligible(link).ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetLeafIssuerCert_returns_positional_issuer_or_null()
|
||||
{
|
||||
using var root = CreateRootCertificate();
|
||||
using var leaf = CreateLeafSignedBy(root);
|
||||
|
||||
var chain = new[] { leaf, root };
|
||||
TlsHelper.GetLeafIssuerCert(chain, 0).ShouldBe(root);
|
||||
TlsHelper.GetLeafIssuerCert(chain, 1).ShouldBeNull();
|
||||
TlsHelper.GetLeafIssuerCert(chain, -1).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetLeafIssuer_returns_verified_issuer_from_chain()
|
||||
{
|
||||
using var root = CreateRootCertificate();
|
||||
using var leaf = CreateLeafSignedBy(root);
|
||||
|
||||
using var issuer = TlsHelper.GetLeafIssuer(leaf, root);
|
||||
|
||||
issuer.ShouldNotBeNull();
|
||||
issuer!.Thumbprint.ShouldBe(root.Thumbprint);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspResponseCurrent_applies_skew_and_ttl_rules()
|
||||
{
|
||||
var opts = OCSPPeerConfig.NewOCSPPeerConfig();
|
||||
var now = DateTime.UtcNow;
|
||||
|
||||
TlsHelper.OcspResponseCurrent(new OcspResponseInfo
|
||||
{
|
||||
ThisUpdate = now.AddMinutes(-1),
|
||||
NextUpdate = now.AddMinutes(5),
|
||||
}, opts).ShouldBeTrue();
|
||||
|
||||
TlsHelper.OcspResponseCurrent(new OcspResponseInfo
|
||||
{
|
||||
ThisUpdate = now.AddHours(-2),
|
||||
NextUpdate = null,
|
||||
}, opts).ShouldBeFalse();
|
||||
|
||||
TlsHelper.OcspResponseCurrent(new OcspResponseInfo
|
||||
{
|
||||
ThisUpdate = now.AddMinutes(2),
|
||||
NextUpdate = now.AddHours(1),
|
||||
}, opts).ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ValidDelegationCheck_accepts_direct_and_ocsp_signing_delegate()
|
||||
{
|
||||
using var issuer = CreateRootCertificate();
|
||||
using var delegateCert = CreateOcspSigningDelegate(issuer);
|
||||
|
||||
TlsHelper.ValidDelegationCheck(issuer, null).ShouldBeTrue();
|
||||
TlsHelper.ValidDelegationCheck(issuer, issuer).ShouldBeTrue();
|
||||
TlsHelper.ValidDelegationCheck(issuer, delegateCert).ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OcspPeerMessages_exposes_error_and_debug_constants()
|
||||
{
|
||||
OcspPeerMessages.ErrIllegalPeerOptsConfig.ShouldContain("expected map to define OCSP peer options");
|
||||
OcspPeerMessages.ErrNoAvailOCSPServers.ShouldBe("no available OCSP servers");
|
||||
OcspPeerMessages.DbgPlugTLSForKind.ShouldBe("Plugging TLS OCSP peer for [%s]");
|
||||
OcspPeerMessages.DbgCacheSaved.ShouldBe("Saved OCSP peer cache successfully (%d bytes)");
|
||||
OcspPeerMessages.MsgFailedOCSPResponseFetch.ShouldBe("Failed OCSP response fetch");
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateLeafWithOcspAia(string ocspUri)
|
||||
{
|
||||
using var key = RSA.Create(2048);
|
||||
var req = new CertificateRequest("CN=leaf-with-ocsp", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, false));
|
||||
req.CertificateExtensions.Add(CreateOcspAiaExtension(ocspUri));
|
||||
return req.CreateSelfSigned(DateTimeOffset.UtcNow.AddDays(-1), DateTimeOffset.UtcNow.AddDays(30));
|
||||
}
|
||||
|
||||
private static X509Extension CreateOcspAiaExtension(string ocspUri)
|
||||
{
|
||||
var writer = new AsnWriter(AsnEncodingRules.DER);
|
||||
writer.PushSequence();
|
||||
writer.PushSequence();
|
||||
writer.WriteObjectIdentifier("1.3.6.1.5.5.7.48.1");
|
||||
writer.WriteCharacterString(UniversalTagNumber.IA5String, ocspUri, new Asn1Tag(TagClass.ContextSpecific, 6));
|
||||
writer.PopSequence();
|
||||
writer.PopSequence();
|
||||
return new X509Extension("1.3.6.1.5.5.7.1.1", writer.Encode(), false);
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateRootCertificate()
|
||||
{
|
||||
using var rootKey = RSA.Create(2048);
|
||||
var req = new CertificateRequest("CN=Root", rootKey, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(true, false, 0, true));
|
||||
req.CertificateExtensions.Add(new X509KeyUsageExtension(X509KeyUsageFlags.KeyCertSign | X509KeyUsageFlags.CrlSign, true));
|
||||
return req.CreateSelfSigned(DateTimeOffset.UtcNow.AddDays(-1), DateTimeOffset.UtcNow.AddYears(5));
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateLeafSignedBy(X509Certificate2 issuer)
|
||||
{
|
||||
using var leafKey = RSA.Create(2048);
|
||||
var req = new CertificateRequest("CN=Leaf", leafKey, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, true));
|
||||
req.CertificateExtensions.Add(new X509SubjectKeyIdentifierExtension(req.PublicKey, false));
|
||||
|
||||
var cert = req.Create(
|
||||
issuer,
|
||||
DateTimeOffset.UtcNow.AddDays(-1),
|
||||
DateTimeOffset.UtcNow.AddYears(1),
|
||||
Guid.NewGuid().ToByteArray());
|
||||
|
||||
return cert.CopyWithPrivateKey(leafKey);
|
||||
}
|
||||
|
||||
private static X509Certificate2 CreateOcspSigningDelegate(X509Certificate2 issuer)
|
||||
{
|
||||
using var key = RSA.Create(2048);
|
||||
var req = new CertificateRequest("CN=OCSP Delegate", key, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
|
||||
req.CertificateExtensions.Add(new X509BasicConstraintsExtension(false, false, 0, true));
|
||||
req.CertificateExtensions.Add(new X509EnhancedKeyUsageExtension(
|
||||
[new Oid("1.3.6.1.5.5.7.3.9")], true));
|
||||
|
||||
var cert = req.Create(
|
||||
issuer,
|
||||
DateTimeOffset.UtcNow.AddDays(-1),
|
||||
DateTimeOffset.UtcNow.AddYears(1),
|
||||
Guid.NewGuid().ToByteArray());
|
||||
|
||||
return cert.CopyWithPrivateKey(key);
|
||||
}
|
||||
}
|
||||
50
tests/NATS.Server.Transport.Tests/TlsRateLimiterTests.cs
Normal file
50
tests/NATS.Server.Transport.Tests/TlsRateLimiterTests.cs
Normal file
@@ -0,0 +1,50 @@
|
||||
using NATS.Server.TestUtilities;
|
||||
using NATS.Server.Tls;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsRateLimiterTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task Rate_limiter_allows_configured_tokens_per_second()
|
||||
{
|
||||
using var limiter = new TlsRateLimiter(5);
|
||||
|
||||
// Should allow 5 tokens immediately
|
||||
for (int i = 0; i < 5; i++)
|
||||
{
|
||||
using var cts = new CancellationTokenSource(100);
|
||||
await limiter.WaitAsync(cts.Token); // Should not throw
|
||||
}
|
||||
|
||||
// 6th token should block (no refill yet)
|
||||
using var blockCts = new CancellationTokenSource(200);
|
||||
var blocked = false;
|
||||
try
|
||||
{
|
||||
await limiter.WaitAsync(blockCts.Token);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
blocked = true;
|
||||
}
|
||||
blocked.ShouldBeTrue("6th token should be blocked before refill");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Rate_limiter_refills_after_one_second()
|
||||
{
|
||||
using var limiter = new TlsRateLimiter(2);
|
||||
|
||||
// Consume all tokens
|
||||
await limiter.WaitAsync(CancellationToken.None);
|
||||
await limiter.WaitAsync(CancellationToken.None);
|
||||
|
||||
// Wait for refill (rate limiter refills tokens after 1 second)
|
||||
await PollHelper.YieldForAsync(1200);
|
||||
|
||||
// Should have tokens again
|
||||
using var cts = new CancellationTokenSource(200);
|
||||
await limiter.WaitAsync(cts.Token); // Should not throw
|
||||
}
|
||||
}
|
||||
198
tests/NATS.Server.Transport.Tests/TlsServerTests.cs
Normal file
198
tests/NATS.Server.Transport.Tests/TlsServerTests.cs
Normal file
@@ -0,0 +1,198 @@
|
||||
using System.Net;
|
||||
using System.Net.Security;
|
||||
using System.Net.Sockets;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using NATS.Server;
|
||||
using NATS.Server.TestUtilities;
|
||||
|
||||
namespace NATS.Server.Transport.Tests;
|
||||
|
||||
public class TlsServerTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly string _certPath;
|
||||
private readonly string _keyPath;
|
||||
|
||||
public TlsServerTests()
|
||||
{
|
||||
_port = TestPortAllocator.GetFreePort();
|
||||
(_certPath, _keyPath) = TestCertHelper.GenerateTestCertFiles();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _port,
|
||||
TlsCert = _certPath,
|
||||
TlsKey = _keyPath,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
File.Delete(_certPath);
|
||||
File.Delete(_keyPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Tls_client_connects_and_receives_info()
|
||||
{
|
||||
using var tcp = new TcpClient();
|
||||
await tcp.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var netStream = tcp.GetStream();
|
||||
|
||||
// Read INFO (sent before TLS upgrade in Mode 2)
|
||||
var buf = new byte[4096];
|
||||
var read = await netStream.ReadAsync(buf);
|
||||
var info = Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldStartWith("INFO ");
|
||||
info.ShouldContain("\"tls_required\":true");
|
||||
|
||||
// Upgrade to TLS
|
||||
using var sslStream = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await sslStream.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
// Send CONNECT + PING over TLS
|
||||
await sslStream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await sslStream.FlushAsync();
|
||||
|
||||
// Read PONG
|
||||
var pongBuf = new byte[256];
|
||||
read = await sslStream.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Tls_pubsub_works_over_encrypted_connection()
|
||||
{
|
||||
using var tcp1 = new TcpClient();
|
||||
await tcp1.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var ssl1 = await UpgradeToTlsAsync(tcp1);
|
||||
|
||||
using var tcp2 = new TcpClient();
|
||||
await tcp2.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var ssl2 = await UpgradeToTlsAsync(tcp2);
|
||||
|
||||
// Sub on client 1
|
||||
await ssl1.WriteAsync("CONNECT {}\r\nSUB test 1\r\nPING\r\n"u8.ToArray());
|
||||
await ssl1.FlushAsync();
|
||||
|
||||
// Wait for PONG to confirm subscription is registered
|
||||
var pongBuf = new byte[256];
|
||||
var pongRead = await ssl1.ReadAsync(pongBuf);
|
||||
var pongStr = Encoding.ASCII.GetString(pongBuf, 0, pongRead);
|
||||
pongStr.ShouldContain("PONG");
|
||||
|
||||
// Pub on client 2
|
||||
await ssl2.WriteAsync("CONNECT {}\r\nPUB test 5\r\nhello\r\nPING\r\n"u8.ToArray());
|
||||
await ssl2.FlushAsync();
|
||||
|
||||
// Client 1 should receive MSG (may arrive across multiple TLS records)
|
||||
var msg = await SocketTestHelper.ReadUntilAsync(ssl1, "hello");
|
||||
msg.ShouldContain("MSG test 1 5");
|
||||
msg.ShouldContain("hello");
|
||||
}
|
||||
|
||||
private static async Task<SslStream> UpgradeToTlsAsync(TcpClient tcp)
|
||||
{
|
||||
var netStream = tcp.GetStream();
|
||||
var buf = new byte[4096];
|
||||
_ = await netStream.ReadAsync(buf); // Read INFO (discard)
|
||||
|
||||
var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await ssl.AuthenticateAsClientAsync("localhost");
|
||||
return ssl;
|
||||
}
|
||||
}
|
||||
|
||||
public class TlsMixedModeTests : IAsyncLifetime
|
||||
{
|
||||
private readonly NatsServer _server;
|
||||
private readonly int _port;
|
||||
private readonly CancellationTokenSource _cts = new();
|
||||
private readonly string _certPath;
|
||||
private readonly string _keyPath;
|
||||
|
||||
public TlsMixedModeTests()
|
||||
{
|
||||
_port = TestPortAllocator.GetFreePort();
|
||||
(_certPath, _keyPath) = TestCertHelper.GenerateTestCertFiles();
|
||||
_server = new NatsServer(
|
||||
new NatsOptions
|
||||
{
|
||||
Port = _port,
|
||||
TlsCert = _certPath,
|
||||
TlsKey = _keyPath,
|
||||
AllowNonTls = true,
|
||||
},
|
||||
NullLoggerFactory.Instance);
|
||||
}
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_ = _server.StartAsync(_cts.Token);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _cts.CancelAsync();
|
||||
_server.Dispose();
|
||||
File.Delete(_certPath);
|
||||
File.Delete(_keyPath);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Mixed_mode_accepts_plain_client()
|
||||
{
|
||||
using var sock = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await sock.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _port));
|
||||
using var stream = new NetworkStream(sock);
|
||||
|
||||
var buf = new byte[4096];
|
||||
var read = await stream.ReadAsync(buf);
|
||||
var info = Encoding.ASCII.GetString(buf, 0, read);
|
||||
info.ShouldContain("\"tls_available\":true");
|
||||
|
||||
await stream.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await stream.FlushAsync();
|
||||
|
||||
var pongBuf = new byte[64];
|
||||
read = await stream.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Mixed_mode_accepts_tls_client()
|
||||
{
|
||||
using var tcp = new TcpClient();
|
||||
await tcp.ConnectAsync(IPAddress.Loopback, _port);
|
||||
using var netStream = tcp.GetStream();
|
||||
|
||||
var buf = new byte[4096];
|
||||
_ = await netStream.ReadAsync(buf); // Read INFO
|
||||
|
||||
using var ssl = new SslStream(netStream, false, (_, _, _, _) => true);
|
||||
await ssl.AuthenticateAsClientAsync("localhost");
|
||||
|
||||
await ssl.WriteAsync("CONNECT {}\r\nPING\r\n"u8.ToArray());
|
||||
await ssl.FlushAsync();
|
||||
|
||||
var pongBuf = new byte[64];
|
||||
var read = await ssl.ReadAsync(pongBuf);
|
||||
var pong = Encoding.ASCII.GetString(pongBuf, 0, read);
|
||||
pong.ShouldContain("PONG");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
using Shouldly;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WebSocketOptionsTests
|
||||
{
|
||||
[Fact]
|
||||
public void DefaultOptions_PortIsNegativeOne_Disabled()
|
||||
{
|
||||
var opts = new WebSocketOptions();
|
||||
opts.Port.ShouldBe(-1);
|
||||
opts.Host.ShouldBe("0.0.0.0");
|
||||
opts.Compression.ShouldBeFalse();
|
||||
opts.NoTls.ShouldBeFalse();
|
||||
opts.HandshakeTimeout.ShouldBe(TimeSpan.FromSeconds(2));
|
||||
opts.AuthTimeout.ShouldBe(TimeSpan.FromSeconds(2));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NatsOptions_HasWebSocketProperty()
|
||||
{
|
||||
var opts = new NatsOptions();
|
||||
opts.WebSocket.ShouldNotBeNull();
|
||||
opts.WebSocket.Port.ShouldBe(-1);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WsAuthConfig_sets_auth_override_when_websocket_auth_fields_are_present()
|
||||
{
|
||||
var ws = new WebSocketOptions
|
||||
{
|
||||
Username = "u",
|
||||
};
|
||||
|
||||
WsAuthConfig.Apply(ws);
|
||||
|
||||
ws.AuthOverride.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WsAuthConfig_keeps_auth_override_false_when_no_ws_auth_fields_are_present()
|
||||
{
|
||||
var ws = new WebSocketOptions();
|
||||
|
||||
WsAuthConfig.Apply(ws);
|
||||
|
||||
ws.AuthOverride.ShouldBeFalse();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
using NATS.Server.Auth;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WebSocketOptionsValidatorParityBatch2Tests
|
||||
{
|
||||
[Fact]
|
||||
public void Validate_rejects_tls_listener_without_cert_key_when_not_no_tls()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = false,
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("TLS", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_rejects_invalid_allowed_origins()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
AllowedOrigins = ["not-a-uri"],
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("allowed origin", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_rejects_no_auth_user_not_present_in_configured_users()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
Users = [new User { Username = "alice", Password = "x" }],
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
NoAuthUser = "bob",
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("NoAuthUser", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_rejects_username_or_token_when_users_or_nkeys_are_set()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
Users = [new User { Username = "alice", Password = "x" }],
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
Username = "ws-user",
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("users", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_rejects_jwt_cookie_without_trusted_operators()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
JwtCookie = "jwt",
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("JwtCookie", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_rejects_reserved_response_headers_override()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TrustedKeys = ["OP1"],
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
Headers = new Dictionary<string, string>
|
||||
{
|
||||
["Sec-WebSocket-Accept"] = "bad",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("reserved", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_rejects_tls_pinned_certs_when_websocket_tls_is_disabled()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TlsPinnedCerts = ["ABCDEF0123"],
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldContain(e => e.Contains("TLSPinnedCerts", StringComparison.OrdinalIgnoreCase));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_accepts_valid_minimal_configuration()
|
||||
{
|
||||
var opts = new NatsOptions
|
||||
{
|
||||
TrustedKeys = ["OP1"],
|
||||
Users = [new User { Username = "alice", Password = "x" }],
|
||||
WebSocket = new WebSocketOptions
|
||||
{
|
||||
Port = 8080,
|
||||
NoTls = true,
|
||||
NoAuthUser = "alice",
|
||||
AllowedOrigins = ["https://app.example.com"],
|
||||
JwtCookie = "jwt",
|
||||
Headers = new Dictionary<string, string>
|
||||
{
|
||||
["X-App-Version"] = "1",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
var result = WebSocketOptionsValidator.Validate(opts);
|
||||
|
||||
result.IsValid.ShouldBeTrue();
|
||||
result.Errors.ShouldBeEmpty();
|
||||
}
|
||||
}
|
||||
119
tests/NATS.Server.Transport.Tests/WebSocket/WebSocketTlsTests.cs
Normal file
119
tests/NATS.Server.Transport.Tests/WebSocket/WebSocketTlsTests.cs
Normal file
@@ -0,0 +1,119 @@
|
||||
// Go reference: server/websocket.go — wsTLSConfig and related TLS handling.
|
||||
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Tests for <see cref="WebSocketTlsConfig"/> — WebSocket-specific TLS configuration
|
||||
/// with separate cert/key from the main NATS listener (Gap 15.1).
|
||||
/// </summary>
|
||||
public class WebSocketTlsTests
|
||||
{
|
||||
// ── IsConfigured ──────────────────────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void IsConfigured_WithCert_ReturnsTrue()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig { CertFile = "server.pem", KeyFile = "server-key.pem" };
|
||||
cfg.IsConfigured.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void IsConfigured_WithoutCert_ReturnsFalse()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig();
|
||||
cfg.IsConfigured.ShouldBeFalse();
|
||||
}
|
||||
|
||||
// ── Validate ──────────────────────────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void Validate_ValidConfig_NoErrors()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig { CertFile = "server.pem", KeyFile = "server-key.pem" };
|
||||
var result = cfg.Validate();
|
||||
result.IsValid.ShouldBeTrue();
|
||||
result.Errors.ShouldBeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_KeyWithoutCert_HasError()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig { KeyFile = "server-key.pem" };
|
||||
var result = cfg.Validate();
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldNotBeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_CertWithoutKey_HasError()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig { CertFile = "server.pem" };
|
||||
var result = cfg.Validate();
|
||||
result.IsValid.ShouldBeFalse();
|
||||
result.Errors.ShouldNotBeEmpty();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Validate_EmptyConfig_Valid()
|
||||
{
|
||||
// An empty configuration means "no TLS" — that is a valid state.
|
||||
var cfg = new WebSocketTlsConfig();
|
||||
var result = cfg.Validate();
|
||||
result.IsValid.ShouldBeTrue();
|
||||
result.Errors.ShouldBeEmpty();
|
||||
}
|
||||
|
||||
// ── HasChangedFrom ────────────────────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void HasChangedFrom_SameConfig_ReturnsFalse()
|
||||
{
|
||||
var a = new WebSocketTlsConfig
|
||||
{
|
||||
CertFile = "server.pem",
|
||||
KeyFile = "server-key.pem",
|
||||
CaFile = "ca.pem",
|
||||
RequireClientCert = true,
|
||||
InsecureSkipVerify = false,
|
||||
};
|
||||
|
||||
var b = new WebSocketTlsConfig
|
||||
{
|
||||
CertFile = "server.pem",
|
||||
KeyFile = "server-key.pem",
|
||||
CaFile = "ca.pem",
|
||||
RequireClientCert = true,
|
||||
InsecureSkipVerify = false,
|
||||
};
|
||||
|
||||
a.HasChangedFrom(b).ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void HasChangedFrom_DifferentCert_ReturnsTrue()
|
||||
{
|
||||
var a = new WebSocketTlsConfig { CertFile = "old.pem", KeyFile = "old-key.pem" };
|
||||
var b = new WebSocketTlsConfig { CertFile = "new.pem", KeyFile = "new-key.pem" };
|
||||
|
||||
a.HasChangedFrom(b).ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void HasChangedFrom_NullOther_ReturnsTrue()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig { CertFile = "server.pem", KeyFile = "server-key.pem" };
|
||||
cfg.HasChangedFrom(null).ShouldBeTrue();
|
||||
}
|
||||
|
||||
// ── Defaults ──────────────────────────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void RequireClientCert_Default_False()
|
||||
{
|
||||
var cfg = new WebSocketTlsConfig();
|
||||
cfg.RequireClientCert.ShouldBeFalse();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
// Tests for WebSocket permessage-deflate parameter negotiation (E10).
|
||||
// Verifies RFC 7692 extension parameter parsing and negotiation during
|
||||
// WebSocket upgrade handshake.
|
||||
// Reference: golang/nats-server/server/websocket.go — wsPMCExtensionSupport (line 885).
|
||||
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsCompressionNegotiationTests
|
||||
{
|
||||
// ─── WsDeflateNegotiator.Negotiate tests ──────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_NullHeader_ReturnsNull()
|
||||
{
|
||||
// Go parity: wsPMCExtensionSupport — no extension header means no compression
|
||||
var result = WsDeflateNegotiator.Negotiate(null);
|
||||
result.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_EmptyHeader_ReturnsNull()
|
||||
{
|
||||
var result = WsDeflateNegotiator.Negotiate("");
|
||||
result.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_NoPermessageDeflate_ReturnsNull()
|
||||
{
|
||||
var result = WsDeflateNegotiator.Negotiate("x-webkit-deflate-frame");
|
||||
result.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_BarePermessageDeflate_ReturnsDefaults()
|
||||
{
|
||||
// Go parity: wsPMCExtensionSupport — basic extension without parameters
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
// NATS always enforces no_context_takeover
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ServerMaxWindowBits.ShouldBe(15);
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(15);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WithServerNoContextTakeover()
|
||||
{
|
||||
// Go parity: wsPMCExtensionSupport — server_no_context_takeover parameter
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_no_context_takeover");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WithClientNoContextTakeover()
|
||||
{
|
||||
// Go parity: wsPMCExtensionSupport — client_no_context_takeover parameter
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_no_context_takeover");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WithBothNoContextTakeover()
|
||||
{
|
||||
// Go parity: wsPMCExtensionSupport — both no_context_takeover parameters
|
||||
var result = WsDeflateNegotiator.Negotiate(
|
||||
"permessage-deflate; server_no_context_takeover; client_no_context_takeover");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WithServerMaxWindowBits()
|
||||
{
|
||||
// RFC 7692 Section 7.1.2.1: server_max_window_bits parameter
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; server_max_window_bits=10");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerMaxWindowBits.ShouldBe(10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WithClientMaxWindowBits_Value()
|
||||
{
|
||||
// RFC 7692 Section 7.1.2.2: client_max_window_bits with explicit value
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits=12");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(12);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WithClientMaxWindowBits_NoValue()
|
||||
{
|
||||
// RFC 7692 Section 7.1.2.2: client_max_window_bits with no value means
|
||||
// client supports any value 8-15; defaults to 15
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate; client_max_window_bits");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(15);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WindowBits_ClampedToValidRange()
|
||||
{
|
||||
// RFC 7692: valid range is 8-15
|
||||
var result = WsDeflateNegotiator.Negotiate(
|
||||
"permessage-deflate; server_max_window_bits=5; client_max_window_bits=20");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerMaxWindowBits.ShouldBe(8); // Clamped up from 5
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(15); // Clamped down from 20
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_FullParameters()
|
||||
{
|
||||
// All parameters specified
|
||||
var result = WsDeflateNegotiator.Negotiate(
|
||||
"permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=9; client_max_window_bits=11");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ServerMaxWindowBits.ShouldBe(9);
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(11);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_CaseInsensitive()
|
||||
{
|
||||
// RFC 7692 extension names are case-insensitive
|
||||
var result = WsDeflateNegotiator.Negotiate("Permessage-Deflate; Server_No_Context_Takeover");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_MultipleExtensions_PicksDeflate()
|
||||
{
|
||||
// Header may contain multiple comma-separated extensions
|
||||
var result = WsDeflateNegotiator.Negotiate(
|
||||
"x-custom-ext, permessage-deflate; server_no_context_takeover, other-ext");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_WhitespaceHandling()
|
||||
{
|
||||
// Extra whitespace around parameters
|
||||
var result = WsDeflateNegotiator.Negotiate(
|
||||
" permessage-deflate ; server_no_context_takeover ; client_max_window_bits = 10 ");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(10);
|
||||
}
|
||||
|
||||
// ─── NatsAlwaysEnforcesNoContextTakeover ─────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void Negotiate_AlwaysEnforcesNoContextTakeover()
|
||||
{
|
||||
// NATS Go server always returns server_no_context_takeover and
|
||||
// client_no_context_takeover regardless of what the client requests
|
||||
var result = WsDeflateNegotiator.Negotiate("permessage-deflate");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
}
|
||||
|
||||
// ─── WsDeflateParams.ToResponseHeaderValue tests ────────────────────
|
||||
|
||||
[Fact]
|
||||
public void DefaultParams_ResponseHeader_ContainsNoContextTakeover()
|
||||
{
|
||||
var header = WsDeflateParams.Default.ToResponseHeaderValue();
|
||||
|
||||
header.ShouldContain("permessage-deflate");
|
||||
header.ShouldContain("server_no_context_takeover");
|
||||
header.ShouldContain("client_no_context_takeover");
|
||||
header.ShouldNotContain("server_max_window_bits");
|
||||
header.ShouldNotContain("client_max_window_bits");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CustomWindowBits_ResponseHeader_IncludesValues()
|
||||
{
|
||||
var params_ = new WsDeflateParams(
|
||||
ServerNoContextTakeover: true,
|
||||
ClientNoContextTakeover: true,
|
||||
ServerMaxWindowBits: 10,
|
||||
ClientMaxWindowBits: 12);
|
||||
|
||||
var header = params_.ToResponseHeaderValue();
|
||||
|
||||
header.ShouldContain("server_max_window_bits=10");
|
||||
header.ShouldContain("client_max_window_bits=12");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void DefaultWindowBits_ResponseHeader_OmitsValues()
|
||||
{
|
||||
// RFC 7692: window bits of 15 is the default and should not be sent
|
||||
var params_ = new WsDeflateParams(
|
||||
ServerNoContextTakeover: true,
|
||||
ClientNoContextTakeover: true,
|
||||
ServerMaxWindowBits: 15,
|
||||
ClientMaxWindowBits: 15);
|
||||
|
||||
var header = params_.ToResponseHeaderValue();
|
||||
|
||||
header.ShouldNotContain("server_max_window_bits");
|
||||
header.ShouldNotContain("client_max_window_bits");
|
||||
}
|
||||
|
||||
// ─── Integration with WsUpgrade ─────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_WithDeflateParams_NegotiatesCompression()
|
||||
{
|
||||
// Go parity: WebSocket upgrade with permessage-deflate parameters
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=10\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, Compression = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeTrue();
|
||||
result.DeflateParams.ShouldNotBeNull();
|
||||
result.DeflateParams.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.DeflateParams.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
result.DeflateParams.Value.ServerMaxWindowBits.ShouldBe(10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_WithDeflateParams_ResponseIncludesNegotiatedParams()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_max_window_bits=10\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, Compression = true };
|
||||
await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
var response = ReadResponse(outputStream);
|
||||
response.ShouldContain("permessage-deflate");
|
||||
response.ShouldContain("server_no_context_takeover");
|
||||
response.ShouldContain("client_no_context_takeover");
|
||||
response.ShouldContain("client_max_window_bits=10");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_CompressionDisabled_NoDeflateParams()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, Compression = false };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeFalse();
|
||||
result.DeflateParams.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_NoExtensionHeader_NoCompression()
|
||||
{
|
||||
var request = BuildValidRequest();
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, Compression = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeFalse();
|
||||
result.DeflateParams.ShouldBeNull();
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||
sb.Append("Host: localhost:4222\r\n");
|
||||
sb.Append("Upgrade: websocket\r\n");
|
||||
sb.Append("Connection: Upgrade\r\n");
|
||||
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||
if (extraHeaders != null)
|
||||
sb.Append(extraHeaders);
|
||||
sb.Append("\r\n");
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||
{
|
||||
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||
}
|
||||
|
||||
private static string ReadResponse(MemoryStream output)
|
||||
{
|
||||
output.Position = 0;
|
||||
return Encoding.ASCII.GetString(output.ToArray());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsCompressionTests
|
||||
{
|
||||
[Fact]
|
||||
public void CompressDecompress_RoundTrip()
|
||||
{
|
||||
var original = "Hello, WebSocket compression test! This is long enough to compress."u8.ToArray();
|
||||
var compressed = WsCompression.Compress(original);
|
||||
compressed.ShouldNotBeNull();
|
||||
compressed.Length.ShouldBeGreaterThan(0);
|
||||
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
|
||||
decompressed.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Decompress_ExceedsMaxPayload_Throws()
|
||||
{
|
||||
var original = new byte[1000];
|
||||
Random.Shared.NextBytes(original);
|
||||
var compressed = WsCompression.Compress(original);
|
||||
|
||||
Should.Throw<InvalidOperationException>(() =>
|
||||
WsCompression.Decompress([compressed], maxPayload: 100));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Compress_RemovesTrailing4Bytes()
|
||||
{
|
||||
var data = new byte[200];
|
||||
Random.Shared.NextBytes(data);
|
||||
var compressed = WsCompression.Compress(data);
|
||||
|
||||
// The compressed data should be valid for decompression when we add the trailer back
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 4096);
|
||||
decompressed.ShouldBe(data);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Decompress_MultipleBuffers()
|
||||
{
|
||||
var original = new byte[500];
|
||||
Random.Shared.NextBytes(original);
|
||||
var compressed = WsCompression.Compress(original);
|
||||
|
||||
// Split compressed data into multiple chunks
|
||||
int mid = compressed.Length / 2;
|
||||
var chunk1 = compressed[..mid];
|
||||
var chunk2 = compressed[mid..];
|
||||
|
||||
var decompressed = WsCompression.Decompress([chunk1, chunk2], maxPayload: 4096);
|
||||
decompressed.ShouldBe(original);
|
||||
}
|
||||
}
|
||||
124
tests/NATS.Server.Transport.Tests/WebSocket/WsConnectionTests.cs
Normal file
124
tests/NATS.Server.Transport.Tests/WebSocket/WsConnectionTests.cs
Normal file
@@ -0,0 +1,124 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsConnectionTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ReadAsync_DecodesFrameAndReturnsPayload()
|
||||
{
|
||||
var payload = "SUB test 1\r\n"u8.ToArray();
|
||||
var frame = BuildUnmaskedFrame(payload);
|
||||
var inner = new MemoryStream(frame);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
|
||||
n.ShouldBe(payload.Length);
|
||||
buf[..n].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_FramesPayload()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = "MSG test 1 5\r\nHello\r\n"u8.ToArray();
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// First 2 bytes should be WS frame header
|
||||
(written[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||
(written[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
int len = written[1] & 0x7F;
|
||||
len.ShouldBe(payload.Length);
|
||||
written[2..].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_WithCompression_CompressesLargePayload()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = new byte[200];
|
||||
Array.Fill<byte>(payload, 0x41); // 'A' repeated - very compressible
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// RSV1 bit should be set for compressed frame
|
||||
(written[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||
// Compressed size should be less than original
|
||||
written.Length.ShouldBeLessThan(payload.Length + 10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WriteAsync_SmallPayload_NotCompressedEvenWhenEnabled()
|
||||
{
|
||||
var inner = new MemoryStream();
|
||||
var ws = new WsConnection(inner, compress: true, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var payload = "Hi"u8.ToArray(); // Below CompressThreshold
|
||||
await ws.WriteAsync(payload);
|
||||
await ws.FlushAsync();
|
||||
|
||||
inner.Position = 0;
|
||||
var written = inner.ToArray();
|
||||
// RSV1 bit should NOT be set for small payloads
|
||||
(written[0] & WsConstants.Rsv1Bit).ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_DecodesMaskedFrame()
|
||||
{
|
||||
var payload = "CONNECT {}\r\n"u8.ToArray();
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
|
||||
var maskKey = header[^4..];
|
||||
WsFrameWriter.MaskBuf(maskKey, payload);
|
||||
|
||||
var frame = new byte[header.Length + payload.Length];
|
||||
header.CopyTo(frame, 0);
|
||||
payload.CopyTo(frame, header.Length);
|
||||
|
||||
var inner = new MemoryStream(frame);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: true, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
|
||||
n.ShouldBe("CONNECT {}\r\n".Length);
|
||||
System.Text.Encoding.ASCII.GetString(buf, 0, n).ShouldBe("CONNECT {}\r\n");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReadAsync_ReturnsZero_OnEndOfStream()
|
||||
{
|
||||
// Empty stream should return 0 (true end of stream)
|
||||
var inner = new MemoryStream([]);
|
||||
var ws = new WsConnection(inner, compress: false, maskRead: false, maskWrite: false, browser: false, noCompFrag: false);
|
||||
|
||||
var buf = new byte[256];
|
||||
int n = await ws.ReadAsync(buf);
|
||||
n.ShouldBe(0);
|
||||
}
|
||||
|
||||
private static byte[] BuildUnmaskedFrame(byte[] payload)
|
||||
{
|
||||
var header = new byte[2];
|
||||
header[0] = (byte)(WsConstants.FinalBit | WsConstants.BinaryMessage);
|
||||
header[1] = (byte)payload.Length;
|
||||
var frame = new byte[2 + payload.Length];
|
||||
header.CopyTo(frame, 0);
|
||||
payload.CopyTo(frame, 2);
|
||||
return frame;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsConstantsTests
|
||||
{
|
||||
[Fact]
|
||||
public void OpCodes_MatchRfc6455()
|
||||
{
|
||||
WsConstants.TextMessage.ShouldBe(1);
|
||||
WsConstants.BinaryMessage.ShouldBe(2);
|
||||
WsConstants.CloseMessage.ShouldBe(8);
|
||||
WsConstants.PingMessage.ShouldBe(9);
|
||||
WsConstants.PongMessage.ShouldBe(10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void FrameBits_MatchRfc6455()
|
||||
{
|
||||
WsConstants.FinalBit.ShouldBe((byte)0x80);
|
||||
WsConstants.Rsv1Bit.ShouldBe((byte)0x40);
|
||||
WsConstants.MaskBit.ShouldBe((byte)0x80);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CloseStatusCodes_MatchRfc6455()
|
||||
{
|
||||
WsConstants.CloseStatusNormalClosure.ShouldBe(1000);
|
||||
WsConstants.CloseStatusGoingAway.ShouldBe(1001);
|
||||
WsConstants.CloseStatusProtocolError.ShouldBe(1002);
|
||||
WsConstants.CloseStatusPolicyViolation.ShouldBe(1008);
|
||||
WsConstants.CloseStatusMessageTooBig.ShouldBe(1009);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(WsConstants.CloseMessage)]
|
||||
[InlineData(WsConstants.PingMessage)]
|
||||
[InlineData(WsConstants.PongMessage)]
|
||||
public void IsControlFrame_True(int opcode)
|
||||
{
|
||||
WsConstants.IsControlFrame(opcode).ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(WsConstants.TextMessage)]
|
||||
[InlineData(WsConstants.BinaryMessage)]
|
||||
[InlineData(0)]
|
||||
public void IsControlFrame_False(int opcode)
|
||||
{
|
||||
WsConstants.IsControlFrame(opcode).ShouldBeFalse();
|
||||
}
|
||||
}
|
||||
163
tests/NATS.Server.Transport.Tests/WebSocket/WsFrameReadTests.cs
Normal file
163
tests/NATS.Server.Transport.Tests/WebSocket/WsFrameReadTests.cs
Normal file
@@ -0,0 +1,163 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsFrameReadTests
|
||||
{
|
||||
/// <summary>Helper: build a single unmasked binary frame.</summary>
|
||||
private static byte[] BuildFrame(byte[] payload, bool fin = true, bool compressed = false, int opcode = WsConstants.BinaryMessage, bool mask = false, byte[]? maskKey = null)
|
||||
{
|
||||
int payloadLen = payload.Length;
|
||||
byte b0 = (byte)opcode;
|
||||
if (fin) b0 |= WsConstants.FinalBit;
|
||||
if (compressed) b0 |= WsConstants.Rsv1Bit;
|
||||
byte b1 = 0;
|
||||
if (mask) b1 |= WsConstants.MaskBit;
|
||||
|
||||
byte[] lenBytes;
|
||||
if (payloadLen <= 125)
|
||||
{
|
||||
lenBytes = [(byte)(b1 | (byte)payloadLen)];
|
||||
}
|
||||
else if (payloadLen < 65536)
|
||||
{
|
||||
lenBytes = new byte[3];
|
||||
lenBytes[0] = (byte)(b1 | 126);
|
||||
BinaryPrimitives.WriteUInt16BigEndian(lenBytes.AsSpan(1), (ushort)payloadLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
lenBytes = new byte[9];
|
||||
lenBytes[0] = (byte)(b1 | 127);
|
||||
BinaryPrimitives.WriteUInt64BigEndian(lenBytes.AsSpan(1), (ulong)payloadLen);
|
||||
}
|
||||
|
||||
int totalLen = 1 + lenBytes.Length + (mask ? 4 : 0) + payloadLen;
|
||||
var frame = new byte[totalLen];
|
||||
frame[0] = b0;
|
||||
lenBytes.CopyTo(frame.AsSpan(1));
|
||||
int pos = 1 + lenBytes.Length;
|
||||
if (mask && maskKey != null)
|
||||
{
|
||||
maskKey.CopyTo(frame.AsSpan(pos));
|
||||
pos += 4;
|
||||
var maskedPayload = payload.ToArray();
|
||||
WsFrameWriter.MaskBuf(maskKey, maskedPayload);
|
||||
maskedPayload.CopyTo(frame.AsSpan(pos));
|
||||
}
|
||||
else
|
||||
{
|
||||
payload.CopyTo(frame.AsSpan(pos));
|
||||
}
|
||||
return frame;
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadSingleUnmaskedFrame()
|
||||
{
|
||||
var payload = "Hello"u8.ToArray();
|
||||
var frame = BuildFrame(payload);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMaskedFrame()
|
||||
{
|
||||
var payload = "Hello"u8.ToArray();
|
||||
byte[] key = [0x37, 0xFA, 0x21, 0x3D];
|
||||
var frame = BuildFrame(payload, mask: true, maskKey: key);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: true);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Read16BitLengthFrame()
|
||||
{
|
||||
var payload = new byte[200];
|
||||
Random.Shared.NextBytes(payload);
|
||||
var frame = BuildFrame(payload);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(1);
|
||||
result[0].ShouldBe(payload);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadPingFrame_ReturnsPongAction()
|
||||
{
|
||||
var frame = BuildFrame([], opcode: WsConstants.PingMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0); // control frames don't produce payload
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(1);
|
||||
readInfo.PendingControlFrames[0].Opcode.ShouldBe(WsConstants.PongMessage);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadCloseFrame_ReturnsCloseAction()
|
||||
{
|
||||
var closePayload = new byte[2];
|
||||
BinaryPrimitives.WriteUInt16BigEndian(closePayload, 1000);
|
||||
var frame = BuildFrame(closePayload, opcode: WsConstants.CloseMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.CloseReceived.ShouldBeTrue();
|
||||
readInfo.CloseStatus.ShouldBe(1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadPongFrame_NoAction()
|
||||
{
|
||||
var frame = BuildFrame([], opcode: WsConstants.PongMessage);
|
||||
|
||||
var readInfo = new WsReadInfo(expectMask: false);
|
||||
var stream = new MemoryStream(frame);
|
||||
var result = WsReadInfo.ReadFrames(readInfo, stream, frame.Length, maxPayload: 1024);
|
||||
|
||||
result.Count.ShouldBe(0);
|
||||
readInfo.PendingControlFrames.Count.ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Unmask_Optimized_8ByteChunks()
|
||||
{
|
||||
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||
var original = new byte[32];
|
||||
Random.Shared.NextBytes(original);
|
||||
var masked = original.ToArray();
|
||||
|
||||
// Mask it
|
||||
for (int i = 0; i < masked.Length; i++)
|
||||
masked[i] ^= key[i & 3];
|
||||
|
||||
// Unmask using the state machine
|
||||
var info = new WsReadInfo(expectMask: true);
|
||||
info.SetMaskKey(key);
|
||||
info.Unmask(masked);
|
||||
|
||||
masked.ShouldBe(original);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
using System.Buffers.Binary;
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsFrameWriterTests
|
||||
{
|
||||
[Fact]
|
||||
public void CreateFrameHeader_SmallPayload_7BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 100);
|
||||
header.Length.ShouldBe(2);
|
||||
(header[0] & WsConstants.FinalBit).ShouldNotBe(0); // FIN set
|
||||
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
(header[1] & 0x7F).ShouldBe(100);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_MediumPayload_16BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
|
||||
header.Length.ShouldBe(4);
|
||||
(header[1] & 0x7F).ShouldBe(126);
|
||||
BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2)).ShouldBe((ushort)1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_LargePayload_64BitLength()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 70000);
|
||||
header.Length.ShouldBe(10);
|
||||
(header[1] & 0x7F).ShouldBe(127);
|
||||
BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2)).ShouldBe(70000UL);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
|
||||
{
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 10);
|
||||
header.Length.ShouldBe(6); // 2 header + 4 mask key
|
||||
(header[1] & WsConstants.MaskBit).ShouldNotBe(0);
|
||||
key.ShouldNotBeNull();
|
||||
key.Length.ShouldBe(4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_Compressed_SetsRsv1Bit()
|
||||
{
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: true,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 10);
|
||||
(header[0] & WsConstants.Rsv1Bit).ShouldNotBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_XorsCorrectly()
|
||||
{
|
||||
byte[] key = [0xAA, 0xBB, 0xCC, 0xDD];
|
||||
byte[] data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
|
||||
byte[] expected = new byte[data.Length];
|
||||
for (int i = 0; i < data.Length; i++)
|
||||
expected[i] = (byte)(data[i] ^ key[i & 3]);
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(expected);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_RoundTrip()
|
||||
{
|
||||
byte[] key = [0x12, 0x34, 0x56, 0x78];
|
||||
byte[] original = "Hello, WebSocket!"u8.ToArray();
|
||||
var data = original.ToArray();
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldNotBe(original);
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_WithStatusAndBody()
|
||||
{
|
||||
var msg = WsFrameWriter.CreateCloseMessage(1000, "normal closure");
|
||||
msg.Length.ShouldBe(2 + "normal closure".Length);
|
||||
BinaryPrimitives.ReadUInt16BigEndian(msg).ShouldBe((ushort)1000);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_LongBody_Truncated()
|
||||
{
|
||||
var longBody = new string('x', 200);
|
||||
var msg = WsFrameWriter.CreateCloseMessage(1000, longBody);
|
||||
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ClientClosed_NormalClosure()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
|
||||
.ShouldBe(WsConstants.CloseStatusNormalClosure);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_AuthTimeout_PolicyViolation()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationTimeout)
|
||||
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ParseError_ProtocolError()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ParseError)
|
||||
.ShouldBe(WsConstants.CloseStatusProtocolError);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_MaxPayload_MessageTooBig()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
|
||||
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildControlFrame_PingNomask()
|
||||
{
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PingMessage, [], useMasking: false);
|
||||
frame.Length.ShouldBe(2);
|
||||
(frame[0] & WsConstants.FinalBit).ShouldNotBe(0);
|
||||
(frame[0] & 0x0F).ShouldBe(WsConstants.PingMessage);
|
||||
(frame[1] & 0x7F).ShouldBe(0);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void BuildControlFrame_PongWithPayload()
|
||||
{
|
||||
byte[] payload = [1, 2, 3, 4];
|
||||
var frame = WsFrameWriter.BuildControlFrame(WsConstants.PongMessage, payload, useMasking: false);
|
||||
frame.Length.ShouldBe(2 + 4);
|
||||
frame[2..].ShouldBe(payload);
|
||||
}
|
||||
}
|
||||
782
tests/NATS.Server.Transport.Tests/WebSocket/WsGoParityTests.cs
Normal file
782
tests/NATS.Server.Transport.Tests/WebSocket/WsGoParityTests.cs
Normal file
@@ -0,0 +1,782 @@
|
||||
// Port of Go server/websocket_test.go — WebSocket protocol parity tests.
|
||||
// Reference: golang/nats-server/server/websocket_test.go
|
||||
//
|
||||
// Tests cover: compression negotiation, JWT auth extraction (bearer/cookie/query),
|
||||
// frame encoding/decoding, origin checking, upgrade handshake, and close messages.
|
||||
|
||||
using System.Buffers.Binary;
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// Parity tests ported from Go server/websocket_test.go exercising WebSocket
|
||||
/// frame encoding, compression negotiation, origin checking, upgrade validation,
|
||||
/// and JWT authentication extraction.
|
||||
/// </summary>
|
||||
public class WsGoParityTests
|
||||
{
|
||||
// ========================================================================
|
||||
// TestWSIsControlFrame
|
||||
// Go reference: websocket_test.go:TestWSIsControlFrame
|
||||
// ========================================================================
|
||||
|
||||
[Theory]
|
||||
[InlineData(WsConstants.CloseMessage, true)]
|
||||
[InlineData(WsConstants.PingMessage, true)]
|
||||
[InlineData(WsConstants.PongMessage, true)]
|
||||
[InlineData(WsConstants.TextMessage, false)]
|
||||
[InlineData(WsConstants.BinaryMessage, false)]
|
||||
[InlineData(WsConstants.ContinuationFrame, false)]
|
||||
public void IsControlFrame_CorrectClassification(int opcode, bool expected)
|
||||
{
|
||||
// Go: TestWSIsControlFrame websocket_test.go
|
||||
WsConstants.IsControlFrame(opcode).ShouldBe(expected);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSUnmask
|
||||
// Go reference: websocket_test.go:TestWSUnmask
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void Unmask_XorsWithKey()
|
||||
{
|
||||
// Go: TestWSUnmask — XOR unmasking with 4-byte key.
|
||||
var ri = new WsReadInfo(expectMask: true);
|
||||
var key = new byte[] { 0x12, 0x34, 0x56, 0x78 };
|
||||
ri.SetMaskKey(key);
|
||||
|
||||
var data = new byte[] { 0x12 ^ (byte)'H', 0x34 ^ (byte)'e', 0x56 ^ (byte)'l', 0x78 ^ (byte)'l', 0x12 ^ (byte)'o' };
|
||||
ri.Unmask(data);
|
||||
|
||||
Encoding.ASCII.GetString(data).ShouldBe("Hello");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Unmask_LargeBuffer_UsesOptimizedPath()
|
||||
{
|
||||
// Go: TestWSUnmask — optimized 8-byte chunk path for larger buffers.
|
||||
var ri = new WsReadInfo(expectMask: true);
|
||||
var key = new byte[] { 0xAA, 0xBB, 0xCC, 0xDD };
|
||||
ri.SetMaskKey(key);
|
||||
|
||||
// Create a buffer large enough to trigger the optimized path (>= 16 bytes)
|
||||
var original = new byte[32];
|
||||
for (int i = 0; i < original.Length; i++)
|
||||
original[i] = (byte)(i + 1);
|
||||
|
||||
// Mask it
|
||||
var masked = new byte[original.Length];
|
||||
for (int i = 0; i < masked.Length; i++)
|
||||
masked[i] = (byte)(original[i] ^ key[i % 4]);
|
||||
|
||||
// Unmask
|
||||
ri.Unmask(masked);
|
||||
masked.ShouldBe(original);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSCreateCloseMessage
|
||||
// Go reference: websocket_test.go:TestWSCreateCloseMessage
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_StatusAndBody()
|
||||
{
|
||||
// Go: TestWSCreateCloseMessage — close message has 2-byte status + body.
|
||||
var msg = WsFrameWriter.CreateCloseMessage(
|
||||
WsConstants.CloseStatusNormalClosure, "goodbye");
|
||||
|
||||
msg.Length.ShouldBeGreaterThan(2);
|
||||
var status = BinaryPrimitives.ReadUInt16BigEndian(msg);
|
||||
status.ShouldBe((ushort)WsConstants.CloseStatusNormalClosure);
|
||||
Encoding.UTF8.GetString(msg.AsSpan(2)).ShouldBe("goodbye");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateCloseMessage_LongBody_Truncated()
|
||||
{
|
||||
// Go: TestWSCreateCloseMessage — body truncated to MaxControlPayloadSize.
|
||||
var longBody = new string('x', 200);
|
||||
var msg = WsFrameWriter.CreateCloseMessage(
|
||||
WsConstants.CloseStatusGoingAway, longBody);
|
||||
|
||||
msg.Length.ShouldBeLessThanOrEqualTo(WsConstants.MaxControlPayloadSize);
|
||||
// Should end with "..."
|
||||
var body = Encoding.UTF8.GetString(msg.AsSpan(2));
|
||||
body.ShouldEndWith("...");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSCreateFrameHeader
|
||||
// Go reference: websocket_test.go:TestWSCreateFrameHeader
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_SmallPayload_2ByteHeader()
|
||||
{
|
||||
// Go: TestWSCreateFrameHeader — payload <= 125 uses 2-byte header.
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 50);
|
||||
|
||||
header.Length.ShouldBe(2);
|
||||
(header[0] & 0x0F).ShouldBe(WsConstants.BinaryMessage);
|
||||
(header[0] & WsConstants.FinalBit).ShouldBe(WsConstants.FinalBit);
|
||||
(header[1] & 0x7F).ShouldBe(50);
|
||||
key.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_MediumPayload_4ByteHeader()
|
||||
{
|
||||
// Go: TestWSCreateFrameHeader — payload 126-65535 uses 4-byte header.
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 1000);
|
||||
|
||||
header.Length.ShouldBe(4);
|
||||
(header[1] & 0x7F).ShouldBe(126);
|
||||
var payloadLen = BinaryPrimitives.ReadUInt16BigEndian(header.AsSpan(2));
|
||||
payloadLen.ShouldBe((ushort)1000);
|
||||
key.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_LargePayload_10ByteHeader()
|
||||
{
|
||||
// Go: TestWSCreateFrameHeader — payload >= 65536 uses 10-byte header.
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 100000);
|
||||
|
||||
header.Length.ShouldBe(10);
|
||||
(header[1] & 0x7F).ShouldBe(127);
|
||||
var payloadLen = BinaryPrimitives.ReadUInt64BigEndian(header.AsSpan(2));
|
||||
payloadLen.ShouldBe(100000UL);
|
||||
key.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_WithMasking_Adds4ByteKey()
|
||||
{
|
||||
// Go: TestWSCreateFrameHeader — masking adds 4-byte key to header.
|
||||
var (header, key) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 50);
|
||||
|
||||
header.Length.ShouldBe(6); // 2 base + 4 mask key
|
||||
(header[1] & WsConstants.MaskBit).ShouldBe(WsConstants.MaskBit);
|
||||
key.ShouldNotBeNull();
|
||||
key!.Length.ShouldBe(4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CreateFrameHeader_Compressed_SetsRsv1()
|
||||
{
|
||||
// Go: TestWSCreateFrameHeader — compressed frames have RSV1 bit set.
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: false, compressed: true,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: 50);
|
||||
|
||||
(header[0] & WsConstants.Rsv1Bit).ShouldBe(WsConstants.Rsv1Bit);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSCheckOrigin
|
||||
// Go reference: websocket_test.go:TestWSCheckOrigin
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_SameOrigin_Allowed()
|
||||
{
|
||||
// Go: TestWSCheckOrigin — same origin passes.
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://localhost:4222", "localhost:4222", isTls: false).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_SameOrigin_Rejected()
|
||||
{
|
||||
// Go: TestWSCheckOrigin — different origin fails.
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
var result = checker.CheckOrigin("http://evil.com", "localhost:4222", isTls: false);
|
||||
result.ShouldNotBeNull();
|
||||
result.ShouldContain("not same origin");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_AllowedList_Allowed()
|
||||
{
|
||||
// Go: TestWSCheckOrigin — allowed origins list.
|
||||
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["http://example.com"]);
|
||||
checker.CheckOrigin("http://example.com", "localhost:4222", isTls: false).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_AllowedList_Rejected()
|
||||
{
|
||||
// Go: TestWSCheckOrigin — origin not in allowed list.
|
||||
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["http://example.com"]);
|
||||
var result = checker.CheckOrigin("http://evil.com", "localhost:4222", isTls: false);
|
||||
result.ShouldNotBeNull();
|
||||
result.ShouldContain("not in the allowed list");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_EmptyOrigin_Allowed()
|
||||
{
|
||||
// Go: TestWSCheckOrigin — empty origin (non-browser) is always allowed.
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin(null, "localhost:4222", isTls: false).ShouldBeNull();
|
||||
checker.CheckOrigin("", "localhost:4222", isTls: false).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_NoRestrictions_AllAllowed()
|
||||
{
|
||||
// Go: no restrictions means all origins pass.
|
||||
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://anything.com", "localhost:4222", isTls: false).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void OriginChecker_AllowedWithPort()
|
||||
{
|
||||
// Go: TestWSSetOriginOptions — origins with explicit ports.
|
||||
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: ["http://example.com:8080"]);
|
||||
checker.CheckOrigin("http://example.com:8080", "localhost", isTls: false).ShouldBeNull();
|
||||
checker.CheckOrigin("http://example.com", "localhost", isTls: false).ShouldNotBeNull(); // wrong port
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSCompressNegotiation
|
||||
// Go reference: websocket_test.go:TestWSCompressNegotiation
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void CompressNegotiation_FullParams()
|
||||
{
|
||||
// Go: TestWSCompressNegotiation — full parameter negotiation.
|
||||
var result = WsDeflateNegotiator.Negotiate(
|
||||
"permessage-deflate; server_no_context_takeover; client_no_context_takeover; server_max_window_bits=10; client_max_window_bits=12");
|
||||
|
||||
result.ShouldNotBeNull();
|
||||
result.Value.ServerNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ClientNoContextTakeover.ShouldBeTrue();
|
||||
result.Value.ServerMaxWindowBits.ShouldBe(10);
|
||||
result.Value.ClientMaxWindowBits.ShouldBe(12);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void CompressNegotiation_NoExtension_ReturnsNull()
|
||||
{
|
||||
// Go: TestWSCompressNegotiation — no permessage-deflate in header.
|
||||
WsDeflateNegotiator.Negotiate("x-webkit-deflate-frame").ShouldBeNull();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// WS Upgrade — JWT extraction (bearer, cookie, query parameter)
|
||||
// Go reference: websocket_test.go:TestWSBasicAuth, TestWSBindToProperAccount
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_BearerJwt_ExtractedFromAuthHeader()
|
||||
{
|
||||
// Go: TestWSBasicAuth — JWT extracted from Authorization: Bearer header.
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.test_jwt_token\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.test_jwt_token");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_CookieJwt_ExtractedFromCookie()
|
||||
{
|
||||
// Go: TestWSBindToProperAccount — JWT extracted from cookie when configured.
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Cookie: jwt=eyJhbGciOiJIUzI1NiJ9.cookie_jwt; other=value\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.CookieJwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.cookie_jwt");
|
||||
// Cookie JWT becomes fallback JWT
|
||||
result.Jwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.cookie_jwt");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_QueryJwt_ExtractedFromQueryParam()
|
||||
{
|
||||
// Go: JWT extracted from query parameter when no auth header or cookie.
|
||||
var request = BuildValidRequest(
|
||||
path: "/?jwt=eyJhbGciOiJIUzI1NiJ9.query_jwt");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe("eyJhbGciOiJIUzI1NiJ9.query_jwt");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_JwtPriority_BearerOverCookieOverQuery()
|
||||
{
|
||||
// Go: Authorization header takes priority over cookie and query.
|
||||
var request = BuildValidRequest(
|
||||
path: "/?jwt=query_token",
|
||||
extraHeaders: "Authorization: Bearer bearer_token\r\nCookie: jwt=cookie_token\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe("bearer_token");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSXForwardedFor
|
||||
// Go reference: websocket_test.go:TestWSXForwardedFor
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_XForwardedFor_ExtractsClientIp()
|
||||
{
|
||||
// Go: TestWSXForwardedFor — X-Forwarded-For header extracts first IP.
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"X-Forwarded-For: 192.168.1.100, 10.0.0.1\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.ClientIp.ShouldBe("192.168.1.100");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSUpgradeValidationErrors
|
||||
// Go reference: websocket_test.go:TestWSUpgradeValidationErrors
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_MissingHost_Fails()
|
||||
{
|
||||
// Go: TestWSUpgradeValidationErrors — missing Host header.
|
||||
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_MissingUpgradeHeader_Fails()
|
||||
{
|
||||
// Go: TestWSUpgradeValidationErrors — missing Upgrade header.
|
||||
var request = "GET / HTTP/1.1\r\nHost: localhost:4222\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_MissingKey_Fails()
|
||||
{
|
||||
// Go: TestWSUpgradeValidationErrors — missing Sec-WebSocket-Key.
|
||||
var request = "GET / HTTP/1.1\r\nHost: localhost:4222\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_WrongVersion_Fails()
|
||||
{
|
||||
// Go: TestWSUpgradeValidationErrors — wrong WebSocket version.
|
||||
var request = BuildValidRequest(versionOverride: "12");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSSetHeader
|
||||
// Go reference: websocket_test.go:TestWSSetHeader
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_CustomHeaders_IncludedInResponse()
|
||||
{
|
||||
// Go: TestWSSetHeader — custom headers added to upgrade response.
|
||||
var request = BuildValidRequest();
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions
|
||||
{
|
||||
NoTls = true,
|
||||
Headers = new Dictionary<string, string> { ["X-Custom"] = "test-value" },
|
||||
};
|
||||
await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
var response = ReadResponse(output);
|
||||
response.ShouldContain("X-Custom: test-value");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSWebrowserClient
|
||||
// Go reference: websocket_test.go:TestWSWebrowserClient
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_BrowserUserAgent_DetectedAsBrowser()
|
||||
{
|
||||
// Go: TestWSWebrowserClient — Mozilla user-agent detected as browser.
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Browser.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_NonBrowserUserAgent_NotDetected()
|
||||
{
|
||||
// Go: non-browser user agent is not flagged.
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"User-Agent: nats-client/1.0\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Browser.ShouldBeFalse();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSCompressionBasic
|
||||
// Go reference: websocket_test.go:TestWSCompressionBasic
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void Compression_RoundTrip()
|
||||
{
|
||||
// Go: TestWSCompressionBasic — compress then decompress returns original.
|
||||
var original = "Hello, WebSocket compression test! This is a reasonably long string."u8.ToArray();
|
||||
|
||||
var compressed = WsCompression.Compress(original);
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 1024 * 1024);
|
||||
|
||||
decompressed.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Compression_SmallData_StillWorks()
|
||||
{
|
||||
// Go: even very small data can be compressed/decompressed.
|
||||
var original = "Hi"u8.ToArray();
|
||||
|
||||
var compressed = WsCompression.Compress(original);
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 1024);
|
||||
|
||||
decompressed.ShouldBe(original);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Compression_EmptyData()
|
||||
{
|
||||
var compressed = WsCompression.Compress(ReadOnlySpan<byte>.Empty);
|
||||
var decompressed = WsCompression.Decompress([compressed], maxPayload: 1024);
|
||||
decompressed.ShouldBeEmpty();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// TestWSDecompressLimit
|
||||
// Go reference: websocket_test.go:TestWSDecompressLimit
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void Decompress_ExceedsMaxPayload_Throws()
|
||||
{
|
||||
// Go: TestWSDecompressLimit — decompressed data exceeding max payload throws.
|
||||
// Create data larger than the limit
|
||||
var large = new byte[10000];
|
||||
for (int i = 0; i < large.Length; i++) large[i] = (byte)(i % 256);
|
||||
|
||||
var compressed = WsCompression.Compress(large);
|
||||
|
||||
Should.Throw<InvalidOperationException>(() =>
|
||||
WsCompression.Decompress([compressed], maxPayload: 100));
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// MaskBuf / MaskBufs
|
||||
// Go reference: websocket_test.go TestWSFrameOutbound
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_XorsInPlace()
|
||||
{
|
||||
// Go: TestWSFrameOutbound — masking XORs buffer with key.
|
||||
var key = new byte[] { 0xAA, 0xBB, 0xCC, 0xDD };
|
||||
var data = new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05 };
|
||||
var expected = new byte[] { 0x01 ^ 0xAA, 0x02 ^ 0xBB, 0x03 ^ 0xCC, 0x04 ^ 0xDD, 0x05 ^ 0xAA };
|
||||
|
||||
WsFrameWriter.MaskBuf(key, data);
|
||||
data.ShouldBe(expected);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MaskBuf_DoubleApply_RestoresOriginal()
|
||||
{
|
||||
// Go: masking is its own inverse.
|
||||
var key = new byte[] { 0x12, 0x34, 0x56, 0x78 };
|
||||
var original = "Hello World"u8.ToArray();
|
||||
var copy = original.ToArray();
|
||||
|
||||
WsFrameWriter.MaskBuf(key, copy);
|
||||
copy.ShouldNotBe(original);
|
||||
|
||||
WsFrameWriter.MaskBuf(key, copy);
|
||||
copy.ShouldBe(original);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// MapCloseStatus
|
||||
// Go reference: websocket_test.go TestWSEnqueueCloseMsg
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ClientClosed_NormalClosure()
|
||||
{
|
||||
// Go: TestWSEnqueueCloseMsg — client-initiated close maps to 1000.
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ClientClosed)
|
||||
.ShouldBe(WsConstants.CloseStatusNormalClosure);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_AuthViolation_PolicyViolation()
|
||||
{
|
||||
// Go: TestWSEnqueueCloseMsg — auth violation maps to 1008.
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.AuthenticationViolation)
|
||||
.ShouldBe(WsConstants.CloseStatusPolicyViolation);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ProtocolError_ProtocolError()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ProtocolViolation)
|
||||
.ShouldBe(WsConstants.CloseStatusProtocolError);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_ServerShutdown_GoingAway()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.ServerShutdown)
|
||||
.ShouldBe(WsConstants.CloseStatusGoingAway);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapCloseStatus_MaxPayloadExceeded_MessageTooBig()
|
||||
{
|
||||
WsFrameWriter.MapCloseStatus(ClientClosedReason.MaxPayloadExceeded)
|
||||
.ShouldBe(WsConstants.CloseStatusMessageTooBig);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// WsUpgrade.ComputeAcceptKey
|
||||
// Go reference: websocket_test.go — RFC 6455 example
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void ComputeAcceptKey_Rfc6455Example()
|
||||
{
|
||||
// RFC 6455 Section 4.2.2 example
|
||||
var accept = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
|
||||
accept.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// WsUpgrade — path-based client kind detection
|
||||
// Go reference: websocket_test.go TestWSWebrowserClient
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_LeafNodePath_DetectedAsLeaf()
|
||||
{
|
||||
var request = BuildValidRequest(path: "/leafnode");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Leaf);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_MqttPath_DetectedAsMqtt()
|
||||
{
|
||||
var request = BuildValidRequest(path: "/mqtt");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Mqtt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_RootPath_DetectedAsClient()
|
||||
{
|
||||
var request = BuildValidRequest(path: "/");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Client);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// WsUpgrade — cookie extraction
|
||||
// Go reference: websocket_test.go TestWSNoAuthUserValidation
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_Cookies_Extracted()
|
||||
{
|
||||
// Go: TestWSNoAuthUserValidation — username/password/token from cookies.
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Cookie: nats_user=admin; nats_pass=secret; nats_token=tok123\r\n");
|
||||
var (input, output) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions
|
||||
{
|
||||
NoTls = true,
|
||||
UsernameCookie = "nats_user",
|
||||
PasswordCookie = "nats_pass",
|
||||
TokenCookie = "nats_token",
|
||||
};
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.CookieUsername.ShouldBe("admin");
|
||||
result.CookiePassword.ShouldBe("secret");
|
||||
result.CookieToken.ShouldBe("tok123");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// ExtractBearerToken
|
||||
// Go reference: websocket_test.go — bearer token extraction
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_WithPrefix()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("Bearer my-token").ShouldBe("my-token");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_WithoutPrefix()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("my-token").ShouldBe("my-token");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_Empty_ReturnsNull()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("").ShouldBeNull();
|
||||
WsUpgrade.ExtractBearerToken(null).ShouldBeNull();
|
||||
WsUpgrade.ExtractBearerToken(" ").ShouldBeNull();
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// ParseQueryString
|
||||
// Go reference: websocket_test.go — query parameter parsing
|
||||
// ========================================================================
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_MultipleParams()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("?jwt=abc&user=admin&pass=secret");
|
||||
|
||||
result["jwt"].ShouldBe("abc");
|
||||
result["user"].ShouldBe("admin");
|
||||
result["pass"].ShouldBe("secret");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_UrlEncoded()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("?key=hello%20world");
|
||||
result["key"].ShouldBe("hello world");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_NoQuestionMark()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("jwt=token123");
|
||||
result["jwt"].ShouldBe("token123");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Helpers
|
||||
// ========================================================================
|
||||
|
||||
private static string BuildValidRequest(string path = "/", string? extraHeaders = null, string? versionOverride = null)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||
sb.Append("Host: localhost:4222\r\n");
|
||||
sb.Append("Upgrade: websocket\r\n");
|
||||
sb.Append("Connection: Upgrade\r\n");
|
||||
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||
sb.Append($"Sec-WebSocket-Version: {versionOverride ?? "13"}\r\n");
|
||||
if (extraHeaders != null)
|
||||
sb.Append(extraHeaders);
|
||||
sb.Append("\r\n");
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||
{
|
||||
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||
}
|
||||
|
||||
private static string ReadResponse(MemoryStream output)
|
||||
{
|
||||
output.Position = 0;
|
||||
return Encoding.ASCII.GetString(output.ToArray());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
using System.Buffers.Binary;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsIntegrationTests : IAsyncLifetime
|
||||
{
|
||||
private NatsServer _server = null!;
|
||||
private NatsOptions _options = null!;
|
||||
|
||||
public async Task InitializeAsync()
|
||||
{
|
||||
_options = new NatsOptions
|
||||
{
|
||||
Port = 0,
|
||||
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
|
||||
};
|
||||
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
|
||||
_server = new NatsServer(_options, loggerFactory);
|
||||
_ = _server.StartAsync(CancellationToken.None);
|
||||
await _server.WaitForReadyAsync();
|
||||
}
|
||||
|
||||
public async Task DisposeAsync()
|
||||
{
|
||||
await _server.ShutdownAsync();
|
||||
_server.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WebSocket_ConnectAndReceiveInfo()
|
||||
{
|
||||
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
|
||||
using var stream = new NetworkStream(socket, ownsSocket: false);
|
||||
|
||||
await SendUpgradeRequest(stream);
|
||||
var response = await ReadHttpResponse(stream);
|
||||
response.ShouldContain("101");
|
||||
|
||||
var wsFrame = await ReadWsFrame(stream);
|
||||
var info = Encoding.ASCII.GetString(wsFrame);
|
||||
info.ShouldStartWith("INFO ");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WebSocket_ConnectAndPing()
|
||||
{
|
||||
using var client = await ConnectWsClient();
|
||||
|
||||
// Send CONNECT and PING together
|
||||
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
|
||||
|
||||
// Read PONG WS frame
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
var pong = await ReadWsFrameAsync(client, cts.Token);
|
||||
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WebSocket_PubSub()
|
||||
{
|
||||
using var sub = await ConnectWsClient();
|
||||
using var pub = await ConnectWsClient();
|
||||
|
||||
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\nPING\r\n");
|
||||
|
||||
// Wait for PONG to confirm subscription is registered
|
||||
using var subCts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
var pong = await ReadWsFrameAsync(sub, subCts.Token);
|
||||
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
|
||||
|
||||
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
|
||||
|
||||
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
||||
var msg = await ReadWsFrameAsync(sub, cts.Token);
|
||||
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
|
||||
}
|
||||
|
||||
private async Task<NetworkStream> ConnectWsClient()
|
||||
{
|
||||
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
||||
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
|
||||
var stream = new NetworkStream(socket, ownsSocket: true);
|
||||
|
||||
await SendUpgradeRequest(stream);
|
||||
var response = await ReadHttpResponse(stream);
|
||||
response.ShouldContain("101");
|
||||
|
||||
await ReadWsFrame(stream); // Read INFO frame
|
||||
return stream;
|
||||
}
|
||||
|
||||
private static async Task SendUpgradeRequest(NetworkStream stream)
|
||||
{
|
||||
var keyBytes = new byte[16];
|
||||
RandomNumberGenerator.Fill(keyBytes);
|
||||
var key = Convert.ToBase64String(keyBytes);
|
||||
|
||||
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
|
||||
private static async Task<string> ReadHttpResponse(NetworkStream stream)
|
||||
{
|
||||
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
|
||||
var sb = new StringBuilder();
|
||||
var buf = new byte[1];
|
||||
while (true)
|
||||
{
|
||||
int n = await stream.ReadAsync(buf);
|
||||
if (n == 0) break;
|
||||
sb.Append((char)buf[0]);
|
||||
if (sb.Length >= 4 &&
|
||||
sb[^4] == '\r' && sb[^3] == '\n' &&
|
||||
sb[^2] == '\r' && sb[^1] == '\n')
|
||||
break;
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
|
||||
=> ReadWsFrameAsync(stream, CancellationToken.None);
|
||||
|
||||
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
|
||||
{
|
||||
var header = new byte[2];
|
||||
await stream.ReadExactlyAsync(header, ct);
|
||||
int len = header[1] & 0x7F;
|
||||
if (len == 126)
|
||||
{
|
||||
var extLen = new byte[2];
|
||||
await stream.ReadExactlyAsync(extLen, ct);
|
||||
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
|
||||
}
|
||||
else if (len == 127)
|
||||
{
|
||||
var extLen = new byte[8];
|
||||
await stream.ReadExactlyAsync(extLen, ct);
|
||||
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
|
||||
}
|
||||
|
||||
var payload = new byte[len];
|
||||
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
|
||||
return payload;
|
||||
}
|
||||
|
||||
private static async Task SendWsText(NetworkStream stream, string text)
|
||||
{
|
||||
var payload = Encoding.ASCII.GetBytes(text);
|
||||
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
||||
useMasking: true, compressed: false,
|
||||
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
|
||||
var maskKey = header[^4..];
|
||||
WsFrameWriter.MaskBuf(maskKey, payload);
|
||||
await stream.WriteAsync(header);
|
||||
await stream.WriteAsync(payload);
|
||||
await stream.FlushAsync();
|
||||
}
|
||||
}
|
||||
316
tests/NATS.Server.Transport.Tests/WebSocket/WsJwtAuthTests.cs
Normal file
316
tests/NATS.Server.Transport.Tests/WebSocket/WsJwtAuthTests.cs
Normal file
@@ -0,0 +1,316 @@
|
||||
// Tests for WebSocket JWT authentication during upgrade (E11).
|
||||
// Verifies JWT extraction from Authorization header, cookie, and query parameter.
|
||||
// Reference: golang/nats-server/server/websocket.go — cookie JWT extraction (line 856),
|
||||
// websocket_test.go — TestWSReloadTLSConfig (line 4066).
|
||||
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsJwtAuthTests
|
||||
{
|
||||
// ─── Authorization header JWT extraction ─────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_AuthorizationBearerHeader_ExtractsJwt()
|
||||
{
|
||||
// JWT from Authorization: Bearer <token> header (standard HTTP auth)
|
||||
var jwt = "eyJhbGciOiJFZDI1NTE5IiwidHlwIjoiSldUIn0.test-payload.test-sig";
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
$"Authorization: Bearer {jwt}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(jwt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_AuthorizationBearerCaseInsensitive()
|
||||
{
|
||||
// RFC 7235: "bearer" scheme is case-insensitive
|
||||
var jwt = "my-jwt-token-123";
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
$"Authorization: bearer {jwt}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(jwt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_AuthorizationBareToken_ExtractsJwt()
|
||||
{
|
||||
// Some clients send the token directly without "Bearer" prefix
|
||||
var jwt = "raw-jwt-token-456";
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
$"Authorization: {jwt}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(jwt);
|
||||
}
|
||||
|
||||
// ─── Cookie JWT extraction ──────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_JwtCookie_ExtractsJwt()
|
||||
{
|
||||
// Go parity: websocket.go line 856 — JWT from configured cookie name
|
||||
var jwt = "cookie-jwt-token-789";
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
$"Cookie: jwt={jwt}; other=value\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.CookieJwt.ShouldBe(jwt);
|
||||
// Cookie JWT is used as fallback when no Authorization header is present
|
||||
result.Jwt.ShouldBe(jwt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_AuthorizationHeader_TakesPriorityOverCookie()
|
||||
{
|
||||
// Authorization header has higher priority than cookie
|
||||
var headerJwt = "auth-header-jwt";
|
||||
var cookieJwt = "cookie-jwt";
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
$"Authorization: Bearer {headerJwt}\r\n" +
|
||||
$"Cookie: jwt={cookieJwt}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt" };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(headerJwt);
|
||||
result.CookieJwt.ShouldBe(cookieJwt); // Cookie value is still preserved
|
||||
}
|
||||
|
||||
// ─── Query parameter JWT extraction ─────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_QueryParamJwt_ExtractsJwt()
|
||||
{
|
||||
// JWT from ?jwt= query parameter (useful for browser clients)
|
||||
var jwt = "query-jwt-token-abc";
|
||||
var request = BuildValidRequest(path: $"/?jwt={jwt}");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(jwt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_QueryParamJwt_UrlEncoded()
|
||||
{
|
||||
// JWT value may be URL-encoded
|
||||
var jwt = "eyJ0eXAiOiJKV1QifQ.payload.sig";
|
||||
var encoded = Uri.EscapeDataString(jwt);
|
||||
var request = BuildValidRequest(path: $"/?jwt={encoded}");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(jwt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_AuthorizationHeader_TakesPriorityOverQueryParam()
|
||||
{
|
||||
// Authorization header > query parameter
|
||||
var headerJwt = "auth-header-jwt";
|
||||
var queryJwt = "query-jwt";
|
||||
var request = BuildValidRequest(
|
||||
path: $"/?jwt={queryJwt}",
|
||||
extraHeaders: $"Authorization: Bearer {headerJwt}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(headerJwt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_Cookie_TakesPriorityOverQueryParam()
|
||||
{
|
||||
// Cookie > query parameter
|
||||
var cookieJwt = "cookie-jwt";
|
||||
var queryJwt = "query-jwt";
|
||||
var request = BuildValidRequest(
|
||||
path: $"/?jwt={queryJwt}",
|
||||
extraHeaders: $"Cookie: jwt_token={cookieJwt}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions { NoTls = true, JwtCookie = "jwt_token" };
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBe(cookieJwt);
|
||||
}
|
||||
|
||||
// ─── No JWT scenarios ───────────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_NoJwtAnywhere_JwtIsNull()
|
||||
{
|
||||
// No JWT in any source
|
||||
var request = BuildValidRequest();
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Jwt.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_EmptyAuthorizationHeader_JwtIsEmpty()
|
||||
{
|
||||
// Empty authorization header should produce empty string (non-null)
|
||||
var request = BuildValidRequest(extraHeaders: "Authorization: \r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
// Empty auth header is treated as null/no JWT
|
||||
result.Jwt.ShouldBeNull();
|
||||
}
|
||||
|
||||
// ─── ExtractBearerToken unit tests ──────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_BearerPrefix()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("Bearer my-token").ShouldBe("my-token");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_BearerPrefixLowerCase()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("bearer my-token").ShouldBe("my-token");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_BareToken()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("raw-token").ShouldBe("raw-token");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_Null()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken(null).ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_Empty()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken("").ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ExtractBearerToken_Whitespace()
|
||||
{
|
||||
WsUpgrade.ExtractBearerToken(" ").ShouldBeNull();
|
||||
}
|
||||
|
||||
// ─── ParseQueryString unit tests ────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_SingleParam()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("?jwt=token123");
|
||||
result["jwt"].ShouldBe("token123");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_MultipleParams()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("?jwt=token&user=admin");
|
||||
result["jwt"].ShouldBe("token");
|
||||
result["user"].ShouldBe("admin");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_UrlEncoded()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("?jwt=a%20b%3Dc");
|
||||
result["jwt"].ShouldBe("a b=c");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseQueryString_NoQuestionMark()
|
||||
{
|
||||
var result = WsUpgrade.ParseQueryString("jwt=token");
|
||||
result["jwt"].ShouldBe("token");
|
||||
}
|
||||
|
||||
// ─── FailUnauthorizedAsync ──────────────────────────────────────────
|
||||
|
||||
[Fact]
|
||||
public async Task FailUnauthorizedAsync_Returns401()
|
||||
{
|
||||
var output = new MemoryStream();
|
||||
var result = await WsUpgrade.FailUnauthorizedAsync(output, "invalid JWT");
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
output.Position = 0;
|
||||
var response = Encoding.ASCII.GetString(output.ToArray());
|
||||
response.ShouldContain("401");
|
||||
response.ShouldContain("invalid JWT");
|
||||
}
|
||||
|
||||
// ─── Query param path routing still works with query strings ────────
|
||||
|
||||
[Fact]
|
||||
public async Task Upgrade_PathWithQueryParam_StillRoutesCorrectly()
|
||||
{
|
||||
// /leafnode?jwt=token should still detect as leaf kind
|
||||
var request = BuildValidRequest(path: "/leafnode?jwt=my-token");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Leaf);
|
||||
result.Jwt.ShouldBe("my-token");
|
||||
}
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||
sb.Append("Host: localhost:4222\r\n");
|
||||
sb.Append("Upgrade: websocket\r\n");
|
||||
sb.Append("Connection: Upgrade\r\n");
|
||||
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||
if (extraHeaders != null)
|
||||
sb.Append(extraHeaders);
|
||||
sb.Append("\r\n");
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||
{
|
||||
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
using NATS.Server.WebSocket;
|
||||
using Shouldly;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsOriginCheckerTests
|
||||
{
|
||||
[Fact]
|
||||
public void NoOriginHeader_Accepted()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin(origin: null, requestHost: "localhost:4222", isTls: false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NeitherSameNorList_AlwaysAccepted()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false, allowedOrigins: null);
|
||||
checker.CheckOrigin("https://evil.com", "localhost:4222", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_Match()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://localhost:4222", "localhost:4222", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_Mismatch()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://other:4222", "localhost:4222", false)
|
||||
.ShouldNotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_DefaultPort_Http()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("http://localhost", "localhost:80", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void SameOrigin_DefaultPort_Https()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: true, allowedOrigins: null);
|
||||
checker.CheckOrigin("https://localhost", "localhost:443", true)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AllowedOrigins_Match()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false,
|
||||
allowedOrigins: ["https://app.example.com"]);
|
||||
checker.CheckOrigin("https://app.example.com", "localhost:4222", false)
|
||||
.ShouldBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AllowedOrigins_Mismatch()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false,
|
||||
allowedOrigins: ["https://app.example.com"]);
|
||||
checker.CheckOrigin("https://evil.example.com", "localhost:4222", false)
|
||||
.ShouldNotBeNull();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AllowedOrigins_SchemeMismatch()
|
||||
{
|
||||
var checker = new WsOriginChecker(sameOrigin: false,
|
||||
allowedOrigins: ["https://app.example.com"]);
|
||||
checker.CheckOrigin("http://app.example.com", "localhost:4222", false)
|
||||
.ShouldNotBeNull();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsUpgradeHelperParityBatch1Tests
|
||||
{
|
||||
[Fact]
|
||||
public void MakeChallengeKey_returns_base64_of_16_random_bytes()
|
||||
{
|
||||
var key = WsUpgrade.MakeChallengeKey();
|
||||
var decoded = Convert.FromBase64String(key);
|
||||
|
||||
decoded.Length.ShouldBe(16);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Url_helpers_match_ws_and_wss_schemes()
|
||||
{
|
||||
WsUpgrade.IsWsUrl("ws://localhost:8080").ShouldBeTrue();
|
||||
WsUpgrade.IsWsUrl("wss://localhost:8443").ShouldBeFalse();
|
||||
WsUpgrade.IsWsUrl("http://localhost").ShouldBeFalse();
|
||||
|
||||
WsUpgrade.IsWssUrl("wss://localhost:8443").ShouldBeTrue();
|
||||
WsUpgrade.IsWssUrl("ws://localhost:8080").ShouldBeFalse();
|
||||
WsUpgrade.IsWssUrl("https://localhost").ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RejectNoMaskingForTest_forces_no_masking_handshake_rejection()
|
||||
{
|
||||
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
|
||||
using var input = new MemoryStream(Encoding.ASCII.GetBytes(request));
|
||||
using var output = new MemoryStream();
|
||||
|
||||
try
|
||||
{
|
||||
WsUpgrade.RejectNoMaskingForTest = true;
|
||||
var result = await WsUpgrade.TryUpgradeAsync(input, output, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
output.Position = 0;
|
||||
var response = Encoding.ASCII.GetString(output.ToArray());
|
||||
response.ShouldContain("400 Bad Request");
|
||||
response.ShouldContain("invalid value for no-masking");
|
||||
}
|
||||
finally
|
||||
{
|
||||
WsUpgrade.RejectNoMaskingForTest = false;
|
||||
}
|
||||
}
|
||||
|
||||
private static string BuildValidRequest(string path = "/", string extraHeaders = "")
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||
sb.Append("Host: localhost:8080\r\n");
|
||||
sb.Append("Upgrade: websocket\r\n");
|
||||
sb.Append("Connection: Upgrade\r\n");
|
||||
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||
sb.Append(extraHeaders);
|
||||
sb.Append("\r\n");
|
||||
return sb.ToString();
|
||||
}
|
||||
}
|
||||
226
tests/NATS.Server.Transport.Tests/WebSocket/WsUpgradeTests.cs
Normal file
226
tests/NATS.Server.Transport.Tests/WebSocket/WsUpgradeTests.cs
Normal file
@@ -0,0 +1,226 @@
|
||||
using System.Text;
|
||||
using NATS.Server.WebSocket;
|
||||
|
||||
namespace NATS.Server.Transport.Tests.WebSocket;
|
||||
|
||||
public class WsUpgradeTests
|
||||
{
|
||||
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
sb.Append($"GET {path} HTTP/1.1\r\n");
|
||||
sb.Append("Host: localhost:4222\r\n");
|
||||
sb.Append("Upgrade: websocket\r\n");
|
||||
sb.Append("Connection: Upgrade\r\n");
|
||||
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
|
||||
sb.Append("Sec-WebSocket-Version: 13\r\n");
|
||||
if (extraHeaders != null)
|
||||
sb.Append(extraHeaders);
|
||||
sb.Append("\r\n");
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ValidUpgrade_Returns101()
|
||||
{
|
||||
var request = BuildValidRequest();
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Client);
|
||||
var response = ReadResponse(outputStream);
|
||||
response.ShouldContain("HTTP/1.1 101");
|
||||
response.ShouldContain("Upgrade: websocket");
|
||||
response.ShouldContain("Sec-WebSocket-Accept:");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MissingUpgradeHeader_Returns400()
|
||||
{
|
||||
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
ReadResponse(outputStream).ShouldContain("400");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MissingHost_Returns400()
|
||||
{
|
||||
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WrongVersion_Returns400()
|
||||
{
|
||||
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task LeafNodePath_ReturnsLeafKind()
|
||||
{
|
||||
var request = BuildValidRequest("/leafnode");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Leaf);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task MqttPath_ReturnsMqttKind()
|
||||
{
|
||||
var request = BuildValidRequest("/mqtt");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Kind.ShouldBe(WsClientKind.Mqtt);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompressionNegotiation_WhenEnabled()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeTrue();
|
||||
ReadResponse(outputStream).ShouldContain("permessage-deflate");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompressionNegotiation_WhenDisabled()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Compress.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task NoMaskingHeader_ForLeaf()
|
||||
{
|
||||
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.MaskRead.ShouldBeFalse();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task BrowserDetection_Mozilla()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.Browser.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SafariDetection_NoCompFrag()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
|
||||
$"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.NoCompFrag.ShouldBeTrue();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AcceptKey_MatchesRfc6455Example()
|
||||
{
|
||||
// RFC 6455 Section 4.2.2 example
|
||||
var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
|
||||
key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CookieExtraction()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders:
|
||||
"Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var opts = new WebSocketOptions
|
||||
{
|
||||
NoTls = true,
|
||||
JwtCookie = "jwt_token",
|
||||
UsernameCookie = "nats_user",
|
||||
PasswordCookie = "nats_pass",
|
||||
};
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.CookieJwt.ShouldBe("my-jwt");
|
||||
result.CookieUsername.ShouldBe("admin");
|
||||
result.CookiePassword.ShouldBe("secret");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task XForwardedFor_ExtractsClientIp()
|
||||
{
|
||||
var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeTrue();
|
||||
result.ClientIp.ShouldBe("192.168.1.100");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task PostMethod_Returns405()
|
||||
{
|
||||
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
||||
var (inputStream, outputStream) = CreateStreamPair(request);
|
||||
|
||||
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
|
||||
|
||||
result.Success.ShouldBeFalse();
|
||||
ReadResponse(outputStream).ShouldContain("405");
|
||||
}
|
||||
|
||||
// Helper: create a readable input stream and writable output stream
|
||||
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
|
||||
{
|
||||
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
|
||||
return (new MemoryStream(inputBytes), new MemoryStream());
|
||||
}
|
||||
|
||||
private static string ReadResponse(MemoryStream output)
|
||||
{
|
||||
output.Position = 0;
|
||||
return Encoding.ASCII.GetString(output.ToArray());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user