diff --git a/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Grpc/Services/ScadaGrpcService.cs b/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Grpc/Services/ScadaGrpcService.cs index 7bf5fe2..7dec9f0 100644 --- a/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Grpc/Services/ScadaGrpcService.cs +++ b/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Grpc/Services/ScadaGrpcService.cs @@ -83,10 +83,13 @@ namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services { try { - // Clean up subscriptions for this session - _subscriptionManager.UnsubscribeClient(request.SessionId); - + // Terminate session first — prevents new Subscribe RPCs from passing + // session validation while we clean up subscriptions var terminated = _sessionManager.TerminateSession(request.SessionId); + + // Then clean up all subscriptions for this session + _subscriptionManager.UnsubscribeSession(request.SessionId); + return Task.FromResult(new Scada.DisconnectResponse { Success = terminated, @@ -361,7 +364,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services throw new RpcException(new GrpcStatus(StatusCode.Unauthenticated, "Invalid session")); } - var reader = await _subscriptionManager.SubscribeAsync( + var (reader, subscriptionId) = await _subscriptionManager.SubscribeAsync( request.SessionId, request.Tags, context.CancellationToken); try @@ -410,12 +413,14 @@ namespace ZB.MOM.WW.LmxProxy.Host.Grpc.Services } catch (Exception ex) { - Log.Error(ex, "Subscribe stream error for session {SessionId}", request.SessionId); + Log.Error(ex, "Subscribe stream error for session {SessionId} subscription {SubscriptionId}", + request.SessionId, subscriptionId); throw new RpcException(new GrpcStatus(StatusCode.Internal, ex.Message)); } finally { - _subscriptionManager.UnsubscribeClient(request.SessionId); + // Clean up THIS subscription only, not the entire session + _subscriptionManager.UnsubscribeSubscription(subscriptionId); } } diff --git a/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/LmxProxyService.cs b/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/LmxProxyService.cs index dd8cf43..c724e39 100644 --- a/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/LmxProxyService.cs +++ b/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/LmxProxyService.cs @@ -115,7 +115,7 @@ namespace ZB.MOM.WW.LmxProxy.Host _sessionManager.OnSessionScavenged(sessionId => { Log.Information("Cleaning up subscriptions for scavenged session {SessionId}", sessionId); - _subscriptionManager.UnsubscribeClient(sessionId); + _subscriptionManager.UnsubscribeSession(sessionId); }); // 9. Create performance metrics diff --git a/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Subscriptions/SubscriptionManager.cs b/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Subscriptions/SubscriptionManager.cs index 4a9194f..24c859e 100644 --- a/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Subscriptions/SubscriptionManager.cs +++ b/lmxproxy/src/ZB.MOM.WW.LmxProxy.Host/Subscriptions/SubscriptionManager.cs @@ -22,7 +22,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions private readonly int _channelCapacity; private readonly BoundedChannelFullMode _channelFullMode; - // Client ID -> ClientSubscription + // Subscription ID -> ClientSubscription private readonly ConcurrentDictionary _clientSubscriptions = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); @@ -30,6 +30,10 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions private readonly ConcurrentDictionary _tagSubscriptions = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + // Session ID -> set of subscription IDs owned by that session + private readonly ConcurrentDictionary> _sessionSubscriptions + = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase); + private readonly ReaderWriterLockSlim _rwLock = new ReaderWriterLockSlim(); public SubscriptionManager(IScadaClient scadaClient, int channelCapacity = 1000, @@ -41,13 +45,15 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions } /// - /// Creates a subscription for a client. Returns a ChannelReader to stream from. + /// Creates a subscription for a session. Returns a ChannelReader and unique + /// subscription ID. Multiple subscriptions per session are supported. /// Awaits COM subscription creation so the initial OnDataChange callback /// is not missed. /// - public async Task> SubscribeAsync( - string clientId, IEnumerable addresses, CancellationToken ct) + public async Task<(ChannelReader<(string address, Vtq vtq)> Reader, string SubscriptionId)> SubscribeAsync( + string sessionId, IEnumerable addresses, CancellationToken ct) { + var subscriptionId = Guid.NewGuid().ToString("N"); var channel = Channel.CreateBounded<(string address, Vtq vtq)>( new BoundedChannelOptions(_channelCapacity) { @@ -58,8 +64,14 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions var addressSet = new HashSet(addresses, StringComparer.OrdinalIgnoreCase); - var clientSub = new ClientSubscription(clientId, channel, addressSet); - _clientSubscriptions[clientId] = clientSub; + var clientSub = new ClientSubscription(subscriptionId, sessionId, channel, addressSet); + _clientSubscriptions[subscriptionId] = clientSub; + + // Track which session owns this subscription + _sessionSubscriptions.AddOrUpdate( + sessionId, + _ => new HashSet(StringComparer.OrdinalIgnoreCase) { subscriptionId }, + (_, set) => { lock (set) { set.Add(subscriptionId); } return set; }); var newTags = new List(); @@ -70,12 +82,12 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions { if (_tagSubscriptions.TryGetValue(address, out var tagSub)) { - tagSub.ClientIds.Add(clientId); + tagSub.ClientIds.Add(subscriptionId); } else { _tagSubscriptions[address] = new TagSubscription(address, - new HashSet(StringComparer.OrdinalIgnoreCase) { clientId }); + new HashSet(StringComparer.OrdinalIgnoreCase) { subscriptionId }); newTags.Add(address); } } @@ -94,12 +106,12 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions await CreateMxAccessSubscriptionsAsync(newTags); } - // Register cancellation cleanup - ct.Register(() => UnsubscribeClient(clientId)); + // Register cancellation cleanup for this subscription only + ct.Register(() => UnsubscribeSubscription(subscriptionId)); - Log.Information("Client {ClientId} subscribed to {Count} tags ({NewCount} new MxAccess subscriptions)", - clientId, addressSet.Count, newTags.Count); - return channel.Reader; + Log.Information("Session {SessionId} subscription {SubscriptionId} subscribed to {Count} tags ({NewCount} new MxAccess subscriptions)", + sessionId, subscriptionId, addressSet.Count, newTags.Count); + return (channel.Reader, subscriptionId); } private async Task CreateMxAccessSubscriptionsAsync(List addresses) @@ -157,31 +169,42 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions } /// - /// Removes a client's subscriptions and cleans up tag subscriptions - /// when the last client unsubscribes. + /// Removes a single subscription and cleans up its tag refs. + /// Called when an individual Subscribe stream ends. /// - public void UnsubscribeClient(string clientId) + public void UnsubscribeSubscription(string subscriptionId) { - if (!_clientSubscriptions.TryRemove(clientId, out var clientSub)) + if (!_clientSubscriptions.TryRemove(subscriptionId, out var clientSub)) return; + // Remove from session tracking + if (_sessionSubscriptions.TryGetValue(clientSub.SessionId, out var subIds)) + { + lock (subIds) + { + subIds.Remove(subscriptionId); + if (subIds.Count == 0) + { + _sessionSubscriptions.TryRemove(clientSub.SessionId, out _); + } + } + } + var tagsToDispose = new List(); _rwLock.EnterWriteLock(); try { - // Scan all tag subscriptions — not just clientSub.Addresses — because - // a client may have called Subscribe multiple times (one tag per RPC), - // each overwriting the ClientSubscription. The last one's Addresses - // only has the final batch, but earlier tags still reference this client. - foreach (var kvp in _tagSubscriptions) + foreach (var address in clientSub.Addresses) { - if (kvp.Value.ClientIds.Remove(clientId)) + if (_tagSubscriptions.TryGetValue(address, out var tagSub)) { - if (kvp.Value.ClientIds.Count == 0) + tagSub.ClientIds.Remove(subscriptionId); + + if (tagSub.ClientIds.Count == 0) { - _tagSubscriptions.TryRemove(kvp.Key, out _); - tagsToDispose.Add(kvp.Key); + _tagSubscriptions.TryRemove(address, out _); + tagsToDispose.Add(address); } } } @@ -191,7 +214,6 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions _rwLock.ExitWriteLock(); } - // Unsubscribe tags with no remaining clients via address-based API if (tagsToDispose.Count > 0) { try @@ -204,11 +226,34 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions } } - // Complete the channel (signals end of stream to the gRPC handler) clientSub.Channel.Writer.TryComplete(); - Log.Information("Client {ClientId} unsubscribed ({Delivered} delivered, {Dropped} dropped)", - clientId, clientSub.DeliveredCount, clientSub.DroppedCount); + Log.Information("Subscription {SubscriptionId} removed ({Delivered} delivered, {Dropped} dropped)", + subscriptionId, clientSub.DeliveredCount, clientSub.DroppedCount); + } + + /// + /// Removes ALL subscriptions for a session. + /// Called on explicit Disconnect or session scavenging. + /// + public void UnsubscribeSession(string sessionId) + { + if (!_sessionSubscriptions.TryRemove(sessionId, out var subscriptionIds)) + return; + + List ids; + lock (subscriptionIds) + { + ids = subscriptionIds.ToList(); + } + + foreach (var subId in ids) + { + UnsubscribeSubscription(subId); + } + + Log.Information("All subscriptions for session {SessionId} removed ({Count} subscriptions)", + sessionId, ids.Count); } /// @@ -252,7 +297,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions } return new SubscriptionStats( - _clientSubscriptions.Count, + _sessionSubscriptions.Count, _tagSubscriptions.Count, _clientSubscriptions.Values.Sum(c => c.Addresses.Count), totalDelivered, @@ -266,6 +311,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions kvp.Value.Channel.Writer.TryComplete(); } _clientSubscriptions.Clear(); + _sessionSubscriptions.Clear(); _tagSubscriptions.Clear(); _rwLock.Dispose(); } @@ -274,16 +320,18 @@ namespace ZB.MOM.WW.LmxProxy.Host.Subscriptions private class ClientSubscription { - public ClientSubscription(string clientId, + public ClientSubscription(string subscriptionId, string sessionId, Channel<(string address, Vtq vtq)> channel, HashSet addresses) { - ClientId = clientId; + SubscriptionId = subscriptionId; + SessionId = sessionId; Channel = channel; Addresses = addresses; } - public string ClientId { get; } + public string SubscriptionId { get; } + public string SessionId { get; } public Channel<(string address, Vtq vtq)> Channel { get; } public HashSet Addresses { get; } diff --git a/lmxproxy/tests/ZB.MOM.WW.LmxProxy.Host.Tests/Subscriptions/SubscriptionManagerTests.cs b/lmxproxy/tests/ZB.MOM.WW.LmxProxy.Host.Tests/Subscriptions/SubscriptionManagerTests.cs index 6607327..733b37d 100644 --- a/lmxproxy/tests/ZB.MOM.WW.LmxProxy.Host.Tests/Subscriptions/SubscriptionManagerTests.cs +++ b/lmxproxy/tests/ZB.MOM.WW.LmxProxy.Host.Tests/Subscriptions/SubscriptionManagerTests.cs @@ -51,8 +51,9 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - var reader = await sm.SubscribeAsync("client1", new[] { "Tag1", "Tag2" }, cts.Token); + var (reader, subscriptionId) = await sm.SubscribeAsync("client1", new[] { "Tag1", "Tag2" }, cts.Token); reader.Should().NotBeNull(); + subscriptionId.Should().NotBeNullOrEmpty(); } [Fact] @@ -60,7 +61,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); + var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); var vtq = Vtq.Good(42.0); sm.OnTagValueChanged("Motor.Speed", vtq); @@ -76,8 +77,8 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - var reader1 = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); - var reader2 = await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token); + var (reader1, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); + var (reader2, _) = await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token); sm.OnTagValueChanged("Motor.Speed", Vtq.Good(99.0)); @@ -92,7 +93,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); + var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); sm.OnTagValueChanged("Motor.Torque", Vtq.Good(10.0)); @@ -101,26 +102,26 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions } [Fact] - public async Task UnsubscribeClient_CompletesChannel() + public async Task UnsubscribeSubscription_CompletesChannel() { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); + var (reader, subscriptionId) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); - sm.UnsubscribeClient("client1"); + sm.UnsubscribeSubscription(subscriptionId); // Channel should be completed reader.Completion.IsCompleted.Should().BeTrue(); } [Fact] - public async Task UnsubscribeClient_RemovesFromTagSubscriptions() + public async Task UnsubscribeSession_RemovesAllSubscriptions() { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); - sm.UnsubscribeClient("client1"); + sm.UnsubscribeSession("client1"); var stats = sm.GetStats(); stats.TotalClients.Should().Be(0); @@ -128,20 +129,20 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions } [Fact] - public async Task RefCounting_LastClientUnsubscribeRemovesTag() + public async Task RefCounting_LastSubscriptionUnsubscribeRemovesTag() { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); - await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token); + var (_, subId1) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); + var (_, subId2) = await sm.SubscribeAsync("client2", new[] { "Motor.Speed" }, cts.Token); sm.GetStats().TotalTags.Should().Be(1); - sm.UnsubscribeClient("client1"); + sm.UnsubscribeSubscription(subId1); sm.GetStats().TotalTags.Should().Be(1); // client2 still subscribed - sm.UnsubscribeClient("client2"); - sm.GetStats().TotalTags.Should().Be(0); // last client gone + sm.UnsubscribeSubscription(subId2); + sm.GetStats().TotalTags.Should().Be(0); // last subscription gone } [Fact] @@ -149,7 +150,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed", "Motor.Torque" }, cts.Token); + var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed", "Motor.Torque" }, cts.Token); sm.NotifyDisconnection(); @@ -165,7 +166,7 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient(), channelCapacity: 3); using var cts = new CancellationTokenSource(); - var reader = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); + var (reader, _) = await sm.SubscribeAsync("client1", new[] { "Motor.Speed" }, cts.Token); // Fill the channel beyond capacity for (int i = 0; i < 10; i++) @@ -184,8 +185,8 @@ namespace ZB.MOM.WW.LmxProxy.Host.Tests.Subscriptions { using var sm = new SubscriptionManager(new FakeScadaClient()); using var cts = new CancellationTokenSource(); - await sm.SubscribeAsync("c1", new[] { "Tag1", "Tag2" }, cts.Token); - await sm.SubscribeAsync("c2", new[] { "Tag2", "Tag3" }, cts.Token); + var (_, _) = await sm.SubscribeAsync("c1", new[] { "Tag1", "Tag2" }, cts.Token); + var (_, _) = await sm.SubscribeAsync("c2", new[] { "Tag2", "Tag3" }, cts.Token); var stats = sm.GetStats(); stats.TotalClients.Should().Be(2);