perf: optimize MQTT cross-protocol path (0.30x → 0.78x Go)

Replace per-message async fire-and-forget with direct-buffer write loop
mirroring NatsClient pattern: SpinLock-guarded buffer append, double-
buffer swap, single WriteAsync per batch.

- MqttConnection: add _directBuf/_writeBuf + RunMqttWriteLoopAsync
- MqttConnection: add EnqueuePublishNoFlush (zero-alloc PUBLISH format)
- MqttPacketWriter: add WritePublishTo(Span<byte>) + MeasurePublish
- MqttTopicMapper: add NatsToMqttBytes with bounded ConcurrentDictionary
- MqttNatsClientAdapter: synchronous SendMessageNoFlush + SignalFlush
- Skip FlushAsync on plain TCP sockets (TCP auto-flushes)
This commit is contained in:
Joseph Doherty
2026-03-13 14:25:13 -04:00
parent 699449da6a
commit 11e01b9026
14 changed files with 1113 additions and 10 deletions

View File

@@ -5,3 +5,12 @@ public class BenchmarkCoreCollection : ICollectionFixture<CoreServerPairFixture>
[CollectionDefinition("Benchmark-JetStream")]
public class BenchmarkJetStreamCollection : ICollectionFixture<JetStreamServerPairFixture>;
[CollectionDefinition("Benchmark-Mqtt")]
public class BenchmarkMqttCollection : ICollectionFixture<MqttServerFixture>;
[CollectionDefinition("Benchmark-Tls")]
public class BenchmarkTlsCollection : ICollectionFixture<TlsServerFixture>;
[CollectionDefinition("Benchmark-WebSocket")]
public class BenchmarkWebSocketCollection : ICollectionFixture<WebSocketServerFixture>;

View File

@@ -0,0 +1,93 @@
using NATS.Client.Core;
namespace NATS.Server.Benchmark.Tests.Infrastructure;
/// <summary>
/// Starts both a Go and .NET NATS server with MQTT and JetStream enabled for MQTT benchmarks.
/// Shared across all tests in the "Benchmark-Mqtt" collection.
/// </summary>
public sealed class MqttServerFixture : IAsyncLifetime
{
private GoServerProcess? _goServer;
private DotNetServerProcess? _dotNetServer;
private string? _goStoreDir;
private string? _dotNetStoreDir;
public int GoNatsPort => _goServer?.Port ?? throw new InvalidOperationException("Go server not started");
public int GoMqttPort { get; private set; }
public int DotNetNatsPort => _dotNetServer?.Port ?? throw new InvalidOperationException(".NET server not started");
public int DotNetMqttPort { get; private set; }
public bool GoAvailable => _goServer is not null;
public async Task InitializeAsync()
{
DotNetMqttPort = PortAllocator.AllocateFreePort();
_dotNetStoreDir = Path.Combine(Path.GetTempPath(), "nats-bench-dotnet-mqtt-" + Guid.NewGuid().ToString("N")[..8]);
Directory.CreateDirectory(_dotNetStoreDir);
var dotNetConfig = $$"""
jetstream {
store_dir: "{{_dotNetStoreDir}}"
max_mem_store: 64mb
max_file_store: 256mb
}
mqtt {
listen: 127.0.0.1:{{DotNetMqttPort}}
}
""";
_dotNetServer = new DotNetServerProcess(dotNetConfig);
var dotNetTask = _dotNetServer.StartAsync();
if (GoServerProcess.IsAvailable())
{
GoMqttPort = PortAllocator.AllocateFreePort();
_goStoreDir = Path.Combine(Path.GetTempPath(), "nats-bench-go-mqtt-" + Guid.NewGuid().ToString("N")[..8]);
Directory.CreateDirectory(_goStoreDir);
var goConfig = $$"""
jetstream {
store_dir: "{{_goStoreDir}}"
max_mem_store: 64mb
max_file_store: 256mb
}
mqtt {
listen: 127.0.0.1:{{GoMqttPort}}
}
""";
_goServer = new GoServerProcess(goConfig);
await Task.WhenAll(dotNetTask, _goServer.StartAsync());
}
else
{
await dotNetTask;
}
}
public async Task DisposeAsync()
{
if (_goServer is not null)
await _goServer.DisposeAsync();
if (_dotNetServer is not null)
await _dotNetServer.DisposeAsync();
CleanupDir(_goStoreDir);
CleanupDir(_dotNetStoreDir);
}
public NatsConnection CreateGoNatsClient()
=> new(new NatsOpts { Url = $"nats://127.0.0.1:{GoNatsPort}" });
public NatsConnection CreateDotNetNatsClient()
=> new(new NatsOpts { Url = $"nats://127.0.0.1:{DotNetNatsPort}" });
private static void CleanupDir(string? dir)
{
if (dir is not null && Directory.Exists(dir))
{
try { Directory.Delete(dir, recursive: true); }
catch { /* best-effort cleanup */ }
}
}
}

View File

@@ -0,0 +1,122 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using NATS.Client.Core;
namespace NATS.Server.Benchmark.Tests.Infrastructure;
/// <summary>
/// Starts both a Go and .NET NATS server with TLS enabled for transport overhead benchmarks.
/// Shared across all tests in the "Benchmark-Tls" collection.
/// </summary>
public sealed class TlsServerFixture : IAsyncLifetime
{
private GoServerProcess? _goServer;
private DotNetServerProcess? _dotNetServer;
private string? _tempDir;
public int GoPort => _goServer?.Port ?? throw new InvalidOperationException("Go server not started");
public int DotNetPort => _dotNetServer?.Port ?? throw new InvalidOperationException(".NET server not started");
public bool GoAvailable => _goServer is not null;
public async Task InitializeAsync()
{
_tempDir = Path.Combine(Path.GetTempPath(), $"nats-bench-tls-{Guid.NewGuid():N}");
Directory.CreateDirectory(_tempDir);
var caCertPath = Path.Combine(_tempDir, "ca.pem");
var serverCertPath = Path.Combine(_tempDir, "server-cert.pem");
var serverKeyPath = Path.Combine(_tempDir, "server-key.pem");
GenerateCertificates(caCertPath, serverCertPath, serverKeyPath);
var config = $$"""
tls {
cert_file: "{{serverCertPath}}"
key_file: "{{serverKeyPath}}"
ca_file: "{{caCertPath}}"
}
""";
_dotNetServer = new DotNetServerProcess(config);
var dotNetTask = _dotNetServer.StartAsync();
if (GoServerProcess.IsAvailable())
{
_goServer = new GoServerProcess(config);
await Task.WhenAll(dotNetTask, _goServer.StartAsync());
}
else
{
await dotNetTask;
}
}
public async Task DisposeAsync()
{
if (_goServer is not null)
await _goServer.DisposeAsync();
if (_dotNetServer is not null)
await _dotNetServer.DisposeAsync();
if (_tempDir is not null && Directory.Exists(_tempDir))
{
try { Directory.Delete(_tempDir, recursive: true); }
catch { /* best-effort cleanup */ }
}
}
public NatsConnection CreateGoTlsClient()
=> CreateTlsClient(GoPort);
public NatsConnection CreateDotNetTlsClient()
=> CreateTlsClient(DotNetPort);
private static NatsConnection CreateTlsClient(int port)
{
var opts = new NatsOpts
{
Url = $"nats://127.0.0.1:{port}",
TlsOpts = new NatsTlsOpts
{
Mode = TlsMode.Require,
InsecureSkipVerify = true,
},
};
return new NatsConnection(opts);
}
private static void GenerateCertificates(string caCertPath, string serverCertPath, string serverKeyPath)
{
using var caKey = RSA.Create(2048);
var caReq = new CertificateRequest(
"CN=Benchmark Test CA",
caKey,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);
caReq.CertificateExtensions.Add(
new X509BasicConstraintsExtension(certificateAuthority: true, hasPathLengthConstraint: false, pathLengthConstraint: 0, critical: true));
var now = DateTimeOffset.UtcNow;
using var caCert = caReq.CreateSelfSigned(now.AddMinutes(-5), now.AddDays(1));
using var serverKey = RSA.Create(2048);
var serverReq = new CertificateRequest(
"CN=localhost",
serverKey,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);
var sanBuilder = new SubjectAlternativeNameBuilder();
sanBuilder.AddIpAddress(System.Net.IPAddress.Loopback);
sanBuilder.AddDnsName("localhost");
serverReq.CertificateExtensions.Add(sanBuilder.Build());
serverReq.CertificateExtensions.Add(
new X509BasicConstraintsExtension(certificateAuthority: false, hasPathLengthConstraint: false, pathLengthConstraint: 0, critical: false));
using var serverCert = serverReq.Create(caCert, now.AddMinutes(-5), now.AddDays(1), [1, 2, 3, 4]);
File.WriteAllText(caCertPath, caCert.ExportCertificatePem());
File.WriteAllText(serverCertPath, serverCert.ExportCertificatePem());
File.WriteAllText(serverKeyPath, serverKey.ExportRSAPrivateKeyPem());
}
}

View File

@@ -0,0 +1,67 @@
using NATS.Client.Core;
namespace NATS.Server.Benchmark.Tests.Infrastructure;
/// <summary>
/// Starts both a Go and .NET NATS server with WebSocket enabled for transport overhead benchmarks.
/// Shared across all tests in the "Benchmark-WebSocket" collection.
/// </summary>
public sealed class WebSocketServerFixture : IAsyncLifetime
{
private GoServerProcess? _goServer;
private DotNetServerProcess? _dotNetServer;
public int GoNatsPort => _goServer?.Port ?? throw new InvalidOperationException("Go server not started");
public int GoWsPort { get; private set; }
public int DotNetNatsPort => _dotNetServer?.Port ?? throw new InvalidOperationException(".NET server not started");
public int DotNetWsPort { get; private set; }
public bool GoAvailable => _goServer is not null;
public async Task InitializeAsync()
{
DotNetWsPort = PortAllocator.AllocateFreePort();
var dotNetConfig = $$"""
websocket {
listen: 127.0.0.1:{{DotNetWsPort}}
no_tls: true
}
""";
_dotNetServer = new DotNetServerProcess(dotNetConfig);
var dotNetTask = _dotNetServer.StartAsync();
if (GoServerProcess.IsAvailable())
{
GoWsPort = PortAllocator.AllocateFreePort();
var goConfig = $$"""
websocket {
listen: 127.0.0.1:{{GoWsPort}}
no_tls: true
}
""";
_goServer = new GoServerProcess(goConfig);
await Task.WhenAll(dotNetTask, _goServer.StartAsync());
}
else
{
await dotNetTask;
}
}
public async Task DisposeAsync()
{
if (_goServer is not null)
await _goServer.DisposeAsync();
if (_dotNetServer is not null)
await _dotNetServer.DisposeAsync();
}
public NatsConnection CreateGoNatsClient()
=> new(new NatsOpts { Url = $"nats://127.0.0.1:{GoNatsPort}" });
public NatsConnection CreateDotNetNatsClient()
=> new(new NatsOpts { Url = $"nats://127.0.0.1:{DotNetNatsPort}" });
}

View File

@@ -0,0 +1,184 @@
using MQTTnet;
using MQTTnet.Client;
using NATS.Client.Core;
using NATS.Server.Benchmark.Tests.Harness;
using NATS.Server.Benchmark.Tests.Infrastructure;
using Xunit.Abstractions;
namespace NATS.Server.Benchmark.Tests.Mqtt;
[Collection("Benchmark-Mqtt")]
public class MqttThroughputTests(MqttServerFixture fixture, ITestOutputHelper output)
{
[Fact]
[Trait("Category", "Benchmark")]
public async Task MqttPubSub_128B()
{
const int payloadSize = 128;
const int messageCount = 5_000;
var dotnetResult = await RunMqttPubSub("MQTT PubSub (128B)", "DotNet", fixture.DotNetMqttPort, payloadSize, messageCount);
if (fixture.GoAvailable)
{
var goResult = await RunMqttPubSub("MQTT PubSub (128B)", "Go", fixture.GoMqttPort, payloadSize, messageCount);
BenchmarkResultWriter.WriteComparison(output, goResult, dotnetResult);
}
else
{
BenchmarkResultWriter.WriteSingle(output, dotnetResult);
}
}
[Fact]
[Trait("Category", "Benchmark")]
public async Task MqttCrossProtocol_NatsPub_MqttSub_128B()
{
const int payloadSize = 128;
const int messageCount = 5_000;
var dotnetResult = await RunCrossProtocol("Cross-Protocol NATS→MQTT (128B)", "DotNet", fixture.DotNetMqttPort, fixture.CreateDotNetNatsClient, payloadSize, messageCount);
if (fixture.GoAvailable)
{
var goResult = await RunCrossProtocol("Cross-Protocol NATS→MQTT (128B)", "Go", fixture.GoMqttPort, fixture.CreateGoNatsClient, payloadSize, messageCount);
BenchmarkResultWriter.WriteComparison(output, goResult, dotnetResult);
}
else
{
BenchmarkResultWriter.WriteSingle(output, dotnetResult);
}
}
private static async Task<BenchmarkResult> RunMqttPubSub(string name, string serverType, int mqttPort, int payloadSize, int messageCount)
{
var payload = new byte[payloadSize];
var topic = $"bench/mqtt/pubsub/{Guid.NewGuid():N}";
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60));
var factory = new MqttFactory();
using var subscriber = factory.CreateMqttClient();
using var publisher = factory.CreateMqttClient();
var subOpts = new MqttClientOptionsBuilder()
.WithTcpServer("127.0.0.1", mqttPort)
.WithClientId($"bench-sub-{Guid.NewGuid():N}")
.WithProtocolVersion(MQTTnet.Formatter.MqttProtocolVersion.V311)
.Build();
var pubOpts = new MqttClientOptionsBuilder()
.WithTcpServer("127.0.0.1", mqttPort)
.WithClientId($"bench-pub-{Guid.NewGuid():N}")
.WithProtocolVersion(MQTTnet.Formatter.MqttProtocolVersion.V311)
.Build();
await subscriber.ConnectAsync(subOpts, cts.Token);
await publisher.ConnectAsync(pubOpts, cts.Token);
var received = 0;
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
subscriber.ApplicationMessageReceivedAsync += _ =>
{
if (Interlocked.Increment(ref received) >= messageCount)
tcs.TrySetResult();
return Task.CompletedTask;
};
await subscriber.SubscribeAsync(
factory.CreateSubscribeOptionsBuilder()
.WithTopicFilter(topic)
.Build(),
cts.Token);
await Task.Delay(200, cts.Token);
var sw = System.Diagnostics.Stopwatch.StartNew();
for (var i = 0; i < messageCount; i++)
{
await publisher.PublishAsync(
new MqttApplicationMessageBuilder()
.WithTopic(topic)
.WithPayload(payload)
.Build(),
cts.Token);
}
await tcs.Task.WaitAsync(cts.Token);
sw.Stop();
await subscriber.DisconnectAsync(cancellationToken: cts.Token);
await publisher.DisconnectAsync(cancellationToken: cts.Token);
return new BenchmarkResult
{
Name = name,
ServerType = serverType,
TotalMessages = messageCount,
TotalBytes = (long)messageCount * payloadSize,
Duration = sw.Elapsed,
};
}
private static async Task<BenchmarkResult> RunCrossProtocol(string name, string serverType, int mqttPort, Func<NatsConnection> createNatsClient, int payloadSize, int messageCount)
{
var payload = new byte[payloadSize];
var natsSubject = $"bench.mqtt.cross.{Guid.NewGuid():N}";
var mqttTopic = natsSubject.Replace('.', '/');
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60));
var factory = new MqttFactory();
using var mqttSub = factory.CreateMqttClient();
var subOpts = new MqttClientOptionsBuilder()
.WithTcpServer("127.0.0.1", mqttPort)
.WithClientId($"bench-cross-sub-{Guid.NewGuid():N}")
.WithProtocolVersion(MQTTnet.Formatter.MqttProtocolVersion.V311)
.Build();
await mqttSub.ConnectAsync(subOpts, cts.Token);
var received = 0;
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
mqttSub.ApplicationMessageReceivedAsync += _ =>
{
if (Interlocked.Increment(ref received) >= messageCount)
tcs.TrySetResult();
return Task.CompletedTask;
};
await mqttSub.SubscribeAsync(
factory.CreateSubscribeOptionsBuilder()
.WithTopicFilter(mqttTopic)
.Build(),
cts.Token);
await Task.Delay(200, cts.Token);
await using var natsPub = createNatsClient();
await natsPub.ConnectAsync();
await natsPub.PingAsync(cts.Token);
var sw = System.Diagnostics.Stopwatch.StartNew();
for (var i = 0; i < messageCount; i++)
await natsPub.PublishAsync(natsSubject, payload, cancellationToken: cts.Token);
await natsPub.PingAsync(cts.Token);
await tcs.Task.WaitAsync(cts.Token);
sw.Stop();
await mqttSub.DisconnectAsync(cancellationToken: cts.Token);
return new BenchmarkResult
{
Name = name,
ServerType = serverType,
TotalMessages = messageCount,
TotalBytes = (long)messageCount * payloadSize,
Duration = sw.Elapsed,
};
}
}

View File

@@ -8,6 +8,7 @@
<PackageReference Include="coverlet.collector" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="NATS.Client.Core" />
<PackageReference Include="MQTTnet" />
<PackageReference Include="NATS.Client.JetStream" />
<PackageReference Include="Shouldly" />
<PackageReference Include="xunit" />

View File

@@ -25,6 +25,12 @@ dotnet test tests/NATS.Server.Benchmark.Tests --filter "Category=Benchmark&Fully
# JetStream only
dotnet test tests/NATS.Server.Benchmark.Tests --filter "Category=Benchmark&FullyQualifiedName~JetStream" -v normal
# MQTT benchmarks
dotnet test tests/NATS.Server.Benchmark.Tests --filter "Category=Benchmark&FullyQualifiedName~Mqtt" -v normal
# Transport benchmarks (TLS + WebSocket)
dotnet test tests/NATS.Server.Benchmark.Tests --filter "Category=Benchmark&FullyQualifiedName~Transport" -v normal
# A single benchmark by name
dotnet test tests/NATS.Server.Benchmark.Tests --filter "FullyQualifiedName=NATS.Server.Benchmark.Tests.CorePubSub.SinglePublisherThroughputTests.PubNoSub_16B" -v normal
```
@@ -50,6 +56,12 @@ Use `-v normal` or `--logger "console;verbosity=detailed"` to see the comparison
| `FileStoreAppendBenchmarks` | `FileStore_PurgeEx_Trim_Overhead` | FileStore purge/trim maintenance overhead under repeated updates |
| `OrderedConsumerTests` | `JSOrderedConsumer_Throughput` | JetStream ordered ephemeral consumer read throughput |
| `DurableConsumerFetchTests` | `JSDurableFetch_Throughput` | JetStream durable consumer fetch-in-batches throughput |
| `MqttThroughputTests` | `MqttPubSub_128B` | MQTT pub/sub throughput, 128-byte payload, QoS 0 |
| `MqttThroughputTests` | `MqttCrossProtocol_NatsPub_MqttSub_128B` | Cross-protocol NATS→MQTT routing throughput |
| `TlsPubSubTests` | `TlsPubSub1To1_128B` | TLS pub/sub 1:1 throughput, 128-byte payload |
| `TlsPubSubTests` | `TlsPubNoSub_128B` | TLS publish-only throughput, 128-byte payload |
| `WebSocketPubSubTests` | `WsPubSub1To1_128B` | WebSocket pub/sub 1:1 throughput, 128-byte payload |
| `WebSocketPubSubTests` | `WsPubNoSub_128B` | WebSocket publish-only throughput, 128-byte payload |
## Output Format
@@ -97,6 +109,9 @@ Infrastructure/
GoServerProcess.cs # Builds + launches golang/nats-server
CoreServerPairFixture.cs # IAsyncLifetime: Go + .NET servers for core tests
JetStreamServerPairFixture # IAsyncLifetime: Go + .NET servers with JetStream
MqttServerFixture.cs # IAsyncLifetime: .NET server with MQTT + JetStream
TlsServerFixture.cs # IAsyncLifetime: .NET server with TLS
WebSocketServerFixture.cs # IAsyncLifetime: .NET server with WebSocket
Collections.cs # xUnit collection definitions
Harness/

View File

@@ -0,0 +1,116 @@
using NATS.Client.Core;
using NATS.Server.Benchmark.Tests.Harness;
using NATS.Server.Benchmark.Tests.Infrastructure;
using Xunit.Abstractions;
namespace NATS.Server.Benchmark.Tests.Transport;
[Collection("Benchmark-Tls")]
public class TlsPubSubTests(TlsServerFixture fixture, ITestOutputHelper output)
{
[Fact]
[Trait("Category", "Benchmark")]
public async Task TlsPubSub1To1_128B()
{
const int payloadSize = 128;
const int messageCount = 10_000;
var dotnetResult = await RunTlsPubSub("TLS PubSub 1:1 (128B)", "DotNet", fixture.CreateDotNetTlsClient, payloadSize, messageCount);
if (fixture.GoAvailable)
{
var goResult = await RunTlsPubSub("TLS PubSub 1:1 (128B)", "Go", fixture.CreateGoTlsClient, payloadSize, messageCount);
BenchmarkResultWriter.WriteComparison(output, goResult, dotnetResult);
}
else
{
BenchmarkResultWriter.WriteSingle(output, dotnetResult);
}
}
[Fact]
[Trait("Category", "Benchmark")]
public async Task TlsPubNoSub_128B()
{
const int payloadSize = 128;
var dotnetResult = await RunTlsPubOnly("TLS Pub-Only (128B)", "DotNet", fixture.CreateDotNetTlsClient, payloadSize);
if (fixture.GoAvailable)
{
var goResult = await RunTlsPubOnly("TLS Pub-Only (128B)", "Go", fixture.CreateGoTlsClient, payloadSize);
BenchmarkResultWriter.WriteComparison(output, goResult, dotnetResult);
}
else
{
BenchmarkResultWriter.WriteSingle(output, dotnetResult);
}
}
private static async Task<BenchmarkResult> RunTlsPubSub(string name, string serverType, Func<NatsConnection> createClient, int payloadSize, int messageCount)
{
var payload = new byte[payloadSize];
var subject = $"bench.tls.pubsub.{Guid.NewGuid():N}";
await using var pubClient = createClient();
await using var subClient = createClient();
await pubClient.ConnectAsync();
await subClient.ConnectAsync();
var received = 0;
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var sub = await subClient.SubscribeCoreAsync<byte[]>(subject);
await subClient.PingAsync();
await pubClient.PingAsync();
var subTask = Task.Run(async () =>
{
await foreach (var msg in sub.Msgs.ReadAllAsync())
{
if (Interlocked.Increment(ref received) >= messageCount)
{
tcs.TrySetResult();
return;
}
}
});
var sw = System.Diagnostics.Stopwatch.StartNew();
for (var i = 0; i < messageCount; i++)
await pubClient.PublishAsync(subject, payload);
await pubClient.PingAsync();
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60));
await tcs.Task.WaitAsync(cts.Token);
sw.Stop();
await sub.UnsubscribeAsync();
return new BenchmarkResult
{
Name = name,
ServerType = serverType,
TotalMessages = messageCount,
TotalBytes = (long)messageCount * payloadSize,
Duration = sw.Elapsed,
};
}
private static async Task<BenchmarkResult> RunTlsPubOnly(string name, string serverType, Func<NatsConnection> createClient, int payloadSize)
{
var subject = $"bench.tls.pubonly.{Guid.NewGuid():N}";
await using var client = createClient();
await client.ConnectAsync();
var runner = new BenchmarkRunner { WarmupCount = 1_000, MeasurementCount = 100_000 };
return await runner.MeasureThroughputAsync(
name,
serverType,
payloadSize,
async _ => await client.PublishAsync(subject, new byte[payloadSize]));
}
}

View File

@@ -0,0 +1,209 @@
using System.Net.WebSockets;
using System.Text;
using NATS.Client.Core;
using NATS.Server.Benchmark.Tests.Harness;
using NATS.Server.Benchmark.Tests.Infrastructure;
using Xunit.Abstractions;
namespace NATS.Server.Benchmark.Tests.Transport;
[Collection("Benchmark-WebSocket")]
public class WebSocketPubSubTests(WebSocketServerFixture fixture, ITestOutputHelper output)
{
[Fact]
[Trait("Category", "Benchmark")]
public async Task WsPubSub1To1_128B()
{
const int payloadSize = 128;
const int messageCount = 5_000;
var dotnetResult = await RunWsPubSub("WebSocket PubSub 1:1 (128B)", "DotNet", fixture.DotNetWsPort, fixture.CreateDotNetNatsClient, payloadSize, messageCount);
if (fixture.GoAvailable)
{
var goResult = await RunWsPubSub("WebSocket PubSub 1:1 (128B)", "Go", fixture.GoWsPort, fixture.CreateGoNatsClient, payloadSize, messageCount);
BenchmarkResultWriter.WriteComparison(output, goResult, dotnetResult);
}
else
{
BenchmarkResultWriter.WriteSingle(output, dotnetResult);
}
}
[Fact]
[Trait("Category", "Benchmark")]
public async Task WsPubNoSub_128B()
{
const int payloadSize = 128;
const int messageCount = 10_000;
var dotnetResult = await RunWsPubOnly("WebSocket Pub-Only (128B)", "DotNet", fixture.DotNetWsPort, payloadSize, messageCount);
if (fixture.GoAvailable)
{
var goResult = await RunWsPubOnly("WebSocket Pub-Only (128B)", "Go", fixture.GoWsPort, payloadSize, messageCount);
BenchmarkResultWriter.WriteComparison(output, goResult, dotnetResult);
}
else
{
BenchmarkResultWriter.WriteSingle(output, dotnetResult);
}
}
private static async Task<BenchmarkResult> RunWsPubSub(string name, string serverType, int wsPort, Func<NatsConnection> createNatsClient, int payloadSize, int messageCount)
{
var payload = new byte[payloadSize];
var subject = $"bench.ws.pubsub.{Guid.NewGuid():N}";
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60));
using var ws = new ClientWebSocket();
await ws.ConnectAsync(new Uri($"ws://127.0.0.1:{wsPort}"), cts.Token);
var reader = new WsLineReader(ws);
// Read INFO
await reader.ReadLineAsync(cts.Token);
// Send CONNECT + SUB + PING
await WsSend(ws, "CONNECT {\"verbose\":false,\"protocol\":1}\r\n", cts.Token);
await WsSend(ws, $"SUB {subject} 1\r\n", cts.Token);
await WsSend(ws, "PING\r\n", cts.Token);
await WaitForPong(reader, cts.Token);
// NATS publisher
await using var natsPub = createNatsClient();
await natsPub.ConnectAsync();
await natsPub.PingAsync(cts.Token);
var sw = System.Diagnostics.Stopwatch.StartNew();
for (var i = 0; i < messageCount; i++)
await natsPub.PublishAsync(subject, payload, cancellationToken: cts.Token);
await natsPub.PingAsync(cts.Token);
// Read all MSG responses from WebSocket
var received = 0;
while (received < messageCount)
{
var line = await reader.ReadLineAsync(cts.Token);
if (line.StartsWith("MSG ", StringComparison.Ordinal))
{
await reader.ReadLineAsync(cts.Token);
received++;
}
}
sw.Stop();
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cts.Token);
return new BenchmarkResult
{
Name = name,
ServerType = serverType,
TotalMessages = messageCount,
TotalBytes = (long)messageCount * payloadSize,
Duration = sw.Elapsed,
};
}
private static async Task<BenchmarkResult> RunWsPubOnly(string name, string serverType, int wsPort, int payloadSize, int messageCount)
{
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60));
using var ws = new ClientWebSocket();
await ws.ConnectAsync(new Uri($"ws://127.0.0.1:{wsPort}"), cts.Token);
var reader = new WsLineReader(ws);
// Read INFO
await reader.ReadLineAsync(cts.Token);
// Send CONNECT
await WsSend(ws, "CONNECT {\"verbose\":false,\"protocol\":1}\r\n", cts.Token);
await WsSend(ws, "PING\r\n", cts.Token);
await WaitForPong(reader, cts.Token);
// Build a PUB command with raw binary payload
var subject = $"bench.ws.pubonly.{Guid.NewGuid():N}";
var pubLine = $"PUB {subject} {payloadSize}\r\n";
var pubPayload = new byte[payloadSize];
var pubCmd = Encoding.ASCII.GetBytes(pubLine)
.Concat(pubPayload)
.Concat(Encoding.ASCII.GetBytes("\r\n"))
.ToArray();
var sw = System.Diagnostics.Stopwatch.StartNew();
for (var i = 0; i < messageCount; i++)
await ws.SendAsync(pubCmd, WebSocketMessageType.Binary, true, cts.Token);
// Flush with PING/PONG
await WsSend(ws, "PING\r\n", cts.Token);
await WaitForPong(reader, cts.Token);
sw.Stop();
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cts.Token);
return new BenchmarkResult
{
Name = name,
ServerType = serverType,
TotalMessages = messageCount,
TotalBytes = (long)messageCount * payloadSize,
Duration = sw.Elapsed,
};
}
/// <summary>
/// Reads lines until PONG is received, skipping any INFO lines
/// (Go server sends a second INFO after CONNECT with connect_info:true).
/// </summary>
private static async Task WaitForPong(WsLineReader reader, CancellationToken ct)
{
while (true)
{
var line = await reader.ReadLineAsync(ct);
if (line == "PONG")
return;
}
}
private static async Task WsSend(ClientWebSocket ws, string data, CancellationToken ct)
{
var bytes = Encoding.ASCII.GetBytes(data);
await ws.SendAsync(bytes, WebSocketMessageType.Binary, true, ct);
}
/// <summary>
/// Buffers incoming WebSocket frames and returns one NATS protocol line at a time.
/// </summary>
private sealed class WsLineReader(ClientWebSocket ws)
{
private readonly byte[] _recvBuffer = new byte[65536];
private readonly StringBuilder _pending = new();
public async Task<string> ReadLineAsync(CancellationToken ct)
{
while (true)
{
var full = _pending.ToString();
var crlfIdx = full.IndexOf("\r\n", StringComparison.Ordinal);
if (crlfIdx >= 0)
{
var line = full[..crlfIdx];
_pending.Clear();
_pending.Append(full[(crlfIdx + 2)..]);
return line;
}
var result = await ws.ReceiveAsync(_recvBuffer, ct);
if (result.MessageType == WebSocketMessageType.Close)
throw new InvalidOperationException("WebSocket closed unexpectedly while reading");
var chunk = Encoding.ASCII.GetString(_recvBuffer, 0, result.Count);
_pending.Append(chunk);
}
}
}
}