Files
natsdotnet/tests/NATS.Server.Tests/WebSocket/WsIntegrationTests.cs
Joseph Doherty ca88036126 feat: integrate WebSocket accept loop into NatsServer and NatsClient
Add WebSocket listener support to NatsServer alongside the existing TCP
listener. When WebSocketOptions.Port >= 0, the server binds a second
socket, performs HTTP upgrade via WsUpgrade.TryUpgradeAsync, wraps the
connection in WsConnection for transparent frame/deframe, and hands it
to the standard NatsClient pipeline.

Changes:
- NatsClient: add IsWebSocket and WsInfo properties
- NatsServer: add RunWebSocketAcceptLoopAsync and AcceptWebSocketClientAsync,
  WS listener lifecycle in StartAsync/ShutdownAsync/Dispose
- NatsOptions: change WebSocketOptions.Port default from 0 to -1 (disabled)
- WsConnection.ReadAsync: fix premature end-of-stream when ReadFrames
  returns no payloads by looping until data is available
- Add WsIntegration tests (connect, ping, pub/sub over WebSocket)
- Add WsConnection masked frame and end-of-stream unit tests
2026-02-23 05:16:57 -05:00

163 lines
5.5 KiB
C#

using System.Buffers.Binary;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsIntegrationTests : IAsyncLifetime
{
private NatsServer _server = null!;
private NatsOptions _options = null!;
public async Task InitializeAsync()
{
_options = new NatsOptions
{
Port = 0,
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
};
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
_server = new NatsServer(_options, loggerFactory);
_ = _server.StartAsync(CancellationToken.None);
await _server.WaitForReadyAsync();
}
public async Task DisposeAsync()
{
await _server.ShutdownAsync();
_server.Dispose();
}
[Fact]
public async Task WebSocket_ConnectAndReceiveInfo()
{
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
using var stream = new NetworkStream(socket, ownsSocket: false);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
var wsFrame = await ReadWsFrame(stream);
var info = Encoding.ASCII.GetString(wsFrame);
info.ShouldStartWith("INFO ");
}
[Fact]
public async Task WebSocket_ConnectAndPing()
{
using var client = await ConnectWsClient();
// Send CONNECT and PING together
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
// Read PONG WS frame
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var pong = await ReadWsFrameAsync(client, cts.Token);
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
}
[Fact]
public async Task WebSocket_PubSub()
{
using var sub = await ConnectWsClient();
using var pub = await ConnectWsClient();
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n");
await Task.Delay(200);
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var msg = await ReadWsFrameAsync(sub, cts.Token);
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
}
private async Task<NetworkStream> ConnectWsClient()
{
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
var stream = new NetworkStream(socket, ownsSocket: true);
await SendUpgradeRequest(stream);
var response = await ReadHttpResponse(stream);
response.ShouldContain("101");
await ReadWsFrame(stream); // Read INFO frame
return stream;
}
private static async Task SendUpgradeRequest(NetworkStream stream)
{
var keyBytes = new byte[16];
RandomNumberGenerator.Fill(keyBytes);
var key = Convert.ToBase64String(keyBytes);
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
await stream.FlushAsync();
}
private static async Task<string> ReadHttpResponse(NetworkStream stream)
{
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
var sb = new StringBuilder();
var buf = new byte[1];
while (true)
{
int n = await stream.ReadAsync(buf);
if (n == 0) break;
sb.Append((char)buf[0]);
if (sb.Length >= 4 &&
sb[^4] == '\r' && sb[^3] == '\n' &&
sb[^2] == '\r' && sb[^1] == '\n')
break;
}
return sb.ToString();
}
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
=> ReadWsFrameAsync(stream, CancellationToken.None);
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
{
var header = new byte[2];
await stream.ReadExactlyAsync(header, ct);
int len = header[1] & 0x7F;
if (len == 126)
{
var extLen = new byte[2];
await stream.ReadExactlyAsync(extLen, ct);
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
}
else if (len == 127)
{
var extLen = new byte[8];
await stream.ReadExactlyAsync(extLen, ct);
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
}
var payload = new byte[len];
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
return payload;
}
private static async Task SendWsText(NetworkStream stream, string text)
{
var payload = Encoding.ASCII.GetBytes(text);
var (header, _) = WsFrameWriter.CreateFrameHeader(
useMasking: true, compressed: false,
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
var maskKey = header[^4..];
WsFrameWriter.MaskBuf(maskKey, payload);
await stream.WriteAsync(header);
await stream.WriteAsync(payload);
await stream.FlushAsync();
}
}