diff --git a/src/NATS.Server/Auth/AuthService.cs b/src/NATS.Server/Auth/AuthService.cs index c828fd5..9a687e7 100644 --- a/src/NATS.Server/Auth/AuthService.cs +++ b/src/NATS.Server/Auth/AuthService.cs @@ -149,6 +149,19 @@ public sealed class AuthService return raw.ToArray(); } + public static bool ValidateMqttCredentials( + string? configuredUsername, + string? configuredPassword, + string? providedUsername, + string? providedPassword) + { + if (string.IsNullOrEmpty(configuredUsername) && string.IsNullOrEmpty(configuredPassword)) + return true; + + return string.Equals(configuredUsername, providedUsername, StringComparison.Ordinal) + && string.Equals(configuredPassword, providedPassword, StringComparison.Ordinal); + } + public string EncodeNonce(byte[] nonce) { return Convert.ToBase64String(nonce) diff --git a/src/NATS.Server/Mqtt/MqttConnection.cs b/src/NATS.Server/Mqtt/MqttConnection.cs index 29da2af..22a3119 100644 --- a/src/NATS.Server/Mqtt/MqttConnection.cs +++ b/src/NATS.Server/Mqtt/MqttConnection.cs @@ -12,6 +12,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA private readonly SemaphoreSlim _writeGate = new(1, 1); private string _clientId = string.Empty; private bool _cleanSession = true; + private TimeSpan _idleTimeout = Timeout.InfiniteTimeSpan; public async Task RunAsync(CancellationToken ct) { @@ -20,7 +21,7 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA string line; try { - line = await ReadLineAsync(ct); + line = await ReadLineAsync(ct, _idleTimeout); } catch { @@ -31,8 +32,15 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA switch (packet.Type) { case MqttPacketType.Connect: + if (!_listener.TryAuthenticate(packet.Username, packet.Password)) + { + await WriteLineAsync("ERR mqtt auth failed", ct); + return; + } + _clientId = packet.ClientId; _cleanSession = packet.CleanSession; + _idleTimeout = _listener.ResolveKeepAliveTimeout(packet.KeepAliveSeconds); var pending = _listener.OpenSession(_clientId, _cleanSession); await WriteLineAsync("CONNACK", ct); foreach (var redelivery in pending) @@ -83,19 +91,39 @@ public sealed class MqttConnection(TcpClient client, MqttListener listener) : IA } } - private async Task ReadLineAsync(CancellationToken ct) + private async Task ReadLineAsync(CancellationToken ct, TimeSpan idleTimeout) { + CancellationToken token = ct; + CancellationTokenSource? timeoutCts = null; + if (idleTimeout != Timeout.InfiniteTimeSpan && idleTimeout > TimeSpan.Zero) + { + timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + timeoutCts.CancelAfter(idleTimeout); + token = timeoutCts.Token; + } + var bytes = new List(64); var single = new byte[1]; - while (true) + try { - var read = await _stream.ReadAsync(single.AsMemory(0, 1), ct); - if (read == 0) - throw new IOException("mqtt closed"); - if (single[0] == (byte)'\n') - break; - if (single[0] != (byte)'\r') - bytes.Add(single[0]); + while (true) + { + var read = await _stream.ReadAsync(single.AsMemory(0, 1), token); + if (read == 0) + throw new IOException("mqtt closed"); + if (single[0] == (byte)'\n') + break; + if (single[0] != (byte)'\r') + bytes.Add(single[0]); + } + } + catch (OperationCanceledException) when (timeoutCts != null && !ct.IsCancellationRequested) + { + throw new IOException("mqtt keepalive timeout"); + } + finally + { + timeoutCts?.Dispose(); } return Encoding.UTF8.GetString([.. bytes]); diff --git a/src/NATS.Server/Mqtt/MqttListener.cs b/src/NATS.Server/Mqtt/MqttListener.cs index 1ad743f..19dacfa 100644 --- a/src/NATS.Server/Mqtt/MqttListener.cs +++ b/src/NATS.Server/Mqtt/MqttListener.cs @@ -1,13 +1,20 @@ using System.Collections.Concurrent; using System.Net; using System.Net.Sockets; +using NATS.Server.Auth; namespace NATS.Server.Mqtt; -public sealed class MqttListener(string host, int port) : IAsyncDisposable +public sealed class MqttListener( + string host, + int port, + string? requiredUsername = null, + string? requiredPassword = null) : IAsyncDisposable { private readonly string _host = host; private int _port = port; + private readonly string? _requiredUsername = requiredUsername; + private readonly string? _requiredPassword = requiredPassword; private readonly ConcurrentDictionary _connections = new(); private readonly ConcurrentDictionary> _subscriptions = new(StringComparer.Ordinal); private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); @@ -85,6 +92,17 @@ public sealed class MqttListener(string host, int port) : IAsyncDisposable session.Pending.TryRemove(packetId, out _); } + internal bool TryAuthenticate(string? username, string? password) + => AuthService.ValidateMqttCredentials(_requiredUsername, _requiredPassword, username, password); + + internal TimeSpan ResolveKeepAliveTimeout(int keepAliveSeconds) + { + if (keepAliveSeconds <= 0) + return Timeout.InfiniteTimeSpan; + + return TimeSpan.FromSeconds(Math.Max(keepAliveSeconds * 1.5, 1)); + } + internal void Unregister(MqttConnection connection) { _connections.TryRemove(connection, out _); diff --git a/src/NATS.Server/Mqtt/MqttProtocolParser.cs b/src/NATS.Server/Mqtt/MqttProtocolParser.cs index 698365b..8b91e45 100644 --- a/src/NATS.Server/Mqtt/MqttProtocolParser.cs +++ b/src/NATS.Server/Mqtt/MqttProtocolParser.cs @@ -16,7 +16,10 @@ public sealed record MqttPacket( string Payload, string ClientId, int PacketId = 0, - bool CleanSession = true); + bool CleanSession = true, + string? Username = null, + string? Password = null, + int KeepAliveSeconds = 0); public sealed class MqttProtocolParser { @@ -39,6 +42,9 @@ public sealed class MqttProtocolParser return new MqttPacket(MqttPacketType.Unknown, string.Empty, string.Empty, string.Empty); var cleanSession = true; + string? username = null; + string? password = null; + var keepAliveSeconds = 0; for (var i = 2; i < parts.Length; i++) { if (parts[i].StartsWith("clean=", StringComparison.OrdinalIgnoreCase) @@ -46,6 +52,19 @@ public sealed class MqttProtocolParser { cleanSession = parsedClean; } + + if (parts[i].StartsWith("user=", StringComparison.OrdinalIgnoreCase)) + username = parts[i]["user=".Length..]; + + if (parts[i].StartsWith("pass=", StringComparison.OrdinalIgnoreCase)) + password = parts[i]["pass=".Length..]; + + if (parts[i].StartsWith("keepalive=", StringComparison.OrdinalIgnoreCase) + && int.TryParse(parts[i]["keepalive=".Length..], out var parsedKeepAlive) + && parsedKeepAlive >= 0) + { + keepAliveSeconds = parsedKeepAlive; + } } return new MqttPacket( @@ -53,7 +72,10 @@ public sealed class MqttProtocolParser string.Empty, string.Empty, parts[1], - CleanSession: cleanSession); + CleanSession: cleanSession, + Username: username, + Password: password, + KeepAliveSeconds: keepAliveSeconds); } if (trimmed.StartsWith("SUB ", StringComparison.Ordinal)) diff --git a/src/NATS.Server/NatsServer.cs b/src/NATS.Server/NatsServer.cs index 66bb1e6..72b0212 100644 --- a/src/NATS.Server/NatsServer.cs +++ b/src/NATS.Server/NatsServer.cs @@ -553,7 +553,11 @@ public sealed class NatsServer : IMessageRouter, ISubListAccess, IDisposable if (_options.Mqtt is { Port: > 0 } mqttOptions) { var mqttHost = string.IsNullOrWhiteSpace(mqttOptions.Host) ? _options.Host : mqttOptions.Host; - _mqttListener = new MqttListener(mqttHost, mqttOptions.Port); + _mqttListener = new MqttListener( + mqttHost, + mqttOptions.Port, + mqttOptions.Username, + mqttOptions.Password); await _mqttListener.StartAsync(linked.Token); } if (_jetStreamService != null) diff --git a/tests/NATS.Server.Tests/Mqtt/MqttAuthIntegrationTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttAuthIntegrationTests.cs new file mode 100644 index 0000000..b315908 --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttAuthIntegrationTests.cs @@ -0,0 +1,24 @@ +using System.Net; +using System.Net.Sockets; +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttAuthIntegrationTests +{ + [Fact] + public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error() + { + await using var listener = new MqttListener("127.0.0.1", 0, requiredUsername: "mqtt", requiredPassword: "secret"); + using var cts = new CancellationTokenSource(); + await listener.StartAsync(cts.Token); + + using var client = new TcpClient(); + await client.ConnectAsync(IPAddress.Loopback, listener.Port); + var stream = client.GetStream(); + + await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT auth-client user=bad pass=wrong"); + (await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("ERR mqtt auth failed"); + (await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull(); + } +} diff --git a/tests/NATS.Server.Tests/Mqtt/MqttKeepAliveTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttKeepAliveTests.cs new file mode 100644 index 0000000..d57954a --- /dev/null +++ b/tests/NATS.Server.Tests/Mqtt/MqttKeepAliveTests.cs @@ -0,0 +1,26 @@ +using System.Net; +using System.Net.Sockets; +using NATS.Server.Mqtt; + +namespace NATS.Server.Tests.Mqtt; + +public class MqttKeepAliveTests +{ + [Fact] + public async Task Invalid_mqtt_credentials_or_keepalive_timeout_close_session_with_protocol_error() + { + await using var listener = new MqttListener("127.0.0.1", 0); + using var cts = new CancellationTokenSource(); + await listener.StartAsync(cts.Token); + + using var client = new TcpClient(); + await client.ConnectAsync(IPAddress.Loopback, listener.Port); + var stream = client.GetStream(); + + await MqttRuntimeWire.WriteLineAsync(stream, "CONNECT keepalive-client keepalive=1"); + (await MqttRuntimeWire.ReadLineAsync(stream, 1000)).ShouldBe("CONNACK"); + + await Task.Delay(2000); + (await MqttRuntimeWire.ReadRawAsync(stream, 1000)).ShouldBeNull(); + } +} diff --git a/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs b/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs index 7c6f905..d3ecbcb 100644 --- a/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs +++ b/tests/NATS.Server.Tests/Mqtt/MqttSessionRuntimeTests.cs @@ -68,4 +68,22 @@ internal static class MqttRuntimeWire return Encoding.UTF8.GetString([.. bytes]); } + + public static async Task ReadRawAsync(NetworkStream stream, int timeoutMs) + { + using var timeout = new CancellationTokenSource(timeoutMs); + var one = new byte[1]; + try + { + var read = await stream.ReadAsync(one.AsMemory(0, 1), timeout.Token); + if (read == 0) + return null; + + return Encoding.UTF8.GetString(one, 0, read); + } + catch (OperationCanceledException) + { + return "__timeout__"; + } + } }