260 lines
10 KiB
C#
260 lines
10 KiB
C#
using System.Net;
|
|
using System.Security.Cryptography;
|
|
using System.Text;
|
|
|
|
namespace NATS.Server.WebSocket;
|
|
|
|
/// <summary>
|
|
/// WebSocket HTTP upgrade handshake handler.
|
|
/// Ported from golang/nats-server/server/websocket.go lines 731-917.
|
|
/// </summary>
|
|
public static class WsUpgrade
|
|
{
|
|
public static async Task<WsUpgradeResult> TryUpgradeAsync(
|
|
Stream inputStream, Stream outputStream, WebSocketOptions options)
|
|
{
|
|
try
|
|
{
|
|
var (method, path, headers) = await ReadHttpRequestAsync(inputStream);
|
|
|
|
if (!string.Equals(method, "GET", StringComparison.OrdinalIgnoreCase))
|
|
return await FailAsync(outputStream, 405, "request method must be GET");
|
|
|
|
if (!headers.ContainsKey("Host"))
|
|
return await FailAsync(outputStream, 400, "'Host' missing in request");
|
|
|
|
if (!HeaderContains(headers, "Upgrade", "websocket"))
|
|
return await FailAsync(outputStream, 400, "invalid value for header 'Upgrade'");
|
|
|
|
if (!HeaderContains(headers, "Connection", "Upgrade"))
|
|
return await FailAsync(outputStream, 400, "invalid value for header 'Connection'");
|
|
|
|
if (!headers.TryGetValue("Sec-WebSocket-Key", out var key) || string.IsNullOrEmpty(key))
|
|
return await FailAsync(outputStream, 400, "key missing");
|
|
|
|
if (!HeaderContains(headers, "Sec-WebSocket-Version", "13"))
|
|
return await FailAsync(outputStream, 400, "invalid version");
|
|
|
|
var kind = path switch
|
|
{
|
|
_ when path.EndsWith("/leafnode") => WsClientKind.Leaf,
|
|
_ when path.EndsWith("/mqtt") => WsClientKind.Mqtt,
|
|
_ => WsClientKind.Client,
|
|
};
|
|
|
|
// Origin checking
|
|
if (options.SameOrigin || options.AllowedOrigins is { Count: > 0 })
|
|
{
|
|
var checker = new WsOriginChecker(options.SameOrigin, options.AllowedOrigins);
|
|
headers.TryGetValue("Origin", out var origin);
|
|
if (string.IsNullOrEmpty(origin))
|
|
headers.TryGetValue("Sec-WebSocket-Origin", out origin);
|
|
var originErr = checker.CheckOrigin(origin, headers.GetValueOrDefault("Host", ""), isTls: false);
|
|
if (originErr != null)
|
|
return await FailAsync(outputStream, 403, $"origin not allowed: {originErr}");
|
|
}
|
|
|
|
// Compression negotiation
|
|
bool compress = options.Compression;
|
|
if (compress)
|
|
{
|
|
compress = headers.TryGetValue("Sec-WebSocket-Extensions", out var ext) &&
|
|
ext.Contains(WsConstants.PmcExtension, StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
// No-masking support
|
|
bool noMasking = headers.TryGetValue(WsConstants.NoMaskingHeader, out var nmVal) &&
|
|
string.Equals(nmVal.Trim(), WsConstants.NoMaskingValue, StringComparison.OrdinalIgnoreCase);
|
|
|
|
// Browser detection
|
|
bool browser = false;
|
|
bool noCompFrag = false;
|
|
if (kind is WsClientKind.Client or WsClientKind.Mqtt &&
|
|
headers.TryGetValue("User-Agent", out var ua) && ua.StartsWith("Mozilla/"))
|
|
{
|
|
browser = true;
|
|
// Disable fragmentation of compressed frames for Safari browsers.
|
|
// Safari has both "Version/" and "Safari/" in the user agent string,
|
|
// while Chrome on macOS has "Safari/" but not "Version/".
|
|
noCompFrag = compress && ua.Contains("Version/") && ua.Contains("Safari/");
|
|
}
|
|
|
|
// Cookie extraction
|
|
string? cookieJwt = null, cookieUsername = null, cookiePassword = null, cookieToken = null;
|
|
if ((kind is WsClientKind.Client or WsClientKind.Mqtt) &&
|
|
headers.TryGetValue("Cookie", out var cookieHeader))
|
|
{
|
|
var cookies = ParseCookies(cookieHeader);
|
|
if (options.JwtCookie != null) cookies.TryGetValue(options.JwtCookie, out cookieJwt);
|
|
if (options.UsernameCookie != null) cookies.TryGetValue(options.UsernameCookie, out cookieUsername);
|
|
if (options.PasswordCookie != null) cookies.TryGetValue(options.PasswordCookie, out cookiePassword);
|
|
if (options.TokenCookie != null) cookies.TryGetValue(options.TokenCookie, out cookieToken);
|
|
}
|
|
|
|
// X-Forwarded-For client IP extraction
|
|
string? clientIp = null;
|
|
if (headers.TryGetValue(WsConstants.XForwardedForHeader, out var xff))
|
|
{
|
|
var ip = xff.Split(',')[0].Trim();
|
|
if (IPAddress.TryParse(ip, out _))
|
|
clientIp = ip;
|
|
}
|
|
|
|
// Build the 101 Switching Protocols response
|
|
var response = new StringBuilder();
|
|
response.Append("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ");
|
|
response.Append(ComputeAcceptKey(key));
|
|
response.Append("\r\n");
|
|
if (compress)
|
|
response.Append(WsConstants.PmcFullResponse);
|
|
if (noMasking)
|
|
response.Append(WsConstants.NoMaskingFullResponse);
|
|
if (options.Headers != null)
|
|
{
|
|
foreach (var (k, v) in options.Headers)
|
|
{
|
|
response.Append(k);
|
|
response.Append(": ");
|
|
response.Append(v);
|
|
response.Append("\r\n");
|
|
}
|
|
}
|
|
|
|
response.Append("\r\n");
|
|
|
|
var responseBytes = Encoding.ASCII.GetBytes(response.ToString());
|
|
await outputStream.WriteAsync(responseBytes);
|
|
await outputStream.FlushAsync();
|
|
|
|
return new WsUpgradeResult(
|
|
Success: true, Compress: compress, Browser: browser, NoCompFrag: noCompFrag,
|
|
MaskRead: !noMasking, MaskWrite: false,
|
|
CookieJwt: cookieJwt, CookieUsername: cookieUsername,
|
|
CookiePassword: cookiePassword, CookieToken: cookieToken,
|
|
ClientIp: clientIp, Kind: kind);
|
|
}
|
|
catch (Exception)
|
|
{
|
|
return WsUpgradeResult.Failed;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Computes the Sec-WebSocket-Accept value per RFC 6455 Section 4.2.2.
|
|
/// </summary>
|
|
public static string ComputeAcceptKey(string clientKey)
|
|
{
|
|
var combined = Encoding.ASCII.GetBytes(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
|
var hash = SHA1.HashData(combined);
|
|
return Convert.ToBase64String(hash);
|
|
}
|
|
|
|
private static async Task<WsUpgradeResult> FailAsync(Stream output, int statusCode, string reason)
|
|
{
|
|
var statusText = statusCode switch
|
|
{
|
|
400 => "Bad Request",
|
|
403 => "Forbidden",
|
|
405 => "Method Not Allowed",
|
|
_ => "Internal Server Error",
|
|
};
|
|
var response = $"HTTP/1.1 {statusCode} {statusText}\r\nSec-WebSocket-Version: 13\r\nContent-Type: text/plain\r\nContent-Length: {reason.Length}\r\n\r\n{reason}";
|
|
await output.WriteAsync(Encoding.ASCII.GetBytes(response));
|
|
await output.FlushAsync();
|
|
return WsUpgradeResult.Failed;
|
|
}
|
|
|
|
private static async Task<(string method, string path, Dictionary<string, string> headers)> ReadHttpRequestAsync(Stream stream)
|
|
{
|
|
var headerBytes = new List<byte>(4096);
|
|
var buf = new byte[1];
|
|
while (true)
|
|
{
|
|
int n = await stream.ReadAsync(buf);
|
|
if (n == 0) throw new IOException("connection closed during handshake");
|
|
headerBytes.Add(buf[0]);
|
|
if (headerBytes.Count >= 4 &&
|
|
headerBytes[^4] == '\r' && headerBytes[^3] == '\n' &&
|
|
headerBytes[^2] == '\r' && headerBytes[^1] == '\n')
|
|
break;
|
|
if (headerBytes.Count > 8192)
|
|
throw new InvalidOperationException("HTTP header too large");
|
|
}
|
|
|
|
var text = Encoding.ASCII.GetString(headerBytes.ToArray());
|
|
var lines = text.Split("\r\n", StringSplitOptions.None);
|
|
if (lines.Length < 1) throw new InvalidOperationException("invalid HTTP request");
|
|
|
|
var parts = lines[0].Split(' ');
|
|
if (parts.Length < 3) throw new InvalidOperationException("invalid HTTP request line");
|
|
var method = parts[0];
|
|
var path = parts[1];
|
|
|
|
var headers = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
|
for (int i = 1; i < lines.Length; i++)
|
|
{
|
|
var line = lines[i];
|
|
if (string.IsNullOrEmpty(line)) break;
|
|
var colonIdx = line.IndexOf(':');
|
|
if (colonIdx > 0)
|
|
{
|
|
var name = line[..colonIdx].Trim();
|
|
var value = line[(colonIdx + 1)..].Trim();
|
|
headers[name] = value;
|
|
}
|
|
}
|
|
|
|
return (method, path, headers);
|
|
}
|
|
|
|
private static bool HeaderContains(Dictionary<string, string> headers, string name, string value)
|
|
{
|
|
if (!headers.TryGetValue(name, out var headerValue))
|
|
return false;
|
|
foreach (var token in headerValue.Split(','))
|
|
{
|
|
if (string.Equals(token.Trim(), value, StringComparison.OrdinalIgnoreCase))
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
private static Dictionary<string, string> ParseCookies(string cookieHeader)
|
|
{
|
|
var cookies = new Dictionary<string, string>(StringComparer.Ordinal);
|
|
foreach (var pair in cookieHeader.Split(';'))
|
|
{
|
|
var trimmed = pair.Trim();
|
|
var eqIdx = trimmed.IndexOf('=');
|
|
if (eqIdx > 0)
|
|
cookies[trimmed[..eqIdx].Trim()] = trimmed[(eqIdx + 1)..].Trim();
|
|
}
|
|
|
|
return cookies;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Result of a WebSocket upgrade handshake attempt.
|
|
/// </summary>
|
|
public readonly record struct WsUpgradeResult(
|
|
bool Success,
|
|
bool Compress,
|
|
bool Browser,
|
|
bool NoCompFrag,
|
|
bool MaskRead,
|
|
bool MaskWrite,
|
|
string? CookieJwt,
|
|
string? CookieUsername,
|
|
string? CookiePassword,
|
|
string? CookieToken,
|
|
string? ClientIp,
|
|
WsClientKind Kind)
|
|
{
|
|
public static readonly WsUpgradeResult Failed = new(
|
|
Success: false, Compress: false, Browser: false, NoCompFrag: false,
|
|
MaskRead: true, MaskWrite: false, CookieJwt: null, CookieUsername: null,
|
|
CookiePassword: null, CookieToken: null, ClientIp: null, Kind: WsClientKind.Client);
|
|
}
|