feat: add WebSocket HTTP upgrade handshake

This commit is contained in:
Joseph Doherty
2026-02-23 04:53:21 -05:00
parent bd29c529a8
commit 1c948b5b0f
2 changed files with 485 additions and 0 deletions

View File

@@ -0,0 +1,259 @@
using System.Net;
using System.Security.Cryptography;
using System.Text;
namespace NATS.Server.WebSocket;
/// <summary>
/// WebSocket HTTP upgrade handshake handler.
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
/// </summary>
public static class WsUpgrade
{
public static async Task<WsUpgradeResult> TryUpgradeAsync(
Stream inputStream, Stream outputStream, WebSocketOptions options)
{
try
{
var (method, path, headers) = await ReadHttpRequestAsync(inputStream);
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
return await FailAsync(outputStream, 405, "request method must be GET");
if (!headers.ContainsKey("Host"))
return await FailAsync(outputStream, 400, "'Host' missing in request");
if (!HeaderContains(headers, "Upgrade", "websocket"))
return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
if (!HeaderContains(headers, "Connection", "Upgrade"))
return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
return await FailAsync(outputStream, 400, "key missing");
if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
return await FailAsync(outputStream, 400, "invalid version");
var kind = path switch
{
_ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
_ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
_ => WsClientKind.Client,
};
// Origin checking
if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
{
var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
headers.TryGetValue("Origin", out var origin);
if (string.IsNullOrEmpty(origin))
headers.TryGetValue("Sec-WebSocket-Origin", out origin);
var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
if (originErr != null)
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
}
// Compression negotiation
bool compress = options.Compression;
if (compress)
{
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
}
// No-masking support
bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
// Browser detection
bool browser = false;
bool noCompFrag = false;
if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
{
browser = true;
// Disable fragmentation of compressed frames for Safari browsers.
// Safari has both "Version/" and "Safari/" in the user agent string,
// while Chrome on macOS has "Safari/" but not "Version/".
noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
}
// Cookie extraction
string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
headers.TryGetValue("Cookie", out var cookieHeader))
{
var cookies = ParseCookies(cookieHeader);
if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
}
// X-Forwarded-For client IP extraction
string? clientIp = null;
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
{
var ip = xff.Split(',')[0].Trim();
if (IPAddress.TryParse(ip, out _))
clientIp = ip;
}
// Build the 101 Switching Protocols response
var response = new StringBuilder();
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
response.Append(ComputeAcceptKey(key));
response.Append("\r\n");
if (compress)
response.Append(WsConstants.PmcFullResponse);
if (noMasking)
response.Append(WsConstants.NoMaskingFullResponse);
if (options.Headers != null)
{
foreach (var (k, v) in options.Headers)
{
response.Append(k);
response.Append(": ");
response.Append(v);
response.Append("\r\n");
}
}
response.Append("\r\n");
var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
await outputStream.WriteAsync(responseBytes);
await outputStream.FlushAsync();
return new WsUpgradeResult(
Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
MaskRead: !noMasking, MaskWrite: false,
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
CookiePassword: cookiePassword, CookieToken: cookieToken,
ClientIp: clientIp, Kind: kind);
}
catch (Exception)
{
return WsUpgradeResult.Failed;
}
}
/// <summary>
/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
/// </summary>
public static string ComputeAcceptKey(string clientKey)
{
var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
var hash = SHA1.HashData(combined);
return Convert.ToBase64String(hash);
}
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
{
var statusText = statusCode switch
{
400 => "Bad Request",
403 => "Forbidden",
405 => "Method Not Allowed",
_ => "Internal Server Error",
};
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
await output.FlushAsync();
return WsUpgradeResult.Failed;
}
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream)
{
var headerBytes = new List<byte>(4096);
var buf = new byte[1];
while (true)
{
int n = await stream.ReadAsync(buf);
if (n == 0) throw new IOException("connection closed during handshake");
headerBytes.Add(buf[0]);
if (headerBytes.Count >= 4 &&
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
break;
if (headerBytes.Count > 8192)
throw new InvalidOperationException("HTTP header too large");
}
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
var lines = text.Split("\r\n", StringSplitOptions.None);
if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
var parts = lines[0].Split(' ');
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
var method = parts[0];
var path = parts[1];
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
for (int i = 1; i < lines.Length; i++)
{
var line = lines[i];
if (string.IsNullOrEmpty(line)) break;
var colonIdx = line.IndexOf(':');
if (colonIdx > 0)
{
var name = line[..colonIdx].Trim();
var value = line[(colonIdx + 1)..].Trim();
headers[name] = value;
}
}
return (method, path, headers);
}
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
{
if (!headers.TryGetValue(name, out var headerValue))
return false;
foreach (var token in headerValue.Split(','))
{
if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
return true;
}
return false;
}
private static Dictionary<string, string> ParseCookies(string cookieHeader)
{
var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
foreach (var pair in cookieHeader.Split(';'))
{
var trimmed = pair.Trim();
var eqIdx = trimmed.IndexOf('=');
if (eqIdx > 0)
cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
}
return cookies;
}
}
/// <summary>
/// Result of a WebSocket upgrade handshake attempt.
/// </summary>
public readonly record struct WsUpgradeResult(
bool Success,
bool Compress,
bool Browser,
bool NoCompFrag,
bool MaskRead,
bool MaskWrite,
string? CookieJwt,
string? CookieUsername,
string? CookiePassword,
string? CookieToken,
string? ClientIp,
WsClientKind Kind)
{
public static readonly WsUpgradeResult Failed = new(
Success: false, Compress: false, Browser: false, NoCompFrag: false,
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
}

View File

@@ -0,0 +1,226 @@
using System.Text;
using NATS.Server.WebSocket;
namespace NATS.Server.Tests.WebSocket;
public class WsUpgradeTests
{
private static string BuildValidRequest(string path = "/", string? extraHeaders = null)
{
var sb = new StringBuilder();
sb.Append($"GET {path} HTTP/1.1\r\n");
sb.Append("Host: localhost:4222\r\n");
sb.Append("Upgrade: websocket\r\n");
sb.Append("Connection: Upgrade\r\n");
sb.Append("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n");
sb.Append("Sec-WebSocket-Version: 13\r\n");
if (extraHeaders != null)
sb.Append(extraHeaders);
sb.Append("\r\n");
return sb.ToString();
}
[Fact]
public async Task ValidUpgrade_Returns101()
{
var request = BuildValidRequest();
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Client);
var response = ReadResponse(outputStream);
response.ShouldContain("HTTP/1.1 101");
response.ShouldContain("Upgrade: websocket");
response.ShouldContain("Sec-WebSocket-Accept:");
}
[Fact]
public async Task MissingUpgradeHeader_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("400");
}
[Fact]
public async Task MissingHost_Returns400()
{
var request = "GET / HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task WrongVersion_Returns400()
{
var request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 12\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
}
[Fact]
public async Task LeafNodePath_ReturnsLeafKind()
{
var request = BuildValidRequest("/leafnode");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Leaf);
}
[Fact]
public async Task MqttPath_ReturnsMqttKind()
{
var request = BuildValidRequest("/mqtt");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Kind.ShouldBe(WsClientKind.Mqtt);
}
[Fact]
public async Task CompressionNegotiation_WhenEnabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}; {WsConstants.PmcSrvNoCtx}; {WsConstants.PmcCliNoCtx}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeTrue();
ReadResponse(outputStream).ShouldContain("permessage-deflate");
}
[Fact]
public async Task CompressionNegotiation_WhenDisabled()
{
var request = BuildValidRequest(extraHeaders: $"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = false });
result.Success.ShouldBeTrue();
result.Compress.ShouldBeFalse();
}
[Fact]
public async Task NoMaskingHeader_ForLeaf()
{
var request = BuildValidRequest("/leafnode", "Nats-No-Masking: true\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.MaskRead.ShouldBeFalse();
}
[Fact]
public async Task BrowserDetection_Mozilla()
{
var request = BuildValidRequest(extraHeaders: "User-Agent: Mozilla/5.0 (Windows)\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.Browser.ShouldBeTrue();
}
[Fact]
public async Task SafariDetection_NoCompFrag()
{
var request = BuildValidRequest(extraHeaders:
"User-Agent: Mozilla/5.0 (Macintosh) Version/15.0 Safari/605.1.15\r\n" +
$"Sec-WebSocket-Extensions: {WsConstants.PmcExtension}\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true, Compression = true });
result.Success.ShouldBeTrue();
result.NoCompFrag.ShouldBeTrue();
}
[Fact]
public void AcceptKey_MatchesRfc6455Example()
{
// RFC 6455 Section 4.2.2 example
var key = WsUpgrade.ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ==");
key.ShouldBe("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
[Fact]
public async Task CookieExtraction()
{
var request = BuildValidRequest(extraHeaders:
"Cookie: jwt_token=my-jwt; nats_user=admin; nats_pass=secret\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var opts = new WebSocketOptions
{
NoTls = true,
JwtCookie = "jwt_token",
UsernameCookie = "nats_user",
PasswordCookie = "nats_pass",
};
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, opts);
result.Success.ShouldBeTrue();
result.CookieJwt.ShouldBe("my-jwt");
result.CookieUsername.ShouldBe("admin");
result.CookiePassword.ShouldBe("secret");
}
[Fact]
public async Task XForwardedFor_ExtractsClientIp()
{
var request = BuildValidRequest(extraHeaders: "X-Forwarded-For: 192.168.1.100\r\n");
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeTrue();
result.ClientIp.ShouldBe("192.168.1.100");
}
[Fact]
public async Task PostMethod_Returns405()
{
var request = "POST / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n";
var (inputStream, outputStream) = CreateStreamPair(request);
var result = await WsUpgrade.TryUpgradeAsync(inputStream, outputStream, new WebSocketOptions { NoTls = true });
result.Success.ShouldBeFalse();
ReadResponse(outputStream).ShouldContain("405");
}
// Helper: create a readable input stream and writable output stream
private static (Stream input, MemoryStream output) CreateStreamPair(string httpRequest)
{
var inputBytes = Encoding.ASCII.GetBytes(httpRequest);
return (new MemoryStream(inputBytes), new MemoryStream());
}
private static string ReadResponse(MemoryStream output)
{
output.Position = 0;
return Encoding.ASCII.GetString(output.ToArray());
}
}