feat: add WebSocket HTTP upgrade handshake
This commit is contained in:
259
src/NATS.Server/WebSocket/WsUpgrade.cs
Normal file
259
src/NATS.Server/WebSocket/WsUpgrade.cs
Normal file
@@ -0,0 +1,259 @@
|
||||
using System.Net;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
|
||||
namespace NATS.Server.WebSocket;
|
||||
|
||||
/// <summary>
|
||||
/// WebSocket HTTP upgrade handshake handler.
|
||||
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
|
||||
/// </summary>
|
||||
public static class WsUpgrade
|
||||
{
|
||||
public static async Task<WsUpgradeResult> TryUpgradeAsync(
|
||||
Stream inputStream, Stream outputStream, WebSocketOptions options)
|
||||
{
|
||||
try
|
||||
{
|
||||
var (method, path, headers) = await ReadHttpRequestAsync(inputStream);
|
||||
|
||||
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
||||
return await FailAsync(outputStream, 405, "request method must be GET");
|
||||
|
||||
if (!headers.ContainsKey("Host"))
|
||||
return await FailAsync(outputStream, 400, "'Host' missing in request");
|
||||
|
||||
if (!HeaderContains(headers, "Upgrade", "websocket"))
|
||||
return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
|
||||
|
||||
if (!HeaderContains(headers, "Connection", "Upgrade"))
|
||||
return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
|
||||
|
||||
if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
|
||||
return await FailAsync(outputStream, 400, "key missing");
|
||||
|
||||
if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
|
||||
return await FailAsync(outputStream, 400, "invalid version");
|
||||
|
||||
var kind = path switch
|
||||
{
|
||||
_ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
|
||||
_ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
|
||||
_ => WsClientKind.Client,
|
||||
};
|
||||
|
||||
// Origin checking
|
||||
if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
|
||||
{
|
||||
var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
|
||||
headers.TryGetValue("Origin", out var origin);
|
||||
if (string.IsNullOrEmpty(origin))
|
||||
headers.TryGetValue("Sec-WebSocket-Origin", out origin);
|
||||
var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
|
||||
if (originErr != null)
|
||||
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
|
||||
}
|
||||
|
||||
// Compression negotiation
|
||||
bool compress = options.Compression;
|
||||
if (compress)
|
||||
{
|
||||
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
|
||||
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
// No-masking support
|
||||
bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
|
||||
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
|
||||
|
||||
// Browser detection
|
||||
bool browser = false;
|
||||
bool noCompFrag = false;
|
||||
if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
|
||||
headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
|
||||
{
|
||||
browser = true;
|
||||
// Disable fragmentation of compressed frames for Safari browsers.
|
||||
// Safari has both "Version/" and "Safari/" in the user agent string,
|
||||
// while Chrome on macOS has "Safari/" but not "Version/".
|
||||
noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
|
||||
}
|
||||
|
||||
// Cookie extraction
|
||||
string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
|
||||
if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
|
||||
headers.TryGetValue("Cookie", out var cookieHeader))
|
||||
{
|
||||
var cookies = ParseCookies(cookieHeader);
|
||||
if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
|
||||
if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
|
||||
if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
|
||||
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
|
||||
}
|
||||
|
||||
// X-Forwarded-For client IP extraction
|
||||
string? clientIp = null;
|
||||
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
|
||||
{
|
||||
var ip = xff.Split(',')[0].Trim();
|
||||
if (IPAddress.TryParse(ip, out _))
|
||||
clientIp = ip;
|
||||
}
|
||||
|
||||
// Build the 101 Switching Protocols response
|
||||
var response = new StringBuilder();
|
||||
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
|
||||
response.Append(ComputeAcceptKey(key));
|
||||
response.Append("\r\n");
|
||||
if (compress)
|
||||
response.Append(WsConstants.PmcFullResponse);
|
||||
if (noMasking)
|
||||
response.Append(WsConstants.NoMaskingFullResponse);
|
||||
if (options.Headers != null)
|
||||
{
|
||||
foreach (var (k, v) in options.Headers)
|
||||
{
|
||||
response.Append(k);
|
||||
response.Append(": ");
|
||||
response.Append(v);
|
||||
response.Append("\r\n");
|
||||
}
|
||||
}
|
||||
|
||||
response.Append("\r\n");
|
||||
|
||||
var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
|
||||
await outputStream.WriteAsync(responseBytes);
|
||||
await outputStream.FlushAsync();
|
||||
|
||||
return new WsUpgradeResult(
|
||||
Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
|
||||
MaskRead: !noMasking, MaskWrite: false,
|
||||
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
|
||||
CookiePassword: cookiePassword, CookieToken: cookieToken,
|
||||
ClientIp: clientIp, Kind: kind);
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
return WsUpgradeResult.Failed;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
|
||||
/// </summary>
|
||||
public static string ComputeAcceptKey(string clientKey)
|
||||
{
|
||||
var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
var hash = SHA1.HashData(combined);
|
||||
return Convert.ToBase64String(hash);
|
||||
}
|
||||
|
||||
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
|
||||
{
|
||||
var statusText = statusCode switch
|
||||
{
|
||||
400 => "Bad Request",
|
||||
403 => "Forbidden",
|
||||
405 => "Method Not Allowed",
|
||||
_ => "Internal Server Error",
|
||||
};
|
||||
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
|
||||
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
|
||||
await output.FlushAsync();
|
||||
return WsUpgradeResult.Failed;
|
||||
}
|
||||
|
||||
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream)
|
||||
{
|
||||
var headerBytes = new List<byte>(4096);
|
||||
var buf = new byte[1];
|
||||
while (true)
|
||||
{
|
||||
int n = await stream.ReadAsync(buf);
|
||||
if (n == 0) throw new IOException("connection closed during handshake");
|
||||
headerBytes.Add(buf[0]);
|
||||
if (headerBytes.Count >= 4 &&
|
||||
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
|
||||
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
|
||||
break;
|
||||
if (headerBytes.Count > 8192)
|
||||
throw new InvalidOperationException("HTTP header too large");
|
||||
}
|
||||
|
||||
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
|
||||
var lines = text.Split("\r\n", StringSplitOptions.None);
|
||||
if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
|
||||
|
||||
var parts = lines[0].Split(' ');
|
||||
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
|
||||
var method = parts[0];
|
||||
var path = parts[1];
|
||||
|
||||
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
for (int i = 1; i < lines.Length; i++)
|
||||
{
|
||||
var line = lines[i];
|
||||
if (string.IsNullOrEmpty(line)) break;
|
||||
var colonIdx = line.IndexOf(':');
|
||||
if (colonIdx > 0)
|
||||
{
|
||||
var name = line[..colonIdx].Trim();
|
||||
var value = line[(colonIdx + 1)..].Trim();
|
||||
headers[name] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return (method, path, headers);
|
||||
}
|
||||
|
||||
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
|
||||
{
|
||||
if (!headers.TryGetValue(name, out var headerValue))
|
||||
return false;
|
||||
foreach (var token in headerValue.Split(','))
|
||||
{
|
||||
if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private static Dictionary<string, string> ParseCookies(string cookieHeader)
|
||||
{
|
||||
var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
|
||||
foreach (var pair in cookieHeader.Split(';'))
|
||||
{
|
||||
var trimmed = pair.Trim();
|
||||
var eqIdx = trimmed.IndexOf('=');
|
||||
if (eqIdx > 0)
|
||||
cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
|
||||
}
|
||||
|
||||
return cookies;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Result of a WebSocket upgrade handshake attempt.
|
||||
/// </summary>
|
||||
public readonly record struct WsUpgradeResult(
|
||||
bool Success,
|
||||
bool Compress,
|
||||
bool Browser,
|
||||
bool NoCompFrag,
|
||||
bool MaskRead,
|
||||
bool MaskWrite,
|
||||
string? CookieJwt,
|
||||
string? CookieUsername,
|
||||
string? CookiePassword,
|
||||
string? CookieToken,
|
||||
string? ClientIp,
|
||||
WsClientKind Kind)
|
||||
{
|
||||
public static readonly WsUpgradeResult Failed = new(
|
||||
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
||||
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
||||
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
||||
}
|
||||
Reference in New Issue
Block a user