using System.Diagnostics.CodeAnalysis; using System.Security.Cryptography; using Google.Protobuf.WellKnownTypes; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using ZB.MOM.WW.MxGateway.Contracts; using ZB.MOM.WW.MxGateway.Contracts.Proto; using ZB.MOM.WW.MxGateway.Server.Configuration; using ZB.MOM.WW.MxGateway.Server.Metrics; using ZB.MOM.WW.MxGateway.Server.Workers; namespace ZB.MOM.WW.MxGateway.Server.Sessions; public sealed class SessionManager : ISessionManager { public const string DefaultCloseReason = "client-close"; public const string GatewayShutdownReason = "gateway-shutdown"; public const string LeaseExpiredReason = "lease-expired"; private readonly ISessionRegistry _registry; private readonly ISessionWorkerClientFactory _workerClientFactory; private readonly GatewayMetrics _metrics; private readonly TimeProvider _timeProvider; private readonly ILogger _logger; private readonly GatewayOptions _options; private readonly SemaphoreSlim _sessionSlots; /// /// Initializes a new instance of . /// /// Session registry. /// Worker client factory. /// Gateway options. /// Gateway metrics. /// Time provider for timestamps. /// Logger. public SessionManager( ISessionRegistry registry, ISessionWorkerClientFactory workerClientFactory, IOptions options, GatewayMetrics metrics, TimeProvider? timeProvider = null, ILogger? logger = null) { _registry = registry ?? throw new ArgumentNullException(nameof(registry)); _workerClientFactory = workerClientFactory ?? throw new ArgumentNullException(nameof(workerClientFactory)); ArgumentNullException.ThrowIfNull(options); _metrics = metrics ?? throw new ArgumentNullException(nameof(metrics)); _timeProvider = timeProvider ?? TimeProvider.System; _logger = logger ?? NullLogger.Instance; _options = options.Value; _sessionSlots = new SemaphoreSlim(_options.Sessions.MaxSessions, _options.Sessions.MaxSessions); } /// /// Opens a new gateway session and connects to the worker. /// /// Session open request. /// Client authentication identity. /// API key identifier of the caller creating the session. /// Cancellation token. /// Opened gateway session. public async Task OpenSessionAsync( SessionOpenRequest request, string? clientIdentity, string? ownerKeyId, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(request); EnsureSessionCapacity(); GatewaySession? session = null; bool sessionOpenedRecorded = false; try { session = CreateSession(request, clientIdentity, ownerKeyId); if (!_registry.TryAdd(session)) { throw new SessionManagerException( SessionManagerErrorCode.OpenFailed, $"Session id collision while opening session {session.SessionId}."); } session.TransitionTo(SessionState.StartingWorker); IWorkerClient workerClient = await _workerClientFactory .CreateAsync(session, cancellationToken) .ConfigureAwait(false); session.AttachWorkerClient(workerClient); session.MarkReady(); _metrics.SessionOpened(); sessionOpenedRecorded = true; return session; } catch (Exception exception) { session?.MarkFaulted(exception.Message); if (session is not null) { _registry.TryRemove(session.SessionId, out _); await session.DisposeAsync().ConfigureAwait(false); } // If SessionOpened() already incremented the open-session gauge, // a failure after that point (e.g. auto-subscribe rejection) must // decrement it again so mxgateway.sessions.open does not leak. if (sessionOpenedRecorded) { _metrics.SessionRemoved(); } ReleaseSessionSlot(); _metrics.Fault(SessionManagerErrorCode.OpenFailed.ToString()); _logger.LogWarning( exception, "Failed to open gateway session {SessionId}.", session?.SessionId ?? ""); throw new SessionManagerException( SessionManagerErrorCode.OpenFailed, session is null ? "Failed to create session." : $"Failed to open session {session.SessionId}.", exception); } } /// /// Attempts to retrieve a session by ID. /// /// Session identifier. /// The session if found. /// True if session found; otherwise false. public bool TryGetSession( string sessionId, [MaybeNullWhen(false)] out GatewaySession session) { return _registry.TryGet(sessionId, out session); } /// /// Invokes a worker command on a session asynchronously. /// /// Session identifier. /// Worker command. /// Cancellation token. /// Command reply. public async Task InvokeAsync( string sessionId, WorkerCommand command, CancellationToken cancellationToken) { GatewaySession session = GetRequiredSession(sessionId); try { return await session.InvokeAsync(command, cancellationToken).ConfigureAwait(false); } catch (SessionManagerException) { throw; } catch (Exception exception) { if (session.WorkerClient?.State == WorkerClientState.Faulted) { session.MarkFaulted(exception.Message); } throw; } } /// /// Reads events from a session's event stream asynchronously. /// /// Session identifier. /// Cancellation token. /// Async enumerable of worker events. public IAsyncEnumerable ReadEventsAsync( string sessionId, CancellationToken cancellationToken) { GatewaySession session = GetRequiredSession(sessionId); return session.ReadEventsAsync(cancellationToken); } /// /// Closes a gateway session asynchronously. /// /// Session identifier. /// Cancellation token. /// Session close result. public async Task CloseSessionAsync( string sessionId, CancellationToken cancellationToken) { GatewaySession session = GetRequiredSession(sessionId); SessionCloseResult result = await CloseSessionCoreAsync( session, DefaultCloseReason, cancellationToken).ConfigureAwait(false); return result; } /// /// Forcefully terminates a session's worker without attempting graceful shutdown. /// Mirrors the registry/metrics cleanup that /// performs after a successful close, but skips the WorkerClient.ShutdownAsync /// step that would otherwise attempt. /// /// Session identifier. /// Reason recorded for the kill. /// Cancellation token. /// Session close result. public async Task KillWorkerAsync( string sessionId, string reason, CancellationToken cancellationToken) { ArgumentException.ThrowIfNullOrWhiteSpace(reason); cancellationToken.ThrowIfCancellationRequested(); GatewaySession session = GetRequiredSession(sessionId); // Serialize concurrent kill/close attempts on this session by routing through the // per-session close lock (Server-045). Returns whether the session was already in // Closed state when the lock was acquired so the metric counter is incremented at // most once across concurrent callers. bool wasClosed; try { wasClosed = await session.KillWorkerWithCloseGateAsync(reason, cancellationToken).ConfigureAwait(false); } catch (Exception exception) { session.MarkFaulted(exception.Message); _metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString()); // Server-044: the open-session gauge was incremented in OpenSessionAsync; // every session reaching KillWorkerAsync had SessionOpened recorded. If the // kill path throws, decrement the gauge here so mxgateway.sessions.open // does not leak — mirroring the Server-006 fix on OpenSessionAsync. _metrics.SessionRemoved(); await RemoveSessionAsync(session).ConfigureAwait(false); throw new SessionManagerException( SessionManagerErrorCode.CloseFailed, $"Failed to kill worker for session {sessionId}.", exception); } if (!wasClosed) { _metrics.SessionClosed(); } await RemoveSessionAsync(session).ConfigureAwait(false); _logger.LogInformation( "Worker for session {SessionId} killed; reason={Reason}.", sessionId, reason); return new SessionCloseResult(sessionId, SessionState.Closed, AlreadyClosed: wasClosed); } /// /// Closes all sessions with expired leases asynchronously. /// /// Current time for lease expiration check. /// Cancellation token. /// Count of sessions closed. public async Task CloseExpiredLeasesAsync( DateTimeOffset now, CancellationToken cancellationToken) { int closedCount = 0; foreach (GatewaySession session in _registry.Snapshot()) { if (!session.IsLeaseExpired(now)) { continue; } await CloseSessionCoreAsync(session, LeaseExpiredReason, cancellationToken).ConfigureAwait(false); closedCount++; } return closedCount; } /// /// Shuts down all active sessions gracefully asynchronously. /// /// Cancellation token. /// Completed task. public async Task ShutdownAsync(CancellationToken cancellationToken) { foreach (GatewaySession session in _registry.Snapshot()) { try { await CloseSessionCoreAsync(session, GatewayShutdownReason, cancellationToken).ConfigureAwait(false); } catch (Exception exception) { _logger.LogWarning( exception, "Graceful shutdown failed for session {SessionId}; killing worker.", session.SessionId); // Defensive fallback: CloseSessionCoreAsync's inner SessionCloseStartedException // catch normally removes the session and accounts the close (Server-046). The // outer fallback only fires for sessions still in the registry — route through // KillWorkerAsync so the bookkeeping is identical to the dashboard kill path. if (_registry.TryGet(session.SessionId, out _)) { try { await KillWorkerAsync(session.SessionId, GatewayShutdownReason, cancellationToken).ConfigureAwait(false); } catch (SessionManagerException killException) { _logger.LogWarning( killException, "Worker kill fallback failed for session {SessionId}.", session.SessionId); } } } } } private async Task CloseSessionCoreAsync( GatewaySession session, string reason, CancellationToken cancellationToken) { bool wasClosed = session.State == SessionState.Closed; try { SessionCloseResult result = await session.CloseAsync(reason, cancellationToken).ConfigureAwait(false); if (!wasClosed && !result.AlreadyClosed) { _metrics.SessionClosed(); } await RemoveSessionAsync(session).ConfigureAwait(false); return result; } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { throw; } catch (SessionCloseStartedException exception) { session.MarkFaulted(exception.Message); if (!wasClosed) { // Server-046: account the close as a SessionClosed (decrements the open-session // gauge AND increments the sessions.closed counter), not just SessionRemoved. // The session is being removed from the registry below; treating this as a // half-finished close that only decremented the gauge under-counted the closed // counter. _metrics.SessionClosed(); } _metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString()); await RemoveSessionAsync(session).ConfigureAwait(false); throw new SessionManagerException( SessionManagerErrorCode.CloseFailed, $"Failed to close session {session.SessionId}.", exception); } } private GatewaySession GetRequiredSession(string sessionId) { if (!_registry.TryGet(sessionId, out GatewaySession? session) || session is null) { throw new SessionManagerException( SessionManagerErrorCode.SessionNotFound, $"Session {sessionId} was not found."); } return session; } private void EnsureSessionCapacity() { if (!_sessionSlots.Wait(0)) { throw new SessionManagerException( SessionManagerErrorCode.SessionLimitExceeded, $"Gateway session limit {_options.Sessions.MaxSessions} has been reached."); } } private async Task RemoveSessionAsync(GatewaySession session) { if (!_registry.TryRemove(session.SessionId, out GatewaySession? removedSession)) { return; } _metrics.RemoveSessionEvents(session.SessionId); ReleaseSessionSlot(); await removedSession.DisposeAsync().ConfigureAwait(false); } private void ReleaseSessionSlot() { try { _sessionSlots.Release(); } catch (SemaphoreFullException) { } } private GatewaySession CreateSession( SessionOpenRequest request, string? clientIdentity, string? ownerKeyId) { string sessionId = CreateSessionId(); string backendName = string.IsNullOrWhiteSpace(request.RequestedBackend) ? GatewayContractInfo.DefaultBackendName : request.RequestedBackend!; TimeSpan commandTimeout = ResolveCommandTimeout(request.CommandTimeout); TimeSpan startupTimeout = TimeSpan.FromSeconds(_options.Worker.StartupTimeoutSeconds); TimeSpan shutdownTimeout = TimeSpan.FromSeconds(_options.Worker.ShutdownTimeoutSeconds); TimeSpan leaseDuration = TimeSpan.FromSeconds(_options.Sessions.DefaultLeaseSeconds); string pipeName = $"mxaccess-gateway-{Environment.ProcessId}-{sessionId}"; string nonce = CreateNonce(); DateTimeOffset openedAt = _timeProvider.GetUtcNow(); string clientCorrelationId = CreateClientCorrelationId(request.ClientSessionName, sessionId); return new GatewaySession( sessionId, backendName, pipeName, nonce, clientIdentity, ownerKeyId, request.ClientSessionName, clientCorrelationId, commandTimeout, startupTimeout, shutdownTimeout, leaseDuration, openedAt); } private static string CreateClientCorrelationId( string? clientSessionName, string sessionId) { string clientName = string.IsNullOrWhiteSpace(clientSessionName) ? "client" : clientSessionName!; return $"{clientName}-{sessionId}"; } private TimeSpan ResolveCommandTimeout(Duration? requestedTimeout) { if (requestedTimeout is null) { return TimeSpan.FromSeconds(_options.Sessions.DefaultCommandTimeoutSeconds); } TimeSpan timeout = requestedTimeout.ToTimeSpan(); return timeout <= TimeSpan.Zero ? TimeSpan.FromSeconds(_options.Sessions.DefaultCommandTimeoutSeconds) : timeout; } private static string CreateSessionId() { return $"session-{Guid.NewGuid():N}"; } private static string CreateNonce() { Span bytes = stackalloc byte[32]; RandomNumberGenerator.Fill(bytes); return Convert.ToBase64String(bytes); } }