Add WebSocket listener support to NatsServer alongside the existing TCP listener. When WebSocketOptions.Port >= 0, the server binds a second socket, performs HTTP upgrade via WsUpgrade.TryUpgradeAsync, wraps the connection in WsConnection for transparent frame/deframe, and hands it to the standard NatsClient pipeline. Changes: - NatsClient: add IsWebSocket and WsInfo properties - NatsServer: add RunWebSocketAcceptLoopAsync and AcceptWebSocketClientAsync, WS listener lifecycle in StartAsync/ShutdownAsync/Dispose - NatsOptions: change WebSocketOptions.Port default from 0 to -1 (disabled) - WsConnection.ReadAsync: fix premature end-of-stream when ReadFrames returns no payloads by looping until data is available - Add WsIntegration tests (connect, ping, pub/sub over WebSocket) - Add WsConnection masked frame and end-of-stream unit tests
163 lines
5.5 KiB
C#
163 lines
5.5 KiB
C#
using System.Buffers.Binary;
|
|
using System.Net;
|
|
using System.Net.Sockets;
|
|
using System.Security.Cryptography;
|
|
using System.Text;
|
|
using NATS.Server.WebSocket;
|
|
|
|
namespace NATS.Server.Tests.WebSocket;
|
|
|
|
public class WsIntegrationTests : IAsyncLifetime
|
|
{
|
|
private NatsServer _server = null!;
|
|
private NatsOptions _options = null!;
|
|
|
|
public async Task InitializeAsync()
|
|
{
|
|
_options = new NatsOptions
|
|
{
|
|
Port = 0,
|
|
WebSocket = new WebSocketOptions { Port = 0, NoTls = true },
|
|
};
|
|
var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { });
|
|
_server = new NatsServer(_options, loggerFactory);
|
|
_ = _server.StartAsync(CancellationToken.None);
|
|
await _server.WaitForReadyAsync();
|
|
}
|
|
|
|
public async Task DisposeAsync()
|
|
{
|
|
await _server.ShutdownAsync();
|
|
_server.Dispose();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task WebSocket_ConnectAndReceiveInfo()
|
|
{
|
|
using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
|
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
|
|
using var stream = new NetworkStream(socket, ownsSocket: false);
|
|
|
|
await SendUpgradeRequest(stream);
|
|
var response = await ReadHttpResponse(stream);
|
|
response.ShouldContain("101");
|
|
|
|
var wsFrame = await ReadWsFrame(stream);
|
|
var info = Encoding.ASCII.GetString(wsFrame);
|
|
info.ShouldStartWith("INFO ");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task WebSocket_ConnectAndPing()
|
|
{
|
|
using var client = await ConnectWsClient();
|
|
|
|
// Send CONNECT and PING together
|
|
await SendWsText(client, "CONNECT {}\r\nPING\r\n");
|
|
|
|
// Read PONG WS frame
|
|
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
|
var pong = await ReadWsFrameAsync(client, cts.Token);
|
|
Encoding.ASCII.GetString(pong).ShouldContain("PONG");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task WebSocket_PubSub()
|
|
{
|
|
using var sub = await ConnectWsClient();
|
|
using var pub = await ConnectWsClient();
|
|
|
|
await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n");
|
|
await Task.Delay(200);
|
|
|
|
await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n");
|
|
|
|
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
|
|
var msg = await ReadWsFrameAsync(sub, cts.Token);
|
|
Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5");
|
|
}
|
|
|
|
private async Task<NetworkStream> ConnectWsClient()
|
|
{
|
|
var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
|
|
await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port));
|
|
var stream = new NetworkStream(socket, ownsSocket: true);
|
|
|
|
await SendUpgradeRequest(stream);
|
|
var response = await ReadHttpResponse(stream);
|
|
response.ShouldContain("101");
|
|
|
|
await ReadWsFrame(stream); // Read INFO frame
|
|
return stream;
|
|
}
|
|
|
|
private static async Task SendUpgradeRequest(NetworkStream stream)
|
|
{
|
|
var keyBytes = new byte[16];
|
|
RandomNumberGenerator.Fill(keyBytes);
|
|
var key = Convert.ToBase64String(keyBytes);
|
|
|
|
var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
|
await stream.WriteAsync(Encoding.ASCII.GetBytes(request));
|
|
await stream.FlushAsync();
|
|
}
|
|
|
|
private static async Task<string> ReadHttpResponse(NetworkStream stream)
|
|
{
|
|
// Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response
|
|
var sb = new StringBuilder();
|
|
var buf = new byte[1];
|
|
while (true)
|
|
{
|
|
int n = await stream.ReadAsync(buf);
|
|
if (n == 0) break;
|
|
sb.Append((char)buf[0]);
|
|
if (sb.Length >= 4 &&
|
|
sb[^4] == '\r' && sb[^3] == '\n' &&
|
|
sb[^2] == '\r' && sb[^1] == '\n')
|
|
break;
|
|
}
|
|
|
|
return sb.ToString();
|
|
}
|
|
|
|
private static Task<byte[]> ReadWsFrame(NetworkStream stream)
|
|
=> ReadWsFrameAsync(stream, CancellationToken.None);
|
|
|
|
private static async Task<byte[]> ReadWsFrameAsync(NetworkStream stream, CancellationToken ct)
|
|
{
|
|
var header = new byte[2];
|
|
await stream.ReadExactlyAsync(header, ct);
|
|
int len = header[1] & 0x7F;
|
|
if (len == 126)
|
|
{
|
|
var extLen = new byte[2];
|
|
await stream.ReadExactlyAsync(extLen, ct);
|
|
len = BinaryPrimitives.ReadUInt16BigEndian(extLen);
|
|
}
|
|
else if (len == 127)
|
|
{
|
|
var extLen = new byte[8];
|
|
await stream.ReadExactlyAsync(extLen, ct);
|
|
len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen);
|
|
}
|
|
|
|
var payload = new byte[len];
|
|
if (len > 0) await stream.ReadExactlyAsync(payload, ct);
|
|
return payload;
|
|
}
|
|
|
|
private static async Task SendWsText(NetworkStream stream, string text)
|
|
{
|
|
var payload = Encoding.ASCII.GetBytes(text);
|
|
var (header, _) = WsFrameWriter.CreateFrameHeader(
|
|
useMasking: true, compressed: false,
|
|
opcode: WsConstants.BinaryMessage, payloadLength: payload.Length);
|
|
var maskKey = header[^4..];
|
|
WsFrameWriter.MaskBuf(maskKey, payload);
|
|
await stream.WriteAsync(header);
|
|
await stream.WriteAsync(payload);
|
|
await stream.FlushAsync();
|
|
}
|
|
}
|