Files
natsnet/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/ImplBacklog/WebSocketHandlerTests.cs
2026-02-28 21:49:41 -05:00

506 lines
19 KiB
C#

using System.Buffers.Binary;
using System.IO.Compression;
using System.Reflection;
using System.Text;
using Shouldly;
using ZB.MOM.NatsNet.Server;
using ZB.MOM.NatsNet.Server.Internal;
using ZB.MOM.NatsNet.Server.WebSocket;
namespace ZB.MOM.NatsNet.Server.Tests.ImplBacklog;
public sealed partial class WebSocketHandlerTests
{
[Fact] // T:3075
public void WSIsControlFrame_ShouldSucceed()
{
WebSocketHelpers.WsIsControlFrame(WsOpCode.Binary).ShouldBeFalse();
WebSocketHelpers.WsIsControlFrame(WsOpCode.Text).ShouldBeFalse();
WebSocketHelpers.WsIsControlFrame(WsOpCode.Ping).ShouldBeTrue();
WebSocketHelpers.WsIsControlFrame(WsOpCode.Pong).ShouldBeTrue();
WebSocketHelpers.WsIsControlFrame(WsOpCode.Close).ShouldBeTrue();
}
[Fact] // T:3076
public void WSUnmask_ShouldSucceed()
{
var key = new byte[] { 1, 2, 3, 4 };
var clear = Encoding.ASCII.GetBytes("this is a clear text");
static void Mask(byte[] k, byte[] buf)
{
for (var i = 0; i < buf.Length; i++)
buf[i] ^= k[i & 3];
}
var masked = clear.ToArray();
Mask(key, masked);
var readInfo = new WsReadInfo { Mask = true };
readInfo.Init();
key.CopyTo(readInfo.MaskKey, 0);
readInfo.Unmask(masked);
masked.ShouldBe(clear);
masked = clear.ToArray();
Mask(key, masked);
readInfo.MaskKeyPosition = 0;
readInfo.Unmask(masked.AsSpan(0, 3));
readInfo.Unmask(masked.AsSpan(3, 8));
readInfo.Unmask(masked.AsSpan(11));
masked.ShouldBe(clear);
}
[Fact] // T:3077
public void WSCreateCloseMessage_ShouldSucceed()
{
var payload = new string('A', WsConstants.MaxControlPayloadSize + 10);
var closeMessage = WebSocketHelpers.WsCreateCloseMessage(WsConstants.CloseProtocolError, payload);
BinaryPrimitives.ReadUInt16BigEndian(closeMessage.AsSpan(0, 2)).ShouldBe((ushort)WsConstants.CloseProtocolError);
closeMessage.Length.ShouldBe(WsConstants.MaxControlPayloadSize);
Encoding.UTF8.GetString(closeMessage.AsSpan(2)).ShouldEndWith("...");
}
[Fact] // T:3078
public void WSCreateFrameHeader_ShouldSucceed()
{
var (small, _) = WebSocketHelpers.WsCreateFrameHeader(useMasking: false, compressed: false, WsOpCode.Binary, 10);
small.Length.ShouldBe(2);
small[0].ShouldBe((byte)((byte)WsOpCode.Binary | WsConstants.FinalBit));
small[1].ShouldBe((byte)10);
var (medium, _) = WebSocketHelpers.WsCreateFrameHeader(useMasking: false, compressed: true, WsOpCode.Text, 600);
medium.Length.ShouldBe(4);
medium[0].ShouldBe((byte)((byte)WsOpCode.Text | WsConstants.FinalBit | WsConstants.Rsv1Bit));
medium[1].ShouldBe((byte)126);
BinaryPrimitives.ReadUInt16BigEndian(medium.AsSpan(2)).ShouldBe((ushort)600);
var (large, _) = WebSocketHelpers.WsCreateFrameHeader(useMasking: false, compressed: false, WsOpCode.Text, 100_000);
large.Length.ShouldBe(10);
large[1].ShouldBe((byte)127);
BinaryPrimitives.ReadUInt64BigEndian(large.AsSpan(2)).ShouldBe(100_000ul);
}
[Fact] // T:3079
public void WSReadUncompressedFrames_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var first = CreateMaskedClientFrame(WsOpCode.Binary, frameNum: 1, final: true, compressed: false, Encoding.ASCII.GetBytes("first message"));
var second = CreateMaskedClientFrame(WsOpCode.Binary, frameNum: 1, final: true, compressed: false, Encoding.ASCII.GetBytes("second message"));
var source = first.Concat(second).ToArray();
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), source);
bufs.Count.ShouldBe(2);
Encoding.ASCII.GetString(bufs[0]).ShouldBe("first message");
Encoding.ASCII.GetString(bufs[1]).ShouldBe("second message");
}
[Fact] // T:3080
public void WSReadCompressedFrames_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var clear = Encoding.ASCII.GetBytes("this is the uncompress data");
var compressed = CreateMaskedClientFrame(WsOpCode.Binary, frameNum: 1, final: true, compressed: true, clear);
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), compressed);
bufs.Count.ShouldBe(1);
var decoded = Encoding.ASCII.GetString(bufs[0]);
decoded.ShouldStartWith("this is the uncompress d");
}
[Fact] // T:3082
public void WSReadVariousFrameSizes_ShouldSucceed()
{
foreach (var size in new[] { 100, 1_000, 70_000 })
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var payload = Enumerable.Range(0, size).Select(i => (byte)('A' + (i % 26))).ToArray();
var frame = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, payload);
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), frame);
bufs.Count.ShouldBe(1);
bufs[0].ShouldBe(payload);
}
}
[Fact] // T:3083
public void WSReadFragmentedFrames_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var f1 = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: false, compressed: false, Encoding.ASCII.GetBytes("first"));
var f2 = CreateMaskedClientFrame(WsOpCode.Binary, 2, final: false, compressed: false, Encoding.ASCII.GetBytes("second"));
var f3 = CreateMaskedClientFrame(WsOpCode.Binary, 3, final: true, compressed: false, Encoding.ASCII.GetBytes("third"));
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), f1.Concat(f2).Concat(f3).ToArray());
bufs.Count.ShouldBe(3);
Encoding.ASCII.GetString(bufs[0]).ShouldBe("first");
Encoding.ASCII.GetString(bufs[1]).ShouldBe("second");
Encoding.ASCII.GetString(bufs[2]).ShouldBe("third");
}
[Fact] // T:3084
public void WSReadPartialFrameHeaderAtEndOfReadBuffer_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var first = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("msg1"));
var second = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("msg2"));
var source = first.Concat(second).ToArray();
var initial = source[..(first.Length + 1)];
using var remainder = new MemoryStream(source[(first.Length + 1)..]);
var bufs = client.WsRead(readInfo, remainder, initial);
bufs.Count.ShouldBe(1);
Encoding.ASCII.GetString(bufs[0]).ShouldBe("msg1");
remainder.Position.ShouldBe(5);
}
[Fact] // T:3085
public void WSReadPingFrame_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var ping = CreateMaskedClientFrame(WsOpCode.Ping, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("optional payload"));
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), ping);
bufs.ShouldBeEmpty();
lock (GetClientLock(client))
{
var (chunks, _) = client.CollapsePtoNB();
chunks.Count.ShouldBe(1);
chunks[0].Buffer[0].ShouldBe((byte)((byte)WsOpCode.Pong | WsConstants.FinalBit));
}
}
[Fact] // T:3086
public void WSReadPongFrame_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var pong = CreateMaskedClientFrame(WsOpCode.Pong, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("optional payload"));
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), pong);
bufs.ShouldBeEmpty();
lock (GetClientLock(client))
{
var (chunks, _) = client.CollapsePtoNB();
chunks.ShouldBeEmpty();
}
}
[Fact] // T:3087
public void WSReadCloseFrame_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var payload = new byte[2 + "optional payload"u8.Length];
BinaryPrimitives.WriteUInt16BigEndian(payload.AsSpan(0, 2), (ushort)WsConstants.CloseNormalClosure);
Encoding.ASCII.GetBytes("optional payload").CopyTo(payload.AsSpan(2));
var msg = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: true, compressed: false, Encoding.ASCII.GetBytes("msg"));
var close = CreateMaskedClientFrame(WsOpCode.Close, 1, final: true, compressed: false, payload);
Should.Throw<EndOfStreamException>(() => client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), msg.Concat(close).ToArray()));
}
[Fact] // T:3088
public void WSReadControlFrameBetweebFragmentedFrames_ShouldSucceed()
{
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var frag1 = CreateMaskedClientFrame(WsOpCode.Binary, 1, final: false, compressed: false, Encoding.ASCII.GetBytes("first"));
var ctrl = CreateMaskedClientFrame(WsOpCode.Pong, 1, final: true, compressed: false, Array.Empty<byte>());
var frag2 = CreateMaskedClientFrame(WsOpCode.Binary, 2, final: true, compressed: false, Encoding.ASCII.GetBytes("second"));
var bufs = client.WsRead(readInfo, new MemoryStream(Array.Empty<byte>()), frag1.Concat(ctrl).Concat(frag2).ToArray());
bufs.Count.ShouldBe(2);
Encoding.ASCII.GetString(bufs[0]).ShouldBe("first");
Encoding.ASCII.GetString(bufs[1]).ShouldBe("second");
}
[Fact] // T:3089
public void WSCloseFrameWithPartialOrInvalid_ShouldSucceed()
{
var payloadText = Encoding.ASCII.GetBytes("hello");
var payload = new byte[2 + payloadText.Length];
BinaryPrimitives.WriteUInt16BigEndian(payload.AsSpan(0, 2), (ushort)WsConstants.CloseNormalClosure);
payloadText.CopyTo(payload.AsSpan(2));
var client = CreateWsClient();
var readInfo = CreateReadInfo();
var closeFrame = CreateMaskedClientFrame(WsOpCode.Close, 1, final: true, compressed: false, payload);
var initial = new[] { closeFrame[0] };
using var remainder = new MemoryStream(closeFrame[1..]);
Should.Throw<EndOfStreamException>(() => client.WsRead(readInfo, remainder, initial));
lock (GetClientLock(client))
{
var (chunks, _) = client.CollapsePtoNB();
chunks.Count.ShouldBe(1);
chunks[0].Buffer.Length.ShouldBe(2 + 2 + payloadText.Length);
chunks[0].Buffer[0].ShouldBe((byte)((byte)WsOpCode.Close | WsConstants.FinalBit));
BinaryPrimitives.ReadUInt16BigEndian(chunks[0].Buffer.AsSpan(2, 2)).ShouldBe((ushort)WsConstants.CloseNormalClosure);
chunks[0].Buffer.AsSpan(4).ToArray().ShouldBe(payloadText);
}
client = CreateWsClient();
readInfo = CreateReadInfo();
closeFrame = CreateMaskedClientFrame(WsOpCode.Close, 1, final: true, compressed: false, payload[..1]);
var partialHeader = new[] { closeFrame[0] };
using var invalidRemainder = new MemoryStream(closeFrame[1..]);
Should.Throw<EndOfStreamException>(() => client.WsRead(readInfo, invalidRemainder, partialHeader));
lock (GetClientLock(client))
{
var (chunks, _) = client.CollapsePtoNB();
chunks.Count.ShouldBe(1);
chunks[0].Buffer.Length.ShouldBe(2);
chunks[0].Buffer[0].ShouldBe((byte)((byte)WsOpCode.Close | WsConstants.FinalBit));
}
}
[Fact] // T:3093
public void WSEnqueueCloseMsg_ShouldSucceed()
{
var client = CreateWsClient();
lock (GetClientLock(client))
{
client.WsEnqueueCloseMessage(ClosedState.ProtocolViolation);
client.Ws!.CloseSent.ShouldBeTrue();
client.Ws.CloseMessage.ShouldNotBeNull();
client.Ws.CloseMessage![0].ShouldBe((byte)((byte)WsOpCode.Close | WsConstants.FinalBit));
}
}
[Fact] // T:3097
public void WSUpgradeConnDeadline_ShouldSucceed()
{
var options = new ServerOptions();
var errors = new List<Exception>();
var warnings = new List<Exception>();
var parseError = ServerOptions.ParseWebsocket(
new Dictionary<string, object?>
{
["handshake_timeout"] = "1ms",
},
options,
errors,
warnings);
parseError.ShouldBeNull();
errors.ShouldBeEmpty();
options.Websocket.HandshakeTimeout.ShouldBe(TimeSpan.FromMilliseconds(1));
}
[Fact] // T:3098
public void WSCompressNegotiation_ShouldSucceed()
{
var headers = new System.Collections.Specialized.NameValueCollection
{
["Sec-WebSocket-Extensions"] = "permessage-deflate; server_no_context_takeover; client_no_context_takeover",
};
var (supported, noContext) = NatsServer.WsPMCExtensionSupport(headers, checkNoContextTakeOver: true);
supported.ShouldBeTrue();
noContext.ShouldBeTrue();
}
[Fact] // T:3099
public void WSSetHeader_ShouldSucceed()
{
var opts = new ServerOptions();
opts.Websocket.Headers["X-Test"] = "one";
opts.Websocket.Headers["X-Trace"] = "two";
var server = CreateWsServer(opts);
var setHeaders = typeof(NatsServer).GetMethod("WsSetHeadersOptions", BindingFlags.Instance | BindingFlags.NonPublic);
setHeaders.ShouldNotBeNull();
setHeaders!.Invoke(server, null);
var wsField = typeof(NatsServer).GetField("_websocket", BindingFlags.Instance | BindingFlags.NonPublic);
wsField.ShouldNotBeNull();
var state = wsField!.GetValue(server);
state.ShouldNotBeNull();
var rawHeadersProp = state!.GetType().GetProperty("RawHeaders", BindingFlags.Instance | BindingFlags.Public);
rawHeadersProp.ShouldNotBeNull();
var rawHeaders = rawHeadersProp!.GetValue(state) as string;
rawHeaders.ShouldNotBeNull();
rawHeaders.ShouldContain("X-Test: one");
rawHeaders.ShouldContain("X-Trace: two");
}
[Fact] // T:3102
public void WSSetOriginOptions_ShouldSucceed()
{
var opts = new ServerOptions();
opts.Websocket.SameOrigin = true;
opts.Websocket.AllowedOrigins.Add("http://example.com:8080");
var server = CreateWsServer(opts);
var setOrigins = typeof(NatsServer).GetMethod("WsSetOriginOptions", BindingFlags.Instance | BindingFlags.NonPublic);
setOrigins.ShouldNotBeNull();
setOrigins!.Invoke(server, null);
var wsField = typeof(NatsServer).GetField("_websocket", BindingFlags.Instance | BindingFlags.NonPublic);
wsField.ShouldNotBeNull();
var state = wsField!.GetValue(server);
state.ShouldNotBeNull();
var sameOriginProp = state!.GetType().GetProperty("SameOrigin", BindingFlags.Instance | BindingFlags.Public);
((bool)sameOriginProp!.GetValue(state)!).ShouldBeTrue();
var allowedOriginsProp = state.GetType().GetProperty("AllowedOrigins", BindingFlags.Instance | BindingFlags.Public);
var allowedOrigins = allowedOriginsProp!.GetValue(state) as System.Collections.IDictionary;
allowedOrigins.ShouldNotBeNull();
allowedOrigins!.Contains("example.com").ShouldBeTrue();
}
[Fact] // T:3113
public void WSFrameOutbound_ShouldSucceed()
{
var client = CreateWsClient();
lock (GetClientLock(client))
{
client.WsEnqueueControlMessageLocked(WsOpCode.Pong, Encoding.ASCII.GetBytes("abc"));
var (chunks, attempted) = client.CollapsePtoNB();
chunks.Count.ShouldBe(1);
attempted.ShouldBe(chunks[0].Count);
}
}
[Fact] // T:3117
public void WSCompressionFrameSizeLimit_ShouldSucceed()
{
var readInfo = CreateReadInfo();
readInfo.CompressedBuffers.Add(Compress(Encoding.ASCII.GetBytes(new string('x', 2048))));
Should.Throw<Exception>(() => readInfo.Decompress(128));
}
[Fact] // T:3132
public void WSNoCorruptionWithFrameSizeLimit_ShouldSucceed()
{
var key = new byte[] { 1, 2, 3, 4 };
var buffers = new List<byte[]>
{
Encoding.ASCII.GetBytes("hello"),
Encoding.ASCII.GetBytes("world"),
};
var original = buffers.SelectMany(b => b).ToArray();
WebSocketHelpers.WsMaskBufs(key, buffers);
WebSocketHelpers.WsMaskBufs(key, buffers);
buffers.SelectMany(b => b).ToArray().ShouldBe(original);
}
private static NatsServer CreateWsServer(ServerOptions? options = null)
{
var (server, err) = NatsServer.NewServer(options ?? new ServerOptions());
err.ShouldBeNull();
server.ShouldNotBeNull();
return server!;
}
private static ClientConnection CreateWsClient()
{
var client = new ClientConnection(ClientKind.Client, server: null, nc: new MemoryStream())
{
Ws = new WebsocketConnection { MaskRead = true, MaskWrite = false },
};
return client;
}
private static WsReadInfo CreateReadInfo()
{
var readInfo = new WsReadInfo { Mask = true };
readInfo.Init();
return readInfo;
}
private static object GetClientLock(ClientConnection client)
{
var muField = typeof(ClientConnection).GetField("_mu", BindingFlags.Instance | BindingFlags.NonPublic);
muField.ShouldNotBeNull();
return muField!.GetValue(client)!;
}
private static byte[] CreateMaskedClientFrame(WsOpCode frameType, int frameNum, bool final, bool compressed, byte[] payload)
{
if (compressed)
payload = Compress(payload);
var frame = new byte[WsConstants.MaxFrameHeaderSize + payload.Length];
if (frameNum == 1)
frame[0] = (byte)frameType;
if (final)
frame[0] |= WsConstants.FinalBit;
if (compressed)
frame[0] |= WsConstants.Rsv1Bit;
var pos = 1;
if (payload.Length <= 125)
{
frame[pos++] = (byte)(payload.Length | WsConstants.MaskBit);
}
else if (payload.Length < 65536)
{
frame[pos++] = (byte)(126 | WsConstants.MaskBit);
BinaryPrimitives.WriteUInt16BigEndian(frame.AsSpan(pos, 2), (ushort)payload.Length);
pos += 2;
}
else
{
frame[pos++] = (byte)(127 | WsConstants.MaskBit);
BinaryPrimitives.WriteUInt64BigEndian(frame.AsSpan(pos, 8), (ulong)payload.Length);
pos += 8;
}
var key = new byte[] { 1, 2, 3, 4 };
key.CopyTo(frame, pos);
pos += 4;
payload.CopyTo(frame, pos);
WebSocketHelpers.WsMaskBuf(key, frame.AsSpan(pos, payload.Length));
pos += payload.Length;
return frame[..pos];
}
private static byte[] Compress(byte[] payload)
{
using var memory = new MemoryStream();
using (var compressor = new DeflateStream(memory, CompressionLevel.Fastest, leaveOpen: true))
compressor.Write(payload, 0, payload.Length);
var compressed = memory.ToArray();
if (compressed.Length >= 4)
return compressed[..^4];
return compressed;
}
}