using System.Buffers.Binary; using System.IO.Compression; using System.Reflection; using System.Text; using Shouldly; using ZB.MOM.NatsNet.Server; using ZB.MOM.NatsNet.Server.Internal; using ZB.MOM.NatsNet.Server.WebSocket; namespace ZB.MOM.NatsNet.Server.Tests.ImplBacklog; public sealed partial class WebSocketHandlerTests { [Fact] // T:3075 public void WSIsControlFrame_ShouldSucceed() { WebSocketHelpers.WsIsControlFrame(WsOpCode.Binary).ShouldBeFalse(); WebSocketHelpers.WsIsControlFrame(WsOpCode.Text).ShouldBeFalse(); WebSocketHelpers.WsIsControlFrame(WsOpCode.Ping).ShouldBeTrue(); WebSocketHelpers.WsIsControlFrame(WsOpCode.Pong).ShouldBeTrue(); WebSocketHelpers.WsIsControlFrame(WsOpCode.Close).ShouldBeTrue(); } [Fact] // T:3076 public void WSUnmask_ShouldSucceed() { var key = new byte[] { 1, 2, 3, 4 }; var clear = Encoding.ASCII.GetBytes("this is a clear text"); static void Mask(byte[] k, byte[] buf) { for (var i = 0; i < buf.Length; i++) buf[i] ^= k[i & 3]; } var masked = clear.ToArray(); Mask(key, masked); var readInfo = new WsReadInfo { Mask = true }; readInfo.Init(); key.CopyTo(readInfo.MaskKey, 0); readInfo.Unmask(masked); masked.ShouldBe(clear); masked = clear.ToArray(); Mask(key, masked); readInfo.MaskKeyPosition = 0; readInfo.Unmask(masked.AsSpan(0, 3)); readInfo.Unmask(masked.AsSpan(3, 8)); readInfo.Unmask(masked.AsSpan(11)); masked.ShouldBe(clear); } [Fact] // T:3077 public void WSCreateCloseMessage_ShouldSucceed() { var payload = new string('A', WsConstants.MaxControlPayloadSize + 10); var closeMessage = WebSocketHelpers.WsCreateCloseMessage(WsConstants.CloseProtocolError, payload); BinaryPrimitives.ReadUInt16BigEndian(closeMessage.AsSpan(0, 2)).ShouldBe((ushort)WsConstants.CloseProtocolError); closeMessage.Length.ShouldBe(WsConstants.MaxControlPayloadSize); Encoding.UTF8.GetString(closeMessage.AsSpan(2)).ShouldEndWith("..."); } [Fact] // T:3078 public void WSCreateFrameHeader_ShouldSucceed() { var (small, _) = WebSocketHelpers.WsCreateFrameHeader(useMasking: false, compressed: false, WsOpCode.Binary, 10); small.Length.ShouldBe(2); small[0].ShouldBe((byte)((byte)WsOpCode.Binary | WsConstants.FinalBit)); small[1].ShouldBe((byte)10); var (medium, _) = WebSocketHelpers.WsCreateFrameHeader(useMasking: false, compressed: true, WsOpCode.Text, 600); medium.Length.ShouldBe(4); medium[0].ShouldBe((byte)((byte)WsOpCode.Text | WsConstants.FinalBit | WsConstants.Rsv1Bit)); medium[1].ShouldBe((byte)126); BinaryPrimitives.ReadUInt16BigEndian(medium.AsSpan(2)).ShouldBe((ushort)600); var (large, _) = WebSocketHelpers.WsCreateFrameHeader(useMasking: false, compressed: false, WsOpCode.Text, 100_000); large.Length.ShouldBe(10); large[1].ShouldBe((byte)127); BinaryPrimitives.ReadUInt64BigEndian(large.AsSpan(2)).ShouldBe(100_000ul); } [Fact] // T:3079 public void WSReadUncompressedFrames_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var first = CreateMaskedClientFrame(WsOpCode.Binary, frameNum: 1, final: true, compressed: false, Encoding.ASCII.GetBytes("first message")); var second = CreateMaskedClientFrame(WsOpCode.Binary, frameNum: 1, final: true, compressed: false, Encoding.ASCII.GetBytes("second message")); var source = first.Concat(second).ToArray(); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), source); bufs.Count.ShouldBe(2); Encoding.ASCII.GetString(bufs[0]).ShouldBe("first message"); Encoding.ASCII.GetString(bufs[1]).ShouldBe("second message"); } [Fact] // T:3080 public void WSReadCompressedFrames_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var clear = Encoding.ASCII.GetBytes("this is the uncompress data"); var compressed = CreateMaskedClientFrame(WsOpCode.Binary, frameNum: 1, final: true, compressed: true, clear); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), compressed); bufs.Count.ShouldBe(1); var decoded = Encoding.ASCII.GetString(bufs[0]); decoded.ShouldStartWith("this is the uncompress d"); } [Fact] // T:3082 public void WSReadVariousFrameSizes_ShouldSucceed() { foreach (var size in new[] { 100, 1_000, 70_000 }) { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var payload = Enumerable.Range(0, size).Select(i => (byte)('A' + (i % 26))).ToArray(); var frame = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, payload); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), frame); bufs.Count.ShouldBe(1); bufs[0].ShouldBe(payload); } } [Fact] // T:3083 public void WSReadFragmentedFrames_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var f1 = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: false, compressed: false, Encoding.ASCII.GetBytes("first")); var f2 = CreateMaskedClientFrame(WsOpCode.Binary, 2, final: false, compressed: false, Encoding.ASCII.GetBytes("second")); var f3 = CreateMaskedClientFrame(WsOpCode.Binary, 3, final: true, compressed: false, Encoding.ASCII.GetBytes("third")); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), f1.Concat(f2).Concat(f3).ToArray()); bufs.Count.ShouldBe(3); Encoding.ASCII.GetString(bufs[0]).ShouldBe("first"); Encoding.ASCII.GetString(bufs[1]).ShouldBe("second"); Encoding.ASCII.GetString(bufs[2]).ShouldBe("third"); } [Fact] // T:3084 public void WSReadPartialFrameHeaderAtEndOfReadBuffer_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var first = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("msg1")); var second = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("msg2")); var source = first.Concat(second).ToArray(); var initial = source[..(first.Length + 1)]; using var remainder = new MemoryStream(source[(first.Length + 1)..]); var bufs = client.WsRead(readInfo, remainder, initial); bufs.Count.ShouldBe(1); Encoding.ASCII.GetString(bufs[0]).ShouldBe("msg1"); remainder.Position.ShouldBe(5); } [Fact] // T:3085 public void WSReadPingFrame_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var ping = CreateMaskedClientFrame(WsOpCode.Ping, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("optional payload")); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), ping); bufs.ShouldBeEmpty(); lock (GetClientLock(client)) { var (chunks, _) = client.CollapsePtoNB(); chunks.Count.ShouldBe(1); chunks[0].Buffer[0].ShouldBe((byte)((byte)WsOpCode.Pong | WsConstants.FinalBit)); } } [Fact] // T:3086 public void WSReadPongFrame_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var pong = CreateMaskedClientFrame(WsOpCode.Pong, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("optional payload")); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), pong); bufs.ShouldBeEmpty(); lock (GetClientLock(client)) { var (chunks, _) = client.CollapsePtoNB(); chunks.ShouldBeEmpty(); } } [Fact] // T:3087 public void WSReadCloseFrame_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var payload = new byte[2 + "optional payload"u8.Length]; BinaryPrimitives.WriteUInt16BigEndian(payload.AsSpan(0, 2), (ushort)WsConstants.CloseNormalClosure); Encoding.ASCII.GetBytes("optional payload").CopyTo(payload.AsSpan(2)); var msg = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("msg")); var close = CreateMaskedClientFrame(WsOpCode.Close, 1, final: true, compressed: false, payload); Should.Throw(() => client.WsRead(readInfo, new MemoryStream(Array.Empty()), msg.Concat(close).ToArray())); } [Fact] // T:3088 public void WSReadControlFrameBetweebFragmentedFrames_ShouldSucceed() { var client = CreateWsClient(); var readInfo = CreateReadInfo(); var frag1 = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: false, compressed: false, Encoding.ASCII.GetBytes("first")); var ctrl = CreateMaskedClientFrame(WsOpCode.Pong, 1, final: true, compressed: false, Array.Empty()); var frag2 = CreateMaskedClientFrame(WsOpCode.Binary, 2, final: true, compressed: false, Encoding.ASCII.GetBytes("second")); var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty()), frag1.Concat(ctrl).Concat(frag2).ToArray()); bufs.Count.ShouldBe(2); Encoding.ASCII.GetString(bufs[0]).ShouldBe("first"); Encoding.ASCII.GetString(bufs[1]).ShouldBe("second"); } [Fact] // T:3089 public void WSCloseFrameWithPartialOrInvalid_ShouldSucceed() { var payloadText = Encoding.ASCII.GetBytes("hello"); var payload = new byte[2 + payloadText.Length]; BinaryPrimitives.WriteUInt16BigEndian(payload.AsSpan(0, 2), (ushort)WsConstants.CloseNormalClosure); payloadText.CopyTo(payload.AsSpan(2)); var client = CreateWsClient(); var readInfo = CreateReadInfo(); var closeFrame = CreateMaskedClientFrame(WsOpCode.Close, 1, final: true, compressed: false, payload); var initial = new[] { closeFrame[0] }; using var remainder = new MemoryStream(closeFrame[1..]); Should.Throw(() => client.WsRead(readInfo, remainder, initial)); lock (GetClientLock(client)) { var (chunks, _) = client.CollapsePtoNB(); chunks.Count.ShouldBe(1); chunks[0].Buffer.Length.ShouldBe(2 + 2 + payloadText.Length); chunks[0].Buffer[0].ShouldBe((byte)((byte)WsOpCode.Close | WsConstants.FinalBit)); BinaryPrimitives.ReadUInt16BigEndian(chunks[0].Buffer.AsSpan(2, 2)).ShouldBe((ushort)WsConstants.CloseNormalClosure); chunks[0].Buffer.AsSpan(4).ToArray().ShouldBe(payloadText); } client = CreateWsClient(); readInfo = CreateReadInfo(); closeFrame = CreateMaskedClientFrame(WsOpCode.Close, 1, final: true, compressed: false, payload[..1]); var partialHeader = new[] { closeFrame[0] }; using var invalidRemainder = new MemoryStream(closeFrame[1..]); Should.Throw(() => client.WsRead(readInfo, invalidRemainder, partialHeader)); lock (GetClientLock(client)) { var (chunks, _) = client.CollapsePtoNB(); chunks.Count.ShouldBe(1); chunks[0].Buffer.Length.ShouldBe(2); chunks[0].Buffer[0].ShouldBe((byte)((byte)WsOpCode.Close | WsConstants.FinalBit)); } } [Fact] // T:3093 public void WSEnqueueCloseMsg_ShouldSucceed() { var client = CreateWsClient(); lock (GetClientLock(client)) { client.WsEnqueueCloseMessage(ClosedState.ProtocolViolation); client.Ws!.CloseSent.ShouldBeTrue(); client.Ws.CloseMessage.ShouldNotBeNull(); client.Ws.CloseMessage![0].ShouldBe((byte)((byte)WsOpCode.Close | WsConstants.FinalBit)); } } [Fact] // T:3097 public void WSUpgradeConnDeadline_ShouldSucceed() { var options = new ServerOptions(); var errors = new List(); var warnings = new List(); var parseError = ServerOptions.ParseWebsocket( new Dictionary { ["handshake_timeout"] = "1ms", }, options, errors, warnings); parseError.ShouldBeNull(); errors.ShouldBeEmpty(); options.Websocket.HandshakeTimeout.ShouldBe(TimeSpan.FromMilliseconds(1)); } [Fact] // T:3098 public void WSCompressNegotiation_ShouldSucceed() { var headers = new System.Collections.Specialized.NameValueCollection { ["Sec-WebSocket-Extensions"] = "permessage-deflate; server_no_context_takeover; client_no_context_takeover", }; var (supported, noContext) = NatsServer.WsPMCExtensionSupport(headers, checkNoContextTakeOver: true); supported.ShouldBeTrue(); noContext.ShouldBeTrue(); } [Fact] // T:3099 public void WSSetHeader_ShouldSucceed() { var opts = new ServerOptions(); opts.Websocket.Headers["X-Test"] = "one"; opts.Websocket.Headers["X-Trace"] = "two"; var server = CreateWsServer(opts); var setHeaders = typeof(NatsServer).GetMethod("WsSetHeadersOptions", BindingFlags.Instance | BindingFlags.NonPublic); setHeaders.ShouldNotBeNull(); setHeaders!.Invoke(server, null); var wsField = typeof(NatsServer).GetField("_websocket", BindingFlags.Instance | BindingFlags.NonPublic); wsField.ShouldNotBeNull(); var state = wsField!.GetValue(server); state.ShouldNotBeNull(); var rawHeadersProp = state!.GetType().GetProperty("RawHeaders", BindingFlags.Instance | BindingFlags.Public); rawHeadersProp.ShouldNotBeNull(); var rawHeaders = rawHeadersProp!.GetValue(state) as string; rawHeaders.ShouldNotBeNull(); rawHeaders.ShouldContain("X-Test: one"); rawHeaders.ShouldContain("X-Trace: two"); } [Fact] // T:3102 public void WSSetOriginOptions_ShouldSucceed() { var opts = new ServerOptions(); opts.Websocket.SameOrigin = true; opts.Websocket.AllowedOrigins.Add("http://example.com:8080"); var server = CreateWsServer(opts); var setOrigins = typeof(NatsServer).GetMethod("WsSetOriginOptions", BindingFlags.Instance | BindingFlags.NonPublic); setOrigins.ShouldNotBeNull(); setOrigins!.Invoke(server, null); var wsField = typeof(NatsServer).GetField("_websocket", BindingFlags.Instance | BindingFlags.NonPublic); wsField.ShouldNotBeNull(); var state = wsField!.GetValue(server); state.ShouldNotBeNull(); var sameOriginProp = state!.GetType().GetProperty("SameOrigin", BindingFlags.Instance | BindingFlags.Public); ((bool)sameOriginProp!.GetValue(state)!).ShouldBeTrue(); var allowedOriginsProp = state.GetType().GetProperty("AllowedOrigins", BindingFlags.Instance | BindingFlags.Public); var allowedOrigins = allowedOriginsProp!.GetValue(state) as System.Collections.IDictionary; allowedOrigins.ShouldNotBeNull(); allowedOrigins!.Contains("example.com").ShouldBeTrue(); } [Fact] // T:3113 public void WSFrameOutbound_ShouldSucceed() { var client = CreateWsClient(); lock (GetClientLock(client)) { client.WsEnqueueControlMessageLocked(WsOpCode.Pong, Encoding.ASCII.GetBytes("abc")); var (chunks, attempted) = client.CollapsePtoNB(); chunks.Count.ShouldBe(1); attempted.ShouldBe(chunks[0].Count); } } [Fact] // T:3117 public void WSCompressionFrameSizeLimit_ShouldSucceed() { var readInfo = CreateReadInfo(); readInfo.CompressedBuffers.Add(Compress(Encoding.ASCII.GetBytes(new string('x', 2048)))); Should.Throw(() => readInfo.Decompress(128)); } [Fact] // T:3132 public void WSNoCorruptionWithFrameSizeLimit_ShouldSucceed() { var key = new byte[] { 1, 2, 3, 4 }; var buffers = new List { Encoding.ASCII.GetBytes("hello"), Encoding.ASCII.GetBytes("world"), }; var original = buffers.SelectMany(b => b).ToArray(); WebSocketHelpers.WsMaskBufs(key, buffers); WebSocketHelpers.WsMaskBufs(key, buffers); buffers.SelectMany(b => b).ToArray().ShouldBe(original); } private static NatsServer CreateWsServer(ServerOptions? options = null) { var (server, err) = NatsServer.NewServer(options ?? new ServerOptions()); err.ShouldBeNull(); server.ShouldNotBeNull(); return server!; } private static ClientConnection CreateWsClient() { var client = new ClientConnection(ClientKind.Client, server: null, nc: new MemoryStream()) { Ws = new WebsocketConnection { MaskRead = true, MaskWrite = false }, }; return client; } private static WsReadInfo CreateReadInfo() { var readInfo = new WsReadInfo { Mask = true }; readInfo.Init(); return readInfo; } private static object GetClientLock(ClientConnection client) { var muField = typeof(ClientConnection).GetField("_mu", BindingFlags.Instance | BindingFlags.NonPublic); muField.ShouldNotBeNull(); return muField!.GetValue(client)!; } private static byte[] CreateMaskedClientFrame(WsOpCode frameType, int frameNum, bool final, bool compressed, byte[] payload) { if (compressed) payload = Compress(payload); var frame = new byte[WsConstants.MaxFrameHeaderSize + payload.Length]; if (frameNum == 1) frame[0] = (byte)frameType; if (final) frame[0] |= WsConstants.FinalBit; if (compressed) frame[0] |= WsConstants.Rsv1Bit; var pos = 1; if (payload.Length <= 125) { frame[pos++] = (byte)(payload.Length | WsConstants.MaskBit); } else if (payload.Length < 65536) { frame[pos++] = (byte)(126 | WsConstants.MaskBit); BinaryPrimitives.WriteUInt16BigEndian(frame.AsSpan(pos, 2), (ushort)payload.Length); pos += 2; } else { frame[pos++] = (byte)(127 | WsConstants.MaskBit); BinaryPrimitives.WriteUInt64BigEndian(frame.AsSpan(pos, 8), (ulong)payload.Length); pos += 8; } var key = new byte[] { 1, 2, 3, 4 }; key.CopyTo(frame, pos); pos += 4; payload.CopyTo(frame, pos); WebSocketHelpers.WsMaskBuf(key, frame.AsSpan(pos, payload.Length)); pos += payload.Length; return frame[..pos]; } private static byte[] Compress(byte[] payload) { using var memory = new MemoryStream(); using (var compressor = new DeflateStream(memory, CompressionLevel.Fastest, leaveOpen: true)) compressor.Write(payload, 0, payload.Length); var compressed = memory.ToArray(); if (compressed.Length >= 4) return compressed[..^4]; return compressed; } }