Files
mxaccessgw/src/ZB.MOM.WW.MxGateway.Server/Sessions/SessionManager.cs
T

511 lines
19 KiB
C#

using System.Diagnostics.CodeAnalysis;
using System.Security.Cryptography;
using Google.Protobuf.WellKnownTypes;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using ZB.MOM.WW.MxGateway.Contracts;
using ZB.MOM.WW.MxGateway.Contracts.Proto;
using ZB.MOM.WW.MxGateway.Server.Configuration;
using ZB.MOM.WW.MxGateway.Server.Metrics;
using ZB.MOM.WW.MxGateway.Server.Workers;
namespace ZB.MOM.WW.MxGateway.Server.Sessions;
public sealed class SessionManager : ISessionManager
{
public const string DefaultCloseReason = "client-close";
public const string GatewayShutdownReason = "gateway-shutdown";
public const string LeaseExpiredReason = "lease-expired";
private readonly ISessionRegistry _registry;
private readonly ISessionWorkerClientFactory _workerClientFactory;
private readonly GatewayMetrics _metrics;
private readonly TimeProvider _timeProvider;
private readonly ILogger<SessionManager> _logger;
private readonly GatewayOptions _options;
private readonly SemaphoreSlim _sessionSlots;
private readonly Grpc.MxAccessGrpcMapper _eventMapper;
private readonly ILogger<SessionEventDistributor> _distributorLogger;
/// <summary>
/// Initializes a new instance of <see cref="SessionManager"/>.
/// </summary>
/// <param name="registry">Session registry.</param>
/// <param name="workerClientFactory">Worker client factory.</param>
/// <param name="options">Gateway options.</param>
/// <param name="metrics">Gateway metrics.</param>
/// <param name="timeProvider">Time provider for timestamps.</param>
/// <param name="logger">Logger.</param>
/// <param name="eventMapper">Mapper used by each session's event distributor to map worker events to public events.</param>
/// <param name="distributorLogger">Logger passed to each session's event distributor pump.</param>
public SessionManager(
ISessionRegistry registry,
ISessionWorkerClientFactory workerClientFactory,
IOptions<GatewayOptions> options,
GatewayMetrics metrics,
TimeProvider? timeProvider = null,
ILogger<SessionManager>? logger = null,
Grpc.MxAccessGrpcMapper? eventMapper = null,
ILogger<SessionEventDistributor>? distributorLogger = null)
{
_registry = registry ?? throw new ArgumentNullException(nameof(registry));
_workerClientFactory = workerClientFactory ?? throw new ArgumentNullException(nameof(workerClientFactory));
ArgumentNullException.ThrowIfNull(options);
_metrics = metrics ?? throw new ArgumentNullException(nameof(metrics));
_timeProvider = timeProvider ?? TimeProvider.System;
_logger = logger ?? NullLogger<SessionManager>.Instance;
_eventMapper = eventMapper ?? new Grpc.MxAccessGrpcMapper();
_distributorLogger = distributorLogger ?? NullLogger<SessionEventDistributor>.Instance;
_options = options.Value;
_sessionSlots = new SemaphoreSlim(_options.Sessions.MaxSessions, _options.Sessions.MaxSessions);
}
/// <summary>
/// Opens a new gateway session and connects to the worker.
/// </summary>
/// <param name="request">Session open request.</param>
/// <param name="clientIdentity">Client authentication identity.</param>
/// <param name="ownerKeyId">API key identifier of the caller creating the session.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Opened gateway session.</returns>
public async Task<GatewaySession> OpenSessionAsync(
SessionOpenRequest request,
string? clientIdentity,
string? ownerKeyId,
CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(request);
EnsureSessionCapacity();
GatewaySession? session = null;
bool sessionOpenedRecorded = false;
try
{
session = CreateSession(request, clientIdentity, ownerKeyId);
if (!_registry.TryAdd(session))
{
throw new SessionManagerException(
SessionManagerErrorCode.OpenFailed,
$"Session id collision while opening session {session.SessionId}.");
}
session.TransitionTo(SessionState.StartingWorker);
IWorkerClient workerClient = await _workerClientFactory
.CreateAsync(session, cancellationToken)
.ConfigureAwait(false);
session.AttachWorkerClient(workerClient);
session.MarkReady();
_metrics.SessionOpened();
sessionOpenedRecorded = true;
return session;
}
catch (Exception exception)
{
session?.MarkFaulted(exception.Message);
if (session is not null)
{
_registry.TryRemove(session.SessionId, out _);
await session.DisposeAsync().ConfigureAwait(false);
}
// If SessionOpened() already incremented the open-session gauge,
// a failure after that point (e.g. auto-subscribe rejection) must
// decrement it again so mxgateway.sessions.open does not leak.
if (sessionOpenedRecorded)
{
_metrics.SessionRemoved();
}
ReleaseSessionSlot();
_metrics.Fault(SessionManagerErrorCode.OpenFailed.ToString());
_logger.LogWarning(
exception,
"Failed to open gateway session {SessionId}.",
session?.SessionId ?? "<not-created>");
throw new SessionManagerException(
SessionManagerErrorCode.OpenFailed,
session is null ? "Failed to create session." : $"Failed to open session {session.SessionId}.",
exception);
}
}
/// <summary>
/// Attempts to retrieve a session by ID.
/// </summary>
/// <param name="sessionId">Session identifier.</param>
/// <param name="session">The session if found.</param>
/// <returns>True if session found; otherwise false.</returns>
public bool TryGetSession(
string sessionId,
[MaybeNullWhen(false)] out GatewaySession session)
{
return _registry.TryGet(sessionId, out session);
}
/// <summary>
/// Invokes a worker command on a session asynchronously.
/// </summary>
/// <param name="sessionId">Session identifier.</param>
/// <param name="command">Worker command.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Command reply.</returns>
public async Task<WorkerCommandReply> InvokeAsync(
string sessionId,
WorkerCommand command,
CancellationToken cancellationToken)
{
GatewaySession session = GetRequiredSession(sessionId);
try
{
return await session.InvokeAsync(command, cancellationToken).ConfigureAwait(false);
}
catch (SessionManagerException)
{
throw;
}
catch (Exception exception)
{
if (session.WorkerClient?.State == WorkerClientState.Faulted)
{
session.MarkFaulted(exception.Message);
}
throw;
}
}
/// <summary>
/// Reads events from a session's event stream asynchronously.
/// </summary>
/// <param name="sessionId">Session identifier.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Async enumerable of worker events.</returns>
public IAsyncEnumerable<WorkerEvent> ReadEventsAsync(
string sessionId,
CancellationToken cancellationToken)
{
GatewaySession session = GetRequiredSession(sessionId);
return session.ReadEventsAsync(cancellationToken);
}
/// <summary>
/// Closes a gateway session asynchronously.
/// </summary>
/// <param name="sessionId">Session identifier.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Session close result.</returns>
public async Task<SessionCloseResult> CloseSessionAsync(
string sessionId,
CancellationToken cancellationToken)
{
GatewaySession session = GetRequiredSession(sessionId);
SessionCloseResult result = await CloseSessionCoreAsync(
session,
DefaultCloseReason,
cancellationToken).ConfigureAwait(false);
return result;
}
/// <summary>
/// Forcefully terminates a session's worker without attempting graceful shutdown.
/// Mirrors the registry/metrics cleanup that <see cref="CloseSessionCoreAsync"/>
/// performs after a successful close, but skips the <c>WorkerClient.ShutdownAsync</c>
/// step that <see cref="GatewaySession.CloseAsync"/> would otherwise attempt.
/// </summary>
/// <param name="sessionId">Session identifier.</param>
/// <param name="reason">Reason recorded for the kill.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Session close result.</returns>
public async Task<SessionCloseResult> KillWorkerAsync(
string sessionId,
string reason,
CancellationToken cancellationToken)
{
ArgumentException.ThrowIfNullOrWhiteSpace(reason);
cancellationToken.ThrowIfCancellationRequested();
GatewaySession session = GetRequiredSession(sessionId);
// Serialize concurrent kill/close attempts on this session by routing through the
// per-session close lock (Server-045). Returns whether the session was already in
// Closed state when the lock was acquired so the metric counter is incremented at
// most once across concurrent callers.
bool wasClosed;
try
{
wasClosed = await session.KillWorkerWithCloseGateAsync(reason, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
session.MarkFaulted(exception.Message);
_metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString());
// Server-044: the open-session gauge was incremented in OpenSessionAsync;
// every session reaching KillWorkerAsync had SessionOpened recorded. If the
// kill path throws, decrement the gauge here so mxgateway.sessions.open
// does not leak — mirroring the Server-006 fix on OpenSessionAsync.
_metrics.SessionRemoved();
await RemoveSessionAsync(session).ConfigureAwait(false);
throw new SessionManagerException(
SessionManagerErrorCode.CloseFailed,
$"Failed to kill worker for session {sessionId}.",
exception);
}
if (!wasClosed)
{
_metrics.SessionClosed();
}
await RemoveSessionAsync(session).ConfigureAwait(false);
_logger.LogInformation(
"Worker for session {SessionId} killed; reason={Reason}.",
sessionId,
reason);
return new SessionCloseResult(sessionId, SessionState.Closed, AlreadyClosed: wasClosed);
}
/// <summary>
/// Closes all sessions with expired leases asynchronously.
/// </summary>
/// <param name="now">Current time for lease expiration check.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Count of sessions closed.</returns>
public async Task<int> CloseExpiredLeasesAsync(
DateTimeOffset now,
CancellationToken cancellationToken)
{
int closedCount = 0;
foreach (GatewaySession session in _registry.Snapshot())
{
if (!session.IsLeaseExpired(now))
{
continue;
}
await CloseSessionCoreAsync(session, LeaseExpiredReason, cancellationToken).ConfigureAwait(false);
closedCount++;
}
return closedCount;
}
/// <summary>
/// Shuts down all active sessions gracefully asynchronously.
/// </summary>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Completed task.</returns>
public async Task ShutdownAsync(CancellationToken cancellationToken)
{
foreach (GatewaySession session in _registry.Snapshot())
{
try
{
await CloseSessionCoreAsync(session, GatewayShutdownReason, cancellationToken).ConfigureAwait(false);
}
catch (Exception exception)
{
_logger.LogWarning(
exception,
"Graceful shutdown failed for session {SessionId}; killing worker.",
session.SessionId);
// Defensive fallback: CloseSessionCoreAsync's inner SessionCloseStartedException
// catch normally removes the session and accounts the close (Server-046). The
// outer fallback only fires for sessions still in the registry — route through
// KillWorkerAsync so the bookkeeping is identical to the dashboard kill path.
if (_registry.TryGet(session.SessionId, out _))
{
try
{
await KillWorkerAsync(session.SessionId, GatewayShutdownReason, cancellationToken).ConfigureAwait(false);
}
catch (SessionManagerException killException)
{
_logger.LogWarning(
killException,
"Worker kill fallback failed for session {SessionId}.",
session.SessionId);
}
}
}
}
}
private async Task<SessionCloseResult> CloseSessionCoreAsync(
GatewaySession session,
string reason,
CancellationToken cancellationToken)
{
bool wasClosed = session.State == SessionState.Closed;
try
{
SessionCloseResult result = await session.CloseAsync(reason, cancellationToken).ConfigureAwait(false);
if (!wasClosed && !result.AlreadyClosed)
{
_metrics.SessionClosed();
}
await RemoveSessionAsync(session).ConfigureAwait(false);
return result;
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
throw;
}
catch (SessionCloseStartedException exception)
{
session.MarkFaulted(exception.Message);
if (!wasClosed)
{
// Server-046: account the close as a SessionClosed (decrements the open-session
// gauge AND increments the sessions.closed counter), not just SessionRemoved.
// The session is being removed from the registry below; treating this as a
// half-finished close that only decremented the gauge under-counted the closed
// counter.
_metrics.SessionClosed();
}
_metrics.Fault(SessionManagerErrorCode.CloseFailed.ToString());
await RemoveSessionAsync(session).ConfigureAwait(false);
throw new SessionManagerException(
SessionManagerErrorCode.CloseFailed,
$"Failed to close session {session.SessionId}.",
exception);
}
}
private GatewaySession GetRequiredSession(string sessionId)
{
if (!_registry.TryGet(sessionId, out GatewaySession? session) || session is null)
{
throw new SessionManagerException(
SessionManagerErrorCode.SessionNotFound,
$"Session {sessionId} was not found.");
}
return session;
}
private void EnsureSessionCapacity()
{
if (!_sessionSlots.Wait(0))
{
throw new SessionManagerException(
SessionManagerErrorCode.SessionLimitExceeded,
$"Gateway session limit {_options.Sessions.MaxSessions} has been reached.");
}
}
private async Task RemoveSessionAsync(GatewaySession session)
{
if (!_registry.TryRemove(session.SessionId, out GatewaySession? removedSession))
{
return;
}
_metrics.RemoveSessionEvents(session.SessionId);
ReleaseSessionSlot();
await removedSession.DisposeAsync().ConfigureAwait(false);
}
private void ReleaseSessionSlot()
{
try
{
_sessionSlots.Release();
}
catch (SemaphoreFullException)
{
}
}
private GatewaySession CreateSession(
SessionOpenRequest request,
string? clientIdentity,
string? ownerKeyId)
{
string sessionId = CreateSessionId();
string backendName = string.IsNullOrWhiteSpace(request.RequestedBackend)
? GatewayContractInfo.DefaultBackendName
: request.RequestedBackend!;
TimeSpan commandTimeout = ResolveCommandTimeout(request.CommandTimeout);
TimeSpan startupTimeout = TimeSpan.FromSeconds(_options.Worker.StartupTimeoutSeconds);
TimeSpan shutdownTimeout = TimeSpan.FromSeconds(_options.Worker.ShutdownTimeoutSeconds);
TimeSpan leaseDuration = TimeSpan.FromSeconds(_options.Sessions.DefaultLeaseSeconds);
string pipeName = $"mxaccess-gateway-{Environment.ProcessId}-{sessionId}";
string nonce = CreateNonce();
DateTimeOffset openedAt = _timeProvider.GetUtcNow();
string clientCorrelationId = CreateClientCorrelationId(request.ClientSessionName, sessionId);
SessionEventStreaming eventStreaming = new(
_eventMapper,
_options.Events,
_distributorLogger,
_timeProvider,
_metrics);
return new GatewaySession(
sessionId,
backendName,
pipeName,
nonce,
clientIdentity,
ownerKeyId,
request.ClientSessionName,
clientCorrelationId,
commandTimeout,
startupTimeout,
shutdownTimeout,
leaseDuration,
openedAt,
eventStreaming);
}
private static string CreateClientCorrelationId(
string? clientSessionName,
string sessionId)
{
string clientName = string.IsNullOrWhiteSpace(clientSessionName)
? "client"
: clientSessionName!;
return $"{clientName}-{sessionId}";
}
private TimeSpan ResolveCommandTimeout(Duration? requestedTimeout)
{
if (requestedTimeout is null)
{
return TimeSpan.FromSeconds(_options.Sessions.DefaultCommandTimeoutSeconds);
}
TimeSpan timeout = requestedTimeout.ToTimeSpan();
return timeout <= TimeSpan.Zero
? TimeSpan.FromSeconds(_options.Sessions.DefaultCommandTimeoutSeconds)
: timeout;
}
private static string CreateSessionId()
{
return $"session-{Guid.NewGuid():N}";
}
private static string CreateNonce()
{
Span<byte> bytes = stackalloc byte[32];
RandomNumberGenerator.Fill(bytes);
return Convert.ToBase64String(bytes);
}
}