fix(lmxproxy): support multiple subscriptions per session

Key subscriptions by unique subscriptionId instead of sessionId to prevent
overwrites when the same session calls Subscribe multiple times (e.g. DCL
StaleTagMonitor). Add session-to-subscription reverse lookup for cleanup.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Joseph Doherty
2026-03-24 16:27:35 -04:00
parent b3076e18db
commit 6df2cbdf90
4 changed files with 115 additions and 61 deletions

View File

@@ -83,10 +83,13 @@ namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services
{ {
try try
{ {
// Clean up subscriptions for this session // Terminate session first — prevents new Subscribe RPCs from passing
_subscriptionManager.UnsubscribeClient(request.SessionId); // session validation while we clean up subscriptions
var terminated = _sessionManager.TerminateSession(request.SessionId); var terminated = _sessionManager.TerminateSession(request.SessionId);
// Then clean up all subscriptions for this session
_subscriptionManager.UnsubscribeSession(request.SessionId);
return Task.FromResult(new Scada.DisconnectResponse return Task.FromResult(new Scada.DisconnectResponse
{ {
Success = terminated, Success = terminated,
@@ -361,7 +364,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services
throw new RpcException(new GrpcStatus(StatusCode.Unauthenticated, "Invalid session")); throw new RpcException(new GrpcStatus(StatusCode.Unauthenticated, "Invalid session"));
} }
var reader = await _subscriptionManager.SubscribeAsync( var (reader, subscriptionId) = await _subscriptionManager.SubscribeAsync(
request.SessionId, request.Tags, context.CancellationToken); request.SessionId, request.Tags, context.CancellationToken);
try try
@@ -410,12 +413,14 @@ namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services
} }
catch (Exception ex) catch (Exception ex)
{ {
Log.Error(ex, "Subscribe stream error for session {SessionId}", request.SessionId); Log.Error(ex, "Subscribe stream error for session {SessionId} subscription {SubscriptionId}",
request.SessionId, subscriptionId);
throw new RpcException(new GrpcStatus(StatusCode.Internal, ex.Message)); throw new RpcException(new GrpcStatus(StatusCode.Internal, ex.Message));
} }
finally finally
{ {
_subscriptionManager.UnsubscribeClient(request.SessionId); // Clean up THIS subscription only, not the entire session
_subscriptionManager.UnsubscribeSubscription(subscriptionId);
} }
} }

View File

@@ -115,7 +115,7 @@ namespace ZB.MOM.WW.LmxProxy.Host
_sessionManager.OnSessionScavenged(sessionId => _sessionManager.OnSessionScavenged(sessionId =>
{ {
Log.Information("Cleaning up subscriptions for scavenged session {SessionId}", sessionId); Log.Information("Cleaning up subscriptions for scavenged session {SessionId}", sessionId);
_subscriptionManager.UnsubscribeClient(sessionId); _subscriptionManager.UnsubscribeSession(sessionId);
}); });
// 9. Create performance metrics // 9. Create performance metrics

View File

@@ -22,7 +22,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
private readonly int _channelCapacity; private readonly int _channelCapacity;
private readonly BoundedChannelFullMode _channelFullMode; private readonly BoundedChannelFullMode _channelFullMode;
// Client ID -> ClientSubscription // Subscription ID -> ClientSubscription
private readonly ConcurrentDictionary<string, ClientSubscription> _clientSubscriptions private readonly ConcurrentDictionary<string, ClientSubscription> _clientSubscriptions
= new ConcurrentDictionary<string, ClientSubscription>(StringComparer.OrdinalIgnoreCase); = new ConcurrentDictionary<string, ClientSubscription>(StringComparer.OrdinalIgnoreCase);
@@ -30,6 +30,10 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
private readonly ConcurrentDictionary<string, TagSubscription> _tagSubscriptions private readonly ConcurrentDictionary<string, TagSubscription> _tagSubscriptions
= new ConcurrentDictionary<string, TagSubscription>(StringComparer.OrdinalIgnoreCase); = new ConcurrentDictionary<string, TagSubscription>(StringComparer.OrdinalIgnoreCase);
// Session ID -> set of subscription IDs owned by that session
private readonly ConcurrentDictionary<string, HashSet<string>> _sessionSubscriptions
= new ConcurrentDictionary<string, HashSet<string>>(StringComparer.OrdinalIgnoreCase);
private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim(); private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim();
public SubscriptionManager(IScadaClient scadaClient, int channelCapacity = 1000, public SubscriptionManager(IScadaClient scadaClient, int channelCapacity = 1000,
@@ -41,13 +45,15 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
} }
/// <summary> /// <summary>
/// Creates a subscription for a client. Returns a ChannelReader to stream from. /// Creates a subscription for a session. Returns a ChannelReader and unique
/// subscription ID. Multiple subscriptions per session are supported.
/// Awaits COM subscription creation so the initial OnDataChange callback /// Awaits COM subscription creation so the initial OnDataChange callback
/// is not missed. /// is not missed.
/// </summary> /// </summary>
public async Task<ChannelReader<(string address, Vtq vtq)>> SubscribeAsync( public async Task<(ChannelReader<(string address, Vtq vtq)> Reader, string SubscriptionId)> SubscribeAsync(
string clientId, IEnumerable<string> addresses, CancellationToken ct) string sessionId, IEnumerable<string> addresses, CancellationToken ct)
{ {
var subscriptionId = Guid.NewGuid().ToString("N");
var channel = Channel.CreateBounded<(string address, Vtq vtq)>( var channel = Channel.CreateBounded<(string address, Vtq vtq)>(
new BoundedChannelOptions(_channelCapacity) new BoundedChannelOptions(_channelCapacity)
{ {
@@ -58,8 +64,14 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
var addressSet = new HashSet<string>(addresses, StringComparer.OrdinalIgnoreCase); var addressSet = new HashSet<string>(addresses, StringComparer.OrdinalIgnoreCase);
var clientSub = new ClientSubscription(clientId, channel, addressSet); var clientSub = new ClientSubscription(subscriptionId, sessionId, channel, addressSet);
_clientSubscriptions[clientId] = clientSub; _clientSubscriptions[subscriptionId] = clientSub;
// Track which session owns this subscription
_sessionSubscriptions.AddOrUpdate(
sessionId,
_ => new HashSet<string>(StringComparer.OrdinalIgnoreCase) { subscriptionId },
(_, set) => { lock (set) { set.Add(subscriptionId); } return set; });
var newTags = new List<string>(); var newTags = new List<string>();
@@ -70,12 +82,12 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
{ {
if (_tagSubscriptions.TryGetValue(address, out var tagSub)) if (_tagSubscriptions.TryGetValue(address, out var tagSub))
{ {
tagSub.ClientIds.Add(clientId); tagSub.ClientIds.Add(subscriptionId);
} }
else else
{ {
_tagSubscriptions[address] = new TagSubscription(address, _tagSubscriptions[address] = new TagSubscription(address,
new HashSet<string>(StringComparer.OrdinalIgnoreCase) { clientId }); new HashSet<string>(StringComparer.OrdinalIgnoreCase) { subscriptionId });
newTags.Add(address); newTags.Add(address);
} }
} }
@@ -94,12 +106,12 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
await CreateMxAccessSubscriptionsAsync(newTags); await CreateMxAccessSubscriptionsAsync(newTags);
} }
// Register cancellation cleanup // Register cancellation cleanup for this subscription only
ct.Register(() => UnsubscribeClient(clientId)); ct.Register(() => UnsubscribeSubscription(subscriptionId));
Log.Information("Client {ClientId} subscribed to {Count} tags ({NewCount} new MxAccess subscriptions)", Log.Information("Session {SessionId} subscription {SubscriptionId} subscribed to {Count} tags ({NewCount} new MxAccess subscriptions)",
clientId, addressSet.Count, newTags.Count); sessionId, subscriptionId, addressSet.Count, newTags.Count);
return channel.Reader; return (channel.Reader, subscriptionId);
} }
private async Task CreateMxAccessSubscriptionsAsync(List<string> addresses) private async Task CreateMxAccessSubscriptionsAsync(List<string> addresses)
@@ -157,31 +169,42 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
} }
/// <summary> /// <summary>
/// Removes a client's subscriptions and cleans up tag subscriptions /// Removes a single subscription and cleans up its tag refs.
/// when the last client unsubscribes. /// Called when an individual Subscribe stream ends.
/// </summary> /// </summary>
public void UnsubscribeClient(string clientId) public void UnsubscribeSubscription(string subscriptionId)
{ {
if (!_clientSubscriptions.TryRemove(clientId, out var clientSub)) if (!_clientSubscriptions.TryRemove(subscriptionId, out var clientSub))
return; return;
// Remove from session tracking
if (_sessionSubscriptions.TryGetValue(clientSub.SessionId, out var subIds))
{
lock (subIds)
{
subIds.Remove(subscriptionId);
if (subIds.Count == 0)
{
_sessionSubscriptions.TryRemove(clientSub.SessionId, out _);
}
}
}
var tagsToDispose = new List<string>(); var tagsToDispose = new List<string>();
_rwLock.EnterWriteLock(); _rwLock.EnterWriteLock();
try try
{ {
// Scan all tag subscriptions — not just clientSub.Addresses — because foreach (var address in clientSub.Addresses)
// a client may have called Subscribe multiple times (one tag per RPC),
// each overwriting the ClientSubscription. The last one's Addresses
// only has the final batch, but earlier tags still reference this client.
foreach (var kvp in _tagSubscriptions)
{ {
if (kvp.Value.ClientIds.Remove(clientId)) if (_tagSubscriptions.TryGetValue(address, out var tagSub))
{ {
if (kvp.Value.ClientIds.Count == 0) tagSub.ClientIds.Remove(subscriptionId);
if (tagSub.ClientIds.Count == 0)
{ {
_tagSubscriptions.TryRemove(kvp.Key, out _); _tagSubscriptions.TryRemove(address, out _);
tagsToDispose.Add(kvp.Key); tagsToDispose.Add(address);
} }
} }
} }
@@ -191,7 +214,6 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
_rwLock.ExitWriteLock(); _rwLock.ExitWriteLock();
} }
// Unsubscribe tags with no remaining clients via address-based API
if (tagsToDispose.Count > 0) if (tagsToDispose.Count > 0)
{ {
try try
@@ -204,11 +226,34 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
} }
} }
// Complete the channel (signals end of stream to the gRPC handler)
clientSub.Channel.Writer.TryComplete(); clientSub.Channel.Writer.TryComplete();
Log.Information("Client {ClientId} unsubscribed ({Delivered} delivered, {Dropped} dropped)", Log.Information("Subscription {SubscriptionId} removed ({Delivered} delivered, {Dropped} dropped)",
clientId, clientSub.DeliveredCount, clientSub.DroppedCount); subscriptionId, clientSub.DeliveredCount, clientSub.DroppedCount);
}
/// <summary>
/// Removes ALL subscriptions for a session.
/// Called on explicit Disconnect or session scavenging.
/// </summary>
public void UnsubscribeSession(string sessionId)
{
if (!_sessionSubscriptions.TryRemove(sessionId, out var subscriptionIds))
return;
List<string> ids;
lock (subscriptionIds)
{
ids = subscriptionIds.ToList();
}
foreach (var subId in ids)
{
UnsubscribeSubscription(subId);
}
Log.Information("All subscriptions for session {SessionId} removed ({Count} subscriptions)",
sessionId, ids.Count);
} }
/// <summary> /// <summary>
@@ -252,7 +297,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
} }
return new SubscriptionStats( return new SubscriptionStats(
_clientSubscriptions.Count, _sessionSubscriptions.Count,
_tagSubscriptions.Count, _tagSubscriptions.Count,
_clientSubscriptions.Values.Sum(c => c.Addresses.Count), _clientSubscriptions.Values.Sum(c => c.Addresses.Count),
totalDelivered, totalDelivered,
@@ -266,6 +311,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
kvp.Value.Channel.Writer.TryComplete(); kvp.Value.Channel.Writer.TryComplete();
} }
_clientSubscriptions.Clear(); _clientSubscriptions.Clear();
_sessionSubscriptions.Clear();
_tagSubscriptions.Clear(); _tagSubscriptions.Clear();
_rwLock.Dispose(); _rwLock.Dispose();
} }
@@ -274,16 +320,18 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions
private class ClientSubscription private class ClientSubscription
{ {
public ClientSubscription(string clientId, public ClientSubscription(string subscriptionId, string sessionId,
Channel<(string address, Vtq vtq)> channel, Channel<(string address, Vtq vtq)> channel,
HashSet<string> addresses) HashSet<string> addresses)
{ {
ClientId = clientId; SubscriptionId = subscriptionId;
SessionId = sessionId;
Channel = channel; Channel = channel;
Addresses = addresses; Addresses = addresses;
} }
public string ClientId { get; } public string SubscriptionId { get; }
public string SessionId { get; }
public Channel<(string address, Vtq vtq)> Channel { get; } public Channel<(string address, Vtq vtq)> Channel { get; }
public HashSet<string> Addresses { get; } public HashSet<string> Addresses { get; }

View File

@@ -51,8 +51,9 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader = await sm.SubscribeAsync("client1", new[] { "Tag1", "Tag2" }, cts.Token); var (reader, subscriptionId) = await sm.SubscribeAsync("client1", new[] { "Tag1", "Tag2" }, cts.Token);
reader.Should().NotBeNull(); reader.Should().NotBeNull();
subscriptionId.Should().NotBeNullOrEmpty();
} }
[Fact] [Fact]
@@ -60,7 +61,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
var vtq = Vtq.Good(42.0); var vtq = Vtq.Good(42.0);
sm.OnTagValueChanged("Motor.Speed", vtq); sm.OnTagValueChanged("Motor.Speed", vtq);
@@ -76,8 +77,8 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader1 = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var (reader1, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
var reader2 = await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token); var (reader2, _) = await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token);
sm.OnTagValueChanged("Motor.Speed", Vtq.Good(99.0)); sm.OnTagValueChanged("Motor.Speed", Vtq.Good(99.0));
@@ -92,7 +93,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
sm.OnTagValueChanged("Motor.Torque", Vtq.Good(10.0)); sm.OnTagValueChanged("Motor.Torque", Vtq.Good(10.0));
@@ -101,26 +102,26 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
} }
[Fact] [Fact]
public async Task UnsubscribeClient_CompletesChannel() public async Task UnsubscribeSubscription_CompletesChannel()
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var (reader, subscriptionId) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
sm.UnsubscribeClient("client1"); sm.UnsubscribeSubscription(subscriptionId);
// Channel should be completed // Channel should be completed
reader.Completion.IsCompleted.Should().BeTrue(); reader.Completion.IsCompleted.Should().BeTrue();
} }
[Fact] [Fact]
public async Task UnsubscribeClient_RemovesFromTagSubscriptions() public async Task UnsubscribeSession_RemovesAllSubscriptions()
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
sm.UnsubscribeClient("client1"); sm.UnsubscribeSession("client1");
var stats = sm.GetStats(); var stats = sm.GetStats();
stats.TotalClients.Should().Be(0); stats.TotalClients.Should().Be(0);
@@ -128,20 +129,20 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
} }
[Fact] [Fact]
public async Task RefCounting_LastClientUnsubscribeRemovesTag() public async Task RefCounting_LastSubscriptionUnsubscribeRemovesTag()
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var (_, subId1) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token); var (_, subId2) = await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token);
sm.GetStats().TotalTags.Should().Be(1); sm.GetStats().TotalTags.Should().Be(1);
sm.UnsubscribeClient("client1"); sm.UnsubscribeSubscription(subId1);
sm.GetStats().TotalTags.Should().Be(1); // client2 still subscribed sm.GetStats().TotalTags.Should().Be(1); // client2 still subscribed
sm.UnsubscribeClient("client2"); sm.UnsubscribeSubscription(subId2);
sm.GetStats().TotalTags.Should().Be(0); // last client gone sm.GetStats().TotalTags.Should().Be(0); // last subscription gone
} }
[Fact] [Fact]
@@ -149,7 +150,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed", "Motor.Torque" }, cts.Token); var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed", "Motor.Torque" }, cts.Token);
sm.NotifyDisconnection(); sm.NotifyDisconnection();
@@ -165,7 +166,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient(), channelCapacity: 3); using var sm = new SubscriptionManager(new FakeScadaClient(), channelCapacity: 3);
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token);
// Fill the channel beyond capacity // Fill the channel beyond capacity
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
@@ -184,8 +185,8 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions
{ {
using var sm = new SubscriptionManager(new FakeScadaClient()); using var sm = new SubscriptionManager(new FakeScadaClient());
using var cts = new CancellationTokenSource(); using var cts = new CancellationTokenSource();
await sm.SubscribeAsync("c1", new[] { "Tag1", "Tag2" }, cts.Token); var (_, _) = await sm.SubscribeAsync("c1", new[] { "Tag1", "Tag2" }, cts.Token);
await sm.SubscribeAsync("c2", new[] { "Tag2", "Tag3" }, cts.Token); var (_, _) = await sm.SubscribeAsync("c2", new[] { "Tag2", "Tag3" }, cts.Token);
var stats = sm.GetStats(); var stats = sm.GetStats();
stats.TotalClients.Should().Be(2); stats.TotalClients.Should().Be(2);