using System.Buffers.Binary; using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using NATS.Server.WebSocket; namespace NATS.Server.Tests.WebSocket; public class WsIntegrationTests : IAsyncLifetime { private NatsServer _server = null!; private NatsOptions _options = null!; public async Task InitializeAsync() { _options = new NatsOptions { Port = 0, WebSocket = new WebSocketOptions { Port = 0, NoTls = true }, }; var loggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(b => { }); _server = new NatsServer(_options, loggerFactory); _ = _server.StartAsync(CancellationToken.None); await _server.WaitForReadyAsync(); } public async Task DisposeAsync() { await _server.ShutdownAsync(); _server.Dispose(); } [Fact] public async Task WebSocket_ConnectAndReceiveInfo() { using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); using var stream = new NetworkStream(socket, ownsSocket: false); await SendUpgradeRequest(stream); var response = await ReadHttpResponse(stream); response.ShouldContain("101"); var wsFrame = await ReadWsFrame(stream); var info = Encoding.ASCII.GetString(wsFrame); info.ShouldStartWith("INFO "); } [Fact] public async Task WebSocket_ConnectAndPing() { using var client = await ConnectWsClient(); // Send CONNECT and PING together await SendWsText(client, "CONNECT {}\r\nPING\r\n"); // Read PONG WS frame using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); var pong = await ReadWsFrameAsync(client, cts.Token); Encoding.ASCII.GetString(pong).ShouldContain("PONG"); } [Fact] public async Task WebSocket_PubSub() { using var sub = await ConnectWsClient(); using var pub = await ConnectWsClient(); await SendWsText(sub, "CONNECT {}\r\nSUB test.ws 1\r\n"); await Task.Delay(200); await SendWsText(pub, "CONNECT {}\r\nPUB test.ws 5\r\nHello\r\n"); using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); var msg = await ReadWsFrameAsync(sub, cts.Token); Encoding.ASCII.GetString(msg).ShouldContain("MSG test.ws 1 5"); } private async Task ConnectWsClient() { var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); await socket.ConnectAsync(new IPEndPoint(IPAddress.Loopback, _options.WebSocket.Port)); var stream = new NetworkStream(socket, ownsSocket: true); await SendUpgradeRequest(stream); var response = await ReadHttpResponse(stream); response.ShouldContain("101"); await ReadWsFrame(stream); // Read INFO frame return stream; } private static async Task SendUpgradeRequest(NetworkStream stream) { var keyBytes = new byte[16]; RandomNumberGenerator.Fill(keyBytes); var key = Convert.ToBase64String(keyBytes); var request = $"GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {key}\r\nSec-WebSocket-Version: 13\r\n\r\n"; await stream.WriteAsync(Encoding.ASCII.GetBytes(request)); await stream.FlushAsync(); } private static async Task ReadHttpResponse(NetworkStream stream) { // Read one byte at a time to avoid consuming WS frame bytes that follow the HTTP response var sb = new StringBuilder(); var buf = new byte[1]; while (true) { int n = await stream.ReadAsync(buf); if (n == 0) break; sb.Append((char)buf[0]); if (sb.Length >= 4 && sb[^4] == '\r' && sb[^3] == '\n' && sb[^2] == '\r' && sb[^1] == '\n') break; } return sb.ToString(); } private static Task ReadWsFrame(NetworkStream stream) => ReadWsFrameAsync(stream, CancellationToken.None); private static async Task ReadWsFrameAsync(NetworkStream stream, CancellationToken ct) { var header = new byte[2]; await stream.ReadExactlyAsync(header, ct); int len = header[1] & 0x7F; if (len == 126) { var extLen = new byte[2]; await stream.ReadExactlyAsync(extLen, ct); len = BinaryPrimitives.ReadUInt16BigEndian(extLen); } else if (len == 127) { var extLen = new byte[8]; await stream.ReadExactlyAsync(extLen, ct); len = (int)BinaryPrimitives.ReadUInt64BigEndian(extLen); } var payload = new byte[len]; if (len > 0) await stream.ReadExactlyAsync(payload, ct); return payload; } private static async Task SendWsText(NetworkStream stream, string text) { var payload = Encoding.ASCII.GetBytes(text); var (header, _) = WsFrameWriter.CreateFrameHeader( useMasking: true, compressed: false, opcode: WsConstants.BinaryMessage, payloadLength: payload.Length); var maskKey = header[^4..]; WsFrameWriter.MaskBuf(maskKey, payload); await stream.WriteAsync(header); await stream.WriteAsync(payload); await stream.FlushAsync(); } }