using Google.Protobuf.WellKnownTypes; using Microsoft.Extensions.Options; using MxGateway.Contracts.Proto; using MxGateway.Server.Configuration; using MxGateway.Server.Metrics; using MxGateway.Server.Sessions; using MxGateway.Server.Workers; namespace MxGateway.Tests.Gateway.Sessions; public sealed class SessionManagerTests { [Fact] public async Task OpenSessionAsync_WithWorkerReady_RegistersReadySession() { FakeWorkerClient workerClient = new(); FakeSessionWorkerClientFactory factory = new(workerClient) { ApplyLifecycleTransitions = true, }; using GatewayMetrics metrics = new(); SessionManager manager = CreateManager(factory, metrics: metrics); GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None); Assert.True(manager.TryGetSession(session.SessionId, out GatewaySession registered)); Assert.Same(session, registered); Assert.Equal(SessionState.Ready, session.State); Assert.Equal("client-1", session.ClientIdentity); Assert.Equal(["StartingWorker", "WaitingForPipe", "Handshaking", "InitializingWorker"], factory.ObservedStates); Assert.Equal(1, metrics.GetSnapshot().OpenSessions); Assert.Equal(1, metrics.GetSnapshot().SessionsOpened); } [Fact] public async Task InvokeAsync_WhenSessionReady_ForwardsCommandToWorker() { FakeWorkerClient workerClient = new(); SessionManager manager = CreateManager(new FakeSessionWorkerClientFactory(workerClient)); GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None); WorkerCommandReply reply = await manager.InvokeAsync( session.SessionId, CreateCommand(MxCommandKind.Ping), CancellationToken.None); Assert.Equal(1, workerClient.InvokeCount); Assert.Equal(MxCommandKind.Ping, reply.Reply.Kind); } [Fact] public async Task InvokeAsync_WhenSessionFaulted_RejectsCommand() { FakeWorkerClient workerClient = new(); SessionManager manager = CreateManager(new FakeSessionWorkerClientFactory(workerClient)); GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None); session.MarkFaulted("test fault"); SessionManagerException exception = await Assert.ThrowsAsync( async () => await manager.InvokeAsync( session.SessionId, CreateCommand(MxCommandKind.Ping), CancellationToken.None)); Assert.Equal(SessionManagerErrorCode.SessionNotReady, exception.ErrorCode); Assert.Equal(0, workerClient.InvokeCount); } [Fact] public async Task CloseSessionAsync_WhenCalledTwice_IsIdempotent() { FakeWorkerClient workerClient = new(); using GatewayMetrics metrics = new(); SessionManager manager = CreateManager(new FakeSessionWorkerClientFactory(workerClient), metrics: metrics); GatewaySession session = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None); SessionCloseResult firstClose = await manager.CloseSessionAsync(session.SessionId, CancellationToken.None); SessionCloseResult secondClose = await manager.CloseSessionAsync(session.SessionId, CancellationToken.None); Assert.False(firstClose.AlreadyClosed); Assert.True(secondClose.AlreadyClosed); Assert.Equal(SessionState.Closed, firstClose.FinalState); Assert.Equal(SessionState.Closed, secondClose.FinalState); Assert.Equal(1, workerClient.ShutdownCount); Assert.Equal(1, metrics.GetSnapshot().SessionsClosed); Assert.Equal(0, metrics.GetSnapshot().OpenSessions); } [Fact] public async Task OpenSessionAsync_WhenWorkerCreationFails_RemovesSessionFromRegistry() { SessionRegistry registry = new(); using GatewayMetrics metrics = new(); SessionManager manager = CreateManager( new FailingSessionWorkerClientFactory(), registry, metrics); SessionManagerException exception = await Assert.ThrowsAsync( async () => await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None)); Assert.Equal(SessionManagerErrorCode.OpenFailed, exception.ErrorCode); Assert.Equal(0, registry.Count); Assert.Equal(0, metrics.GetSnapshot().SessionsOpened); Assert.Equal(1, metrics.GetSnapshot().Faults); } [Fact] public async Task CloseExpiredLeasesAsync_ClosesExpiredSessionsOnly() { FakeWorkerClient expiredClient = new(); FakeWorkerClient activeClient = new(); QueueingSessionWorkerClientFactory factory = new(expiredClient, activeClient); SessionManager manager = CreateManager(factory); GatewaySession expiredSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None); GatewaySession activeSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-2", CancellationToken.None); DateTimeOffset now = DateTimeOffset.UtcNow; expiredSession.ExtendLease(now.AddSeconds(-1)); activeSession.ExtendLease(now.AddMinutes(5)); int closedCount = await manager.CloseExpiredLeasesAsync(now, CancellationToken.None); Assert.Equal(1, closedCount); Assert.Equal(SessionState.Closed, expiredSession.State); Assert.Equal(SessionState.Ready, activeSession.State); Assert.Equal(1, expiredClient.ShutdownCount); Assert.Equal(0, activeClient.ShutdownCount); } [Fact] public async Task ShutdownAsync_ClosesAllRegisteredSessions() { FakeWorkerClient firstClient = new(); FakeWorkerClient secondClient = new(); QueueingSessionWorkerClientFactory factory = new(firstClient, secondClient); using GatewayMetrics metrics = new(); SessionManager manager = CreateManager(factory, metrics: metrics); GatewaySession firstSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-1", CancellationToken.None); GatewaySession secondSession = await manager.OpenSessionAsync(CreateOpenRequest(), "client-2", CancellationToken.None); await manager.ShutdownAsync(CancellationToken.None); Assert.Equal(SessionState.Closed, firstSession.State); Assert.Equal(SessionState.Closed, secondSession.State); Assert.Equal(1, firstClient.ShutdownCount); Assert.Equal(1, secondClient.ShutdownCount); Assert.Equal(2, metrics.GetSnapshot().SessionsClosed); Assert.Equal(0, metrics.GetSnapshot().OpenSessions); } private static SessionManager CreateManager( ISessionWorkerClientFactory factory, ISessionRegistry? registry = null, GatewayMetrics? metrics = null, GatewayOptions? options = null) { return new SessionManager( registry ?? new SessionRegistry(), factory, Options.Create(options ?? CreateOptions()), metrics ?? new GatewayMetrics()); } private static GatewayOptions CreateOptions() { return new GatewayOptions { Sessions = new SessionOptions { DefaultCommandTimeoutSeconds = 30, MaxSessions = 64, }, Worker = new WorkerOptions { StartupTimeoutSeconds = 30, ShutdownTimeoutSeconds = 10, }, }; } private static SessionOpenRequest CreateOpenRequest() { return new SessionOpenRequest( RequestedBackend: null, ClientSessionName: "test-session", ClientCorrelationId: "client-correlation-1", CommandTimeout: Duration.FromTimeSpan(TimeSpan.FromSeconds(5))); } private static WorkerCommand CreateCommand(MxCommandKind kind) { return new WorkerCommand { Command = new MxCommand { Kind = kind, }, }; } private sealed class FakeSessionWorkerClientFactory(IWorkerClient workerClient) : ISessionWorkerClientFactory { public List ObservedStates { get; } = []; public bool ApplyLifecycleTransitions { get; init; } public Task CreateAsync( GatewaySession session, CancellationToken cancellationToken) { ObservedStates.Add(session.State.ToString()); if (ApplyLifecycleTransitions) { session.TransitionTo(SessionState.WaitingForPipe); ObservedStates.Add(session.State.ToString()); session.TransitionTo(SessionState.Handshaking); ObservedStates.Add(session.State.ToString()); session.TransitionTo(SessionState.InitializingWorker); ObservedStates.Add(session.State.ToString()); } return Task.FromResult(workerClient); } } private sealed class QueueingSessionWorkerClientFactory : ISessionWorkerClientFactory { private readonly Queue _workerClients; public QueueingSessionWorkerClientFactory(params IWorkerClient[] workerClients) { _workerClients = new Queue(workerClients); } public Task CreateAsync( GatewaySession session, CancellationToken cancellationToken) { return Task.FromResult(_workerClients.Dequeue()); } } private sealed class FailingSessionWorkerClientFactory : ISessionWorkerClientFactory { public Task CreateAsync( GatewaySession session, CancellationToken cancellationToken) { throw new InvalidOperationException("worker startup failed"); } } private sealed class FakeWorkerClient : IWorkerClient { public string SessionId { get; init; } = "session-1"; public int? ProcessId { get; init; } = 1234; public WorkerClientState State { get; set; } = WorkerClientState.Ready; public DateTimeOffset LastHeartbeatAt { get; init; } = DateTimeOffset.UtcNow; public int InvokeCount { get; private set; } public int ShutdownCount { get; private set; } public int KillCount { get; private set; } public Task StartAsync(CancellationToken cancellationToken) { return Task.CompletedTask; } public Task InvokeAsync( WorkerCommand command, TimeSpan timeout, CancellationToken cancellationToken) { InvokeCount++; MxCommandKind kind = command.Command?.Kind ?? MxCommandKind.Unspecified; return Task.FromResult(new WorkerCommandReply { Reply = new MxCommandReply { SessionId = SessionId, CorrelationId = "correlation-1", Kind = kind, }, }); } public async IAsyncEnumerable ReadEventsAsync( [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) { await Task.CompletedTask; yield break; } public Task ShutdownAsync( TimeSpan timeout, CancellationToken cancellationToken) { ShutdownCount++; State = WorkerClientState.Closed; return Task.CompletedTask; } public void Kill(string reason) { KillCount++; State = WorkerClientState.Faulted; } public ValueTask DisposeAsync() { return ValueTask.CompletedTask; } } }