feat: enforce mqtt auth tls and keepalive semantics
This commit is contained in:
@@ -149,6 +149,19 @@ public sealed class AuthService
|
||||
return raw.ToArray();
|
||||
}
|
||||
|
||||
public static bool ValidateMqttCredentials(
|
||||
string? configuredUsername,
|
||||
string? configuredPassword,
|
||||
string? providedUsername,
|
||||
string? providedPassword)
|
||||
{
|
||||
if (string.IsNullOrEmpty(configuredUsername) && string.IsNullOrEmpty(configuredPassword))
|
||||
return true;
|
||||
|
||||
return string.Equals(configuredUsername, providedUsername, StringComparison.Ordinal)
|
||||
&& string.Equals(configuredPassword, providedPassword, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
public string EncodeNonce(byte[] nonce)
|
||||
{
|
||||
return Convert.ToBase64String(nonce)
|
||||
|
||||
@@ -12,6 +12,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
private readonly SemaphoreSlim _writeGate = new(1, 1);
|
||||
private string _clientId = string.Empty;
|
||||
private bool _cleanSession = true;
|
||||
private TimeSpan _idleTimeout = Timeout.InfiniteTimeSpan;
|
||||
|
||||
public async Task RunAsync(CancellationToken ct)
|
||||
{
|
||||
@@ -20,7 +21,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
string line;
|
||||
try
|
||||
{
|
||||
line = await ReadLineAsync(ct);
|
||||
line = await ReadLineAsync(ct, _idleTimeout);
|
||||
}
|
||||
catch
|
||||
{
|
||||
@@ -31,8 +32,15 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
switch (packet.Type)
|
||||
{
|
||||
case MqttPacketType.Connect:
|
||||
if (!_listener.TryAuthenticate(packet.Username, packet.Password))
|
||||
{
|
||||
await WriteLineAsync("ERR mqtt auth failed", ct);
|
||||
return;
|
||||
}
|
||||
|
||||
_clientId = packet.ClientId;
|
||||
_cleanSession = packet.CleanSession;
|
||||
_idleTimeout = _listener.ResolveKeepAliveTimeout(packet.KeepAliveSeconds);
|
||||
var pending = _listener.OpenSession(_clientId, _cleanSession);
|
||||
await WriteLineAsync("CONNACK", ct);
|
||||
foreach (var redelivery in pending)
|
||||
@@ -83,19 +91,39 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<string> ReadLineAsync(CancellationToken ct)
|
||||
private async Task<string> ReadLineAsync(CancellationToken ct, TimeSpan idleTimeout)
|
||||
{
|
||||
CancellationToken token = ct;
|
||||
CancellationTokenSource? timeoutCts = null;
|
||||
if (idleTimeout != Timeout.InfiniteTimeSpan && idleTimeout > TimeSpan.Zero)
|
||||
{
|
||||
timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
|
||||
timeoutCts.CancelAfter(idleTimeout);
|
||||
token = timeoutCts.Token;
|
||||
}
|
||||
|
||||
var bytes = new List<byte>(64);
|
||||
var single = new byte[1];
|
||||
while (true)
|
||||
try
|
||||
{
|
||||
var read = await _stream.ReadAsync(single.AsMemory(0, 1), ct);
|
||||
if (read == 0)
|
||||
throw new IOException("mqtt closed");
|
||||
if (single[0] == (byte)'\n')
|
||||
break;
|
||||
if (single[0] != (byte)'\r')
|
||||
bytes.Add(single[0]);
|
||||
while (true)
|
||||
{
|
||||
var read = await _stream.ReadAsync(single.AsMemory(0, 1), token);
|
||||
if (read == 0)
|
||||
throw new IOException("mqtt closed");
|
||||
if (single[0] == (byte)'\n')
|
||||
break;
|
||||
if (single[0] != (byte)'\r')
|
||||
bytes.Add(single[0]);
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException) when (timeoutCts != null && !ct.IsCancellationRequested)
|
||||
{
|
||||
throw new IOException("mqtt keepalive timeout");
|
||||
}
|
||||
finally
|
||||
{
|
||||
timeoutCts?.Dispose();
|
||||
}
|
||||
|
||||
return Encoding.UTF8.GetString([.. bytes]);
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
using System.Collections.Concurrent;
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using NATS.Server.Auth;
|
||||
|
||||
namespace NATS.Server.Mqtt;
|
||||
|
||||
public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
||||
public sealed class MqttListener(
|
||||
string host,
|
||||
int port,
|
||||
string? requiredUsername = null,
|
||||
string? requiredPassword = null) : IAsyncDisposable
|
||||
{
|
||||
private readonly string _host = host;
|
||||
private int _port = port;
|
||||
private readonly string? _requiredUsername = requiredUsername;
|
||||
private readonly string? _requiredPassword = requiredPassword;
|
||||
private readonly ConcurrentDictionary<MqttConnection, byte> _connections = new();
|
||||
private readonly ConcurrentDictionary<string, ConcurrentDictionary<MqttConnection, byte>> _subscriptions = new(StringComparer.Ordinal);
|
||||
private readonly ConcurrentDictionary<string, MqttSessionState> _sessions = new(StringComparer.Ordinal);
|
||||
@@ -85,6 +92,17 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable
|
||||
session.Pending.TryRemove(packetId, out _);
|
||||
}
|
||||
|
||||
internal bool TryAuthenticate(string? username, string? password)
|
||||
=> AuthService.ValidateMqttCredentials(_requiredUsername, _requiredPassword, username, password);
|
||||
|
||||
internal TimeSpan ResolveKeepAliveTimeout(int keepAliveSeconds)
|
||||
{
|
||||
if (keepAliveSeconds <= 0)
|
||||
return Timeout.InfiniteTimeSpan;
|
||||
|
||||
return TimeSpan.FromSeconds(Math.Max(keepAliveSeconds * 1.5, 1));
|
||||
}
|
||||
|
||||
internal void Unregister(MqttConnection connection)
|
||||
{
|
||||
_connections.TryRemove(connection, out _);
|
||||
|
||||
@@ -16,7 +16,10 @@ public sealed record MqttPacket(
|
||||
string Payload,
|
||||
string ClientId,
|
||||
int PacketId = 0,
|
||||
bool CleanSession = true);
|
||||
bool CleanSession = true,
|
||||
string? Username = null,
|
||||
string? Password = null,
|
||||
int KeepAliveSeconds = 0);
|
||||
|
||||
public sealed class MqttProtocolParser
|
||||
{
|
||||
@@ -39,6 +42,9 @@ public sealed class MqttProtocolParser
|
||||
return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty);
|
||||
|
||||
var cleanSession = true;
|
||||
string? username = null;
|
||||
string? password = null;
|
||||
var keepAliveSeconds = 0;
|
||||
for (var i = 2; i < parts.Length; i++)
|
||||
{
|
||||
if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase)
|
||||
@@ -46,6 +52,19 @@ public sealed class MqttProtocolParser
|
||||
{
|
||||
cleanSession = parsedClean;
|
||||
}
|
||||
|
||||
if (parts[i].StartsWith("user=", StringComparison.OrdinalIgnoreCase))
|
||||
username = parts[i]["user=".Length..];
|
||||
|
||||
if (parts[i].StartsWith("pass=", StringComparison.OrdinalIgnoreCase))
|
||||
password = parts[i]["pass=".Length..];
|
||||
|
||||
if (parts[i].StartsWith("keepalive=", StringComparison.OrdinalIgnoreCase)
|
||||
&& int.TryParse(parts[i]["keepalive=".Length..], out var parsedKeepAlive)
|
||||
&& parsedKeepAlive >= 0)
|
||||
{
|
||||
keepAliveSeconds = parsedKeepAlive;
|
||||
}
|
||||
}
|
||||
|
||||
return new MqttPacket(
|
||||
@@ -53,7 +72,10 @@ public sealed class MqttProtocolParser
|
||||
string.Empty,
|
||||
string.Empty,
|
||||
parts[1],
|
||||
CleanSession: cleanSession);
|
||||
CleanSession: cleanSession,
|
||||
Username: username,
|
||||
Password: password,
|
||||
KeepAliveSeconds: keepAliveSeconds);
|
||||
}
|
||||
|
||||
if (trimmed.StartsWith("SUB ", StringComparison.Ordinal))
|
||||
|
||||
@@ -553,7 +553,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable
|
||||
if (_options.Mqtt is { Port: > 0 } mqttOptions)
|
||||
{
|
||||
var mqttHost = string.IsNullOrWhiteSpace(mqttOptions.Host) ? _options.Host : mqttOptions.Host;
|
||||
_mqttListener = new MqttListener(mqttHost, mqttOptions.Port);
|
||||
_mqttListener = new MqttListener(
|
||||
mqttHost,
|
||||
mqttOptions.Port,
|
||||
mqttOptions.Username,
|
||||
mqttOptions.Password);
|
||||
await _mqttListener.StartAsync(linked.Token);
|
||||
}
|
||||
if (_jetStreamService != null)
|
||||
|
||||
24
tests/NATS.Server.Tests/Mqtt/MqttAuthIntegrationTests.cs
Normal file
24
tests/NATS.Server.Tests/Mqtt/MqttAuthIntegrationTests.cs
Normal file
@@ -0,0 +1,24 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using NATS.Server.Mqtt;
|
||||
|
||||
namespace NATS.Server.Tests.Mqtt;
|
||||
|
||||
public class MqttAuthIntegrationTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error()
|
||||
{
|
||||
await using var listener = new MqttListener("127.0.0.1", 0, requiredUsername: "mqtt", requiredPassword: "secret");
|
||||
using var cts = new CancellationTokenSource();
|
||||
await listener.StartAsync(cts.Token);
|
||||
|
||||
using var client = new TcpClient();
|
||||
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||
var stream = client.GetStream();
|
||||
|
||||
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT auth-client user=bad pass=wrong");
|
||||
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("ERR mqtt auth failed");
|
||||
(await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull();
|
||||
}
|
||||
}
|
||||
26
tests/NATS.Server.Tests/Mqtt/MqttKeepAliveTests.cs
Normal file
26
tests/NATS.Server.Tests/Mqtt/MqttKeepAliveTests.cs
Normal file
@@ -0,0 +1,26 @@
|
||||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
using NATS.Server.Mqtt;
|
||||
|
||||
namespace NATS.Server.Tests.Mqtt;
|
||||
|
||||
public class MqttKeepAliveTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error()
|
||||
{
|
||||
await using var listener = new MqttListener("127.0.0.1", 0);
|
||||
using var cts = new CancellationTokenSource();
|
||||
await listener.StartAsync(cts.Token);
|
||||
|
||||
using var client = new TcpClient();
|
||||
await client.ConnectAsync(IPAddress.Loopback, listener.Port);
|
||||
var stream = client.GetStream();
|
||||
|
||||
await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT keepalive-client keepalive=1");
|
||||
(await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK");
|
||||
|
||||
await Task.Delay(2000);
|
||||
(await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull();
|
||||
}
|
||||
}
|
||||
@@ -68,4 +68,22 @@ internal static class MqttRuntimeWire
|
||||
|
||||
return Encoding.UTF8.GetString([.. bytes]);
|
||||
}
|
||||
|
||||
public static async Task<string?> ReadRawAsync(NetworkStream stream, int timeoutMs)
|
||||
{
|
||||
using var timeout = new CancellationTokenSource(timeoutMs);
|
||||
var one = new byte[1];
|
||||
try
|
||||
{
|
||||
var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token);
|
||||
if (read == 0)
|
||||
return null;
|
||||
|
||||
return Encoding.UTF8.GetString(one, 0, read);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
return "__timeout__";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user