diff --git a/clients/dotnet/MxGateway.Client.Cli/MxGatewayClientCli.cs b/clients/dotnet/MxGateway.Client.Cli/MxGatewayClientCli.cs index 4460b3d..2d67fef 100644 --- a/clients/dotnet/MxGateway.Client.Cli/MxGatewayClientCli.cs +++ b/clients/dotnet/MxGateway.Client.Cli/MxGatewayClientCli.cs @@ -86,6 +86,10 @@ public static class MxGatewayClientCli .ConfigureAwait(false), "advise" => await AdviseAsync(arguments, client, standardOutput, cancellation.Token) .ConfigureAwait(false), + "subscribe-bulk" => await SubscribeBulkAsync(arguments, client, standardOutput, cancellation.Token) + .ConfigureAwait(false), + "unsubscribe-bulk" => await UnsubscribeBulkAsync(arguments, client, standardOutput, cancellation.Token) + .ConfigureAwait(false), "stream-events" => await StreamEventsAsync(arguments, client, standardOutput, cancellation.Token) .ConfigureAwait(false), "write" => await WriteAsync(arguments, client, standardOutput, cancellation.Token) @@ -289,6 +293,54 @@ public static class MxGatewayClientCli cancellationToken); } + private static Task SubscribeBulkAsync( + CliArguments arguments, + IMxGatewayCliClient client, + TextWriter output, + CancellationToken cancellationToken) + { + SubscribeBulkCommand command = new() + { + ServerHandle = arguments.GetInt32("server-handle"), + }; + command.TagAddresses.Add(ParseStringList(arguments.GetRequired("items"))); + + return InvokeAndWriteAsync( + arguments, + client, + output, + new MxCommand + { + Kind = MxCommandKind.SubscribeBulk, + SubscribeBulk = command, + }, + cancellationToken); + } + + private static Task UnsubscribeBulkAsync( + CliArguments arguments, + IMxGatewayCliClient client, + TextWriter output, + CancellationToken cancellationToken) + { + UnsubscribeBulkCommand command = new() + { + ServerHandle = arguments.GetInt32("server-handle"), + }; + command.ItemHandles.Add(ParseInt32List(arguments.GetRequired("item-handles"))); + + return InvokeAndWriteAsync( + arguments, + client, + output, + new MxCommand + { + Kind = MxCommandKind.UnsubscribeBulk, + UnsubscribeBulk = command, + }, + cancellationToken); + } + private static Task WriteAsync( CliArguments arguments, IMxGatewayCliClient client, @@ -736,12 +788,40 @@ public static class MxGatewayClientCli or "register" or "add-item" or "advise" + or "subscribe-bulk" + or "unsubscribe-bulk" or "stream-events" or "write" or "write2" or "smoke"; } + private static IReadOnlyList ParseStringList(string value) + { + string[] items = value + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if (items.Length is 0) + { + throw new ArgumentException("At least one item is required."); + } + + return items; + } + + private static IReadOnlyList ParseInt32List(string value) + { + string[] items = value + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if (items.Length is 0) + { + throw new ArgumentException("At least one item handle is required."); + } + + return items + .Select(item => int.Parse(item, CultureInfo.InvariantCulture)) + .ToArray(); + } + private static string CreateCorrelationId() { return Guid.NewGuid().ToString("N"); @@ -756,6 +836,8 @@ public static class MxGatewayClientCli writer.WriteLine("mxgw-dotnet register --session-id --client-name [--json]"); writer.WriteLine("mxgw-dotnet add-item --session-id --server-handle --item [--json]"); writer.WriteLine("mxgw-dotnet advise --session-id --server-handle --item-handle [--json]"); + writer.WriteLine("mxgw-dotnet subscribe-bulk --session-id --server-handle --items [--json]"); + writer.WriteLine("mxgw-dotnet unsubscribe-bulk --session-id --server-handle --item-handles [--json]"); writer.WriteLine("mxgw-dotnet stream-events --session-id [--max-events ] [--json]"); writer.WriteLine("mxgw-dotnet write --session-id --server-handle --item-handle --type --value [--json]"); writer.WriteLine("mxgw-dotnet write2 --session-id --server-handle --item-handle --type --value [--timestamp ] [--json]"); diff --git a/clients/go/README.md b/clients/go/README.md index 4f73d97..47ca554 100644 --- a/clients/go/README.md +++ b/clients/go/README.md @@ -76,10 +76,13 @@ client, err := mxgateway.Dial(ctx, mxgateway.Options{ ``` `Client.OpenSession` returns a `Session` with helpers for `Register`, -`AddItem`, `AddItem2`, `Advise`, `Write`, `Events`, and `Close`. Raw protobuf -messages remain available through the `mxgateway` package aliases and the -`Raw` helper methods. Typed errors support `errors.As` for `GatewayError`, -`CommandError`, and `MxAccessError`; command errors preserve the raw reply. +`AddItem`, `AddItem2`, `Advise`, `Write`, `Events`, and `Close`. Prefer +`SubscribeEvents` or `SubscribeEventsAfter` for long-running streams because the +returned subscription owns cancellation and exposes `Close` for deterministic +goroutine cleanup. Raw protobuf messages remain available through the +`mxgateway` package aliases and the `Raw` helper methods. Typed errors support +`errors.As` for `GatewayError`, `CommandError`, and `MxAccessError`; command +errors preserve the raw reply. ## CLI diff --git a/clients/go/cmd/mxgw-go/main.go b/clients/go/cmd/mxgw-go/main.go index b74346c..b8555c0 100644 --- a/clients/go/cmd/mxgw-go/main.go +++ b/clients/go/cmd/mxgw-go/main.go @@ -9,6 +9,7 @@ import ( "io" "os" "strconv" + "strings" "time" "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/mxgateway" @@ -77,6 +78,10 @@ func runWithIO(ctx context.Context, args []string, stdout, stderr io.Writer) err return runAddItem(ctx, args[1:], stdout, stderr) case "advise": return runAdvise(ctx, args[1:], stdout, stderr) + case "subscribe-bulk": + return runSubscribeBulk(ctx, args[1:], stdout, stderr) + case "unsubscribe-bulk": + return runUnsubscribeBulk(ctx, args[1:], stdout, stderr) case "write": return runWrite(ctx, args[1:], stdout, stderr) case "stream-events": @@ -268,6 +273,60 @@ func runAdvise(ctx context.Context, args []string, stdout, stderr io.Writer) err return writeCommandOutput(stdout, *jsonOutput, "advise", options, reply, err) } +func runSubscribeBulk(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("subscribe-bulk", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") + items := flags.String("items", "", "comma-separated item definitions") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" || *items == "" { + return errors.New("session-id and items are required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + results, err := session.SubscribeBulk(ctx, int32(*serverHandle), parseStringList(*items)) + return writeBulkOutput(stdout, *jsonOutput, "subscribe-bulk", options, results, err) +} + +func runUnsubscribeBulk(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("unsubscribe-bulk", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") + itemHandles := flags.String("item-handles", "", "comma-separated item handles") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" || *itemHandles == "" { + return errors.New("session-id and item-handles are required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + results, err := session.UnsubscribeBulk(ctx, int32(*serverHandle), parseInt32List(*itemHandles)) + return writeBulkOutput(stdout, *jsonOutput, "unsubscribe-bulk", options, results, err) +} + func runWrite(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("write", flag.ContinueOnError) flags.SetOutput(stderr) @@ -328,10 +387,12 @@ func runStreamEvents(ctx context.Context, args []string, stdout, stderr io.Write session := mxgateway.NewSessionForID(client, *sessionID) streamCtx, cancelStream := context.WithCancel(ctx) defer cancelStream() - events, err := session.EventsAfter(streamCtx, *after) + subscription, err := session.SubscribeEventsAfter(streamCtx, *after) if err != nil { return err } + defer subscription.Close() + events := subscription.Events() count := 0 for result := range events { @@ -426,6 +487,35 @@ func closeSmokeSession(ctx context.Context, session *mxgateway.Session, primaryE return closeErr } +func parseStringList(value string) []string { + parts := strings.Split(value, ",") + items := make([]string, 0, len(parts)) + for _, part := range parts { + item := strings.TrimSpace(part) + if item != "" { + items = append(items, item) + } + } + return items +} + +func parseInt32List(value string) []int32 { + parts := strings.Split(value, ",") + items := make([]int32, 0, len(parts)) + for _, part := range parts { + item := strings.TrimSpace(part) + if item == "" { + continue + } + parsed, err := strconv.ParseInt(item, 10, 32) + if err != nil { + panic(err) + } + items = append(items, int32(parsed)) + } + return items +} + func bindCommonFlags(flags *flag.FlagSet) *commonOptions { common := &commonOptions{} flags.StringVar(&common.Endpoint, "endpoint", "localhost:5000", "gateway endpoint") @@ -527,6 +617,21 @@ func writeCommandOutput(stdout io.Writer, jsonOutput bool, command string, optio return nil } +func writeBulkOutput(stdout io.Writer, jsonOutput bool, command string, options commonOptions, results []*mxgateway.SubscribeResult, err error) error { + if err != nil { + return err + } + if jsonOutput { + return writeJSON(stdout, map[string]any{ + "command": command, + "options": options, + "results": results, + }) + } + fmt.Fprintln(stdout, len(results)) + return nil +} + func writeJSON(writer io.Writer, value any) error { encoder := json.NewEncoder(writer) encoder.SetIndent("", " ") @@ -546,5 +651,5 @@ type protojsonMessage interface { } func writeUsage(writer io.Writer) { - fmt.Fprintln(writer, "usage: mxgw-go ") + fmt.Fprintln(writer, "usage: mxgw-go ") } diff --git a/clients/go/mxgateway/client_session_test.go b/clients/go/mxgateway/client_session_test.go index 48f5290..46ffb5f 100644 --- a/clients/go/mxgateway/client_session_test.go +++ b/clients/go/mxgateway/client_session_test.go @@ -77,6 +77,42 @@ func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) { } } +func TestEventSubscriptionCloseStopsStream(t *testing.T) { + fake := &fakeGatewayServer{ + streamStarted: make(chan struct{}), + streamDone: make(chan struct{}), + } + client, cleanup := newBufconnClient(t, fake) + defer cleanup() + session := NewSessionForID(client, "session-1") + + subscription, err := session.SubscribeEvents(context.Background()) + if err != nil { + t.Fatalf("SubscribeEvents() error = %v", err) + } + <-fake.streamStarted + first := <-subscription.Events() + if first.Err != nil { + t.Fatalf("first event error = %v", first.Err) + } + + subscription.Close() + + select { + case <-fake.streamDone: + case <-time.After(2 * time.Second): + t.Fatal("event stream did not stop after subscription close") + } + select { + case _, ok := <-subscription.Events(): + if ok { + t.Fatal("subscription channel remained open after close") + } + case <-time.After(2 * time.Second): + t.Fatal("subscription channel did not close") + } +} + func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ @@ -235,6 +271,7 @@ type fakeGatewayServer struct { openAuth string streamAuth string streamStarted chan struct{} + streamDone chan struct{} invokeReply *pb.MxCommandReply invokeRequest *pb.MxCommandRequest } @@ -277,6 +314,9 @@ func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error { s.streamAuth = authorizationFromContext(stream.Context()) + if s.streamDone != nil { + defer close(s.streamDone) + } if s.streamStarted != nil { close(s.streamStarted) } diff --git a/clients/go/mxgateway/session.go b/clients/go/mxgateway/session.go index 3f1fc16..81be344 100644 --- a/clients/go/mxgateway/session.go +++ b/clients/go/mxgateway/session.go @@ -22,6 +22,30 @@ type EventResult struct { Err error } +// EventSubscription owns a running gateway event stream. +type EventSubscription struct { + results <-chan EventResult + cancel context.CancelFunc + done <-chan struct{} + once sync.Once +} + +// Events returns the stream results channel. +func (s *EventSubscription) Events() <-chan EventResult { + return s.results +} + +// Close cancels the stream and waits for the receive goroutine to stop. +func (s *EventSubscription) Close() { + if s == nil { + return + } + s.once.Do(func() { + s.cancel() + <-s.done + }) +} + // Session represents one gateway-backed MXAccess session. type Session struct { client *Client @@ -394,34 +418,56 @@ func (s *Session) Events(ctx context.Context) (<-chan EventResult, error) { // EventsAfter streams ordered session events after the given worker sequence. func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (<-chan EventResult, error) { - stream, err := s.client.StreamEventsRaw(ctx, &pb.StreamEventsRequest{ + subscription, err := s.SubscribeEventsAfter(ctx, afterWorkerSequence) + if err != nil { + return nil, err + } + return subscription.Events(), nil +} + +// SubscribeEvents starts an owned event subscription. +func (s *Session) SubscribeEvents(ctx context.Context) (*EventSubscription, error) { + return s.SubscribeEventsAfter(ctx, 0) +} + +// SubscribeEventsAfter starts an owned event subscription after the given worker sequence. +func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence uint64) (*EventSubscription, error) { + streamCtx, cancel := context.WithCancel(ctx) + stream, err := s.client.StreamEventsRaw(streamCtx, &pb.StreamEventsRequest{ SessionId: s.ID(), AfterWorkerSequence: afterWorkerSequence, }) if err != nil { + cancel() return nil, err } results := make(chan EventResult, 16) + done := make(chan struct{}) go func() { defer close(results) + defer close(done) for { event, err := stream.Recv() if err == nil { - if !sendEventResult(ctx, results, EventResult{Event: event}) { + if !sendEventResult(streamCtx, results, EventResult{Event: event}) { return } continue } - if err == io.EOF || status.Code(err) == codes.Canceled || ctx.Err() != nil { + if err == io.EOF || status.Code(err) == codes.Canceled || streamCtx.Err() != nil { return } - sendEventResult(ctx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}}) + sendEventResult(streamCtx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}}) return } }() - return results, nil + return &EventSubscription{ + results: results, + cancel: cancel, + done: done, + }, nil } func ensureBulkSize(name string, length int) error { diff --git a/clients/java/mxgateway-cli/src/main/java/com/dohertylan/mxgateway/cli/MxGatewayCli.java b/clients/java/mxgateway-cli/src/main/java/com/dohertylan/mxgateway/cli/MxGatewayCli.java index bf4a918..868760b 100644 --- a/clients/java/mxgateway-cli/src/main/java/com/dohertylan/mxgateway/cli/MxGatewayCli.java +++ b/clients/java/mxgateway-cli/src/main/java/com/dohertylan/mxgateway/cli/MxGatewayCli.java @@ -12,7 +12,9 @@ import com.google.protobuf.util.JsonFormat; import java.io.PrintWriter; import java.nio.file.Path; import java.time.Duration; +import java.util.Arrays; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import mxaccess_gateway.v1.MxaccessGateway.CloseSessionRequest; @@ -20,6 +22,7 @@ import mxaccess_gateway.v1.MxaccessGateway.MxCommandReply; import mxaccess_gateway.v1.MxaccessGateway.MxEvent; import mxaccess_gateway.v1.MxaccessGateway.MxValue; import mxaccess_gateway.v1.MxaccessGateway.OpenSessionRequest; +import mxaccess_gateway.v1.MxaccessGateway.SubscribeResult; import picocli.CommandLine; import picocli.CommandLine.Command; import picocli.CommandLine.Mixin; @@ -75,6 +78,8 @@ public final class MxGatewayCli implements Callable { commandLine.addSubcommand("register", new RegisterCommand(clientFactory)); commandLine.addSubcommand("add-item", new AddItemCommand(clientFactory)); commandLine.addSubcommand("advise", new AdviseCommand(clientFactory)); + commandLine.addSubcommand("subscribe-bulk", new SubscribeBulkCommand(clientFactory)); + commandLine.addSubcommand("unsubscribe-bulk", new UnsubscribeBulkCommand(clientFactory)); commandLine.addSubcommand("write", new WriteCommand(clientFactory)); commandLine.addSubcommand("stream-events", new StreamEventsCommand(clientFactory)); commandLine.addSubcommand("smoke", new SmokeCommand(clientFactory)); @@ -243,6 +248,58 @@ public final class MxGatewayCli implements Callable { } } + @Command(name = "subscribe-bulk", description = "Invokes MXAccess SubscribeBulk.") + static final class SubscribeBulkCommand extends GatewayCommand { + @Option(names = "--session-id", required = true, description = "Gateway session id.") + String sessionId; + + @Option(names = "--server-handle", required = true, description = "MXAccess server handle.") + int serverHandle; + + @Option(names = "--items", required = true, description = "Comma-separated item definitions.") + String items; + + SubscribeBulkCommand(MxGatewayCliClientFactory clientFactory) { + super(clientFactory); + } + + @Override + public Integer call() { + try (MxGatewayCliClient client = clientFactory.connect(common.resolved())) { + List results = + client.session(sessionId).subscribeBulk(serverHandle, parseStringList(items)); + writeBulkOutput("subscribe-bulk", common, json, results); + } + return 0; + } + } + + @Command(name = "unsubscribe-bulk", description = "Invokes MXAccess UnsubscribeBulk.") + static final class UnsubscribeBulkCommand extends GatewayCommand { + @Option(names = "--session-id", required = true, description = "Gateway session id.") + String sessionId; + + @Option(names = "--server-handle", required = true, description = "MXAccess server handle.") + int serverHandle; + + @Option(names = "--item-handles", required = true, description = "Comma-separated item handles.") + String itemHandles; + + UnsubscribeBulkCommand(MxGatewayCliClientFactory clientFactory) { + super(clientFactory); + } + + @Override + public Integer call() { + try (MxGatewayCliClient client = clientFactory.connect(common.resolved())) { + List results = + client.session(sessionId).unsubscribeBulk(serverHandle, parseIntList(itemHandles)); + writeBulkOutput("unsubscribe-bulk", common, json, results); + } + return 0; + } + } + @Command(name = "write", description = "Invokes MXAccess Write.") static final class WriteCommand extends GatewayCommand { @Option(names = "--session-id", required = true, description = "Gateway session id.") @@ -454,6 +511,10 @@ public final class MxGatewayCli implements Callable { MxCommandReply writeRaw(int serverHandle, int itemHandle, MxValue value, int userId); + List subscribeBulk(int serverHandle, List items); + + List unsubscribeBulk(int serverHandle, List itemHandles); + MxEventStream streamEventsAfter(long afterWorkerSequence); } @@ -535,6 +596,16 @@ public final class MxGatewayCli implements Callable { return session.writeRaw(serverHandle, itemHandle, value, userId); } + @Override + public List subscribeBulk(int serverHandle, List items) { + return session.subscribeBulk(serverHandle, items); + } + + @Override + public List unsubscribeBulk(int serverHandle, List itemHandles) { + return session.unsubscribeBulk(serverHandle, itemHandles); + } + @Override public MxEventStream streamEventsAfter(long afterWorkerSequence) { return session.streamEventsAfter(afterWorkerSequence); @@ -559,6 +630,30 @@ public final class MxGatewayCli implements Callable { out.println(textSupplier.get()); } + private static void writeBulkOutput( + String command, CommonOptions common, boolean json, List results) { + PrintWriter out = common.spec.commandLine().getOut(); + if (json) { + Map output = new LinkedHashMap<>(); + output.put("command", command); + output.put("options", common.redactedJsonMap()); + output.put("results", results.stream().map(MxGatewayCli::subscribeResultMap).toList()); + out.println(jsonObject(output)); + return; + } + out.println(results.size()); + } + + private static Map subscribeResultMap(SubscribeResult result) { + Map values = new LinkedHashMap<>(); + values.put("serverHandle", result.getServerHandle()); + values.put("tagAddress", result.getTagAddress()); + values.put("itemHandle", result.getItemHandle()); + values.put("wasSuccessful", result.getWasSuccessful()); + values.put("errorMessage", result.getErrorMessage()); + return values; + } + private static MxValue parseValue(String type, String text) { return switch (type) { case "bool" -> MxValues.boolValue(Boolean.parseBoolean(text)); @@ -571,6 +666,17 @@ public final class MxGatewayCli implements Callable { }; } + private static List parseStringList(String value) { + return Arrays.stream(value.split(",")) + .map(String::trim) + .filter(item -> !item.isBlank()) + .toList(); + } + + private static List parseIntList(String value) { + return parseStringList(value).stream().map(Integer::parseInt).toList(); + } + private static Duration parseDuration(String value) { if (value == null || value.isBlank()) { return Duration.ofSeconds(30); @@ -630,6 +736,20 @@ public final class MxGatewayCli implements Callable { if (value instanceof Map map) { return jsonObject((Map) map); } + if (value instanceof Iterable iterable) { + StringBuilder builder = new StringBuilder(); + builder.append('['); + boolean first = true; + for (Object item : iterable) { + if (!first) { + builder.append(','); + } + first = false; + builder.append(jsonValue(item)); + } + builder.append(']'); + return builder.toString(); + } return jsonString(value.toString()); } diff --git a/clients/java/mxgateway-cli/src/test/java/com/dohertylan/mxgateway/cli/MxGatewayCliTests.java b/clients/java/mxgateway-cli/src/test/java/com/dohertylan/mxgateway/cli/MxGatewayCliTests.java index 6de9073..6f047fb 100644 --- a/clients/java/mxgateway-cli/src/test/java/com/dohertylan/mxgateway/cli/MxGatewayCliTests.java +++ b/clients/java/mxgateway-cli/src/test/java/com/dohertylan/mxgateway/cli/MxGatewayCliTests.java @@ -6,6 +6,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.PrintWriter; import java.io.StringWriter; +import java.util.ArrayList; +import java.util.List; import mxaccess_gateway.v1.MxaccessGateway.AddItemReply; import mxaccess_gateway.v1.MxaccessGateway.CloseSessionReply; import mxaccess_gateway.v1.MxaccessGateway.CloseSessionRequest; @@ -19,6 +21,7 @@ import mxaccess_gateway.v1.MxaccessGateway.ProtocolStatus; import mxaccess_gateway.v1.MxaccessGateway.ProtocolStatusCode; import mxaccess_gateway.v1.MxaccessGateway.RegisterReply; import mxaccess_gateway.v1.MxaccessGateway.SessionState; +import mxaccess_gateway.v1.MxaccessGateway.SubscribeResult; import org.junit.jupiter.api.Test; final class MxGatewayCliTests { @@ -100,6 +103,44 @@ final class MxGatewayCliTests { assertTrue(run.output().contains("\"itemHandle\":7")); } + @Test + void subscribeBulkCommandPrintsResults() { + CliRun run = execute( + new FakeClientFactory(), + "subscribe-bulk", + "--session-id", + "session-cli", + "--server-handle", + "42", + "--items", + "TestMachine_001.TestChangingInt,TestMachine_002.TestChangingInt", + "--json"); + + assertEquals(0, run.exitCode()); + assertTrue(run.output().contains("\"command\":\"subscribe-bulk\"")); + assertTrue(run.output().contains("\"itemHandle\":100")); + assertTrue(run.output().contains("\"tagAddress\":\"TestMachine_002.TestChangingInt\"")); + } + + @Test + void unsubscribeBulkCommandPrintsResults() { + CliRun run = execute( + new FakeClientFactory(), + "unsubscribe-bulk", + "--session-id", + "session-cli", + "--server-handle", + "42", + "--item-handles", + "100,101", + "--json"); + + assertEquals(0, run.exitCode()); + assertTrue(run.output().contains("\"command\":\"unsubscribe-bulk\"")); + assertTrue(run.output().contains("\"itemHandle\":101")); + assertTrue(run.output().contains("\"wasSuccessful\":true")); + } + private static CliRun execute(MxGatewayCli.MxGatewayCliClientFactory factory, String... args) { StringWriter output = new StringWriter(); StringWriter errors = new StringWriter(); @@ -227,6 +268,33 @@ final class MxGatewayCliTests { .build(); } + @Override + public List subscribeBulk(int serverHandle, List items) { + List results = new ArrayList<>(); + for (int index = 0; index < items.size(); index++) { + results.add(SubscribeResult.newBuilder() + .setServerHandle(serverHandle) + .setTagAddress(items.get(index)) + .setItemHandle(100 + index) + .setWasSuccessful(true) + .build()); + } + return results; + } + + @Override + public List unsubscribeBulk(int serverHandle, List itemHandles) { + List results = new ArrayList<>(); + for (Integer itemHandle : itemHandles) { + results.add(SubscribeResult.newBuilder() + .setServerHandle(serverHandle) + .setItemHandle(itemHandle) + .setWasSuccessful(true) + .build()); + } + return results; + } + @Override public com.dohertylan.mxgateway.client.MxEventStream streamEventsAfter(long afterWorkerSequence) { throw new UnsupportedOperationException("stream-events is covered by client tests"); diff --git a/clients/python/src/mxgateway_cli/commands.py b/clients/python/src/mxgateway_cli/commands.py index 10660c3..428e9d6 100644 --- a/clients/python/src/mxgateway_cli/commands.py +++ b/clients/python/src/mxgateway_cli/commands.py @@ -150,6 +150,40 @@ def advise(**kwargs: Any) -> None: _run(_advise(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs)) +@main.command("subscribe-bulk") +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.") +@click.option("--items", required=True, help="Comma-separated MXAccess item definitions.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def subscribe_bulk(**kwargs: Any) -> None: + """Invoke MXAccess SubscribeBulk.""" + + _run( + _subscribe_bulk(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + +@main.command("unsubscribe-bulk") +@gateway_options +@click.option("--session-id", required=True, help="Gateway session id.") +@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.") +@click.option("--item-handles", required=True, help="Comma-separated MXAccess item handles.") +@click.option("--correlation-id", default="", help="Client correlation id.") +@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.") +def unsubscribe_bulk(**kwargs: Any) -> None: + """Invoke MXAccess UnsubscribeBulk.""" + + _run( + _unsubscribe_bulk(**kwargs), + output_json=kwargs["output_json"], + secrets=_secrets(kwargs), + ) + + @main.command("stream-events") @gateway_options @click.option("--session-id", required=True, help="Gateway session id.") @@ -282,6 +316,28 @@ async def _advise(**kwargs: Any) -> dict[str, Any]: return {"ok": True} +async def _subscribe_bulk(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + results = await session.subscribe_bulk( + kwargs["server_handle"], + _parse_string_list(kwargs["items"]), + correlation_id=kwargs["correlation_id"], + ) + return {"results": [_message_dict(result) for result in results]} + + +async def _unsubscribe_bulk(**kwargs: Any) -> dict[str, Any]: + async with await _connect(kwargs) as client: + session = _session(client, kwargs["session_id"]) + results = await session.unsubscribe_bulk( + kwargs["server_handle"], + _parse_int_list(kwargs["item_handles"]), + correlation_id=kwargs["correlation_id"], + ) + return {"results": [_message_dict(result) for result in results]} + + async def _stream_events(**kwargs: Any) -> dict[str, Any]: async with await _connect(kwargs) as client: session = _session(client, kwargs["session_id"]) @@ -470,6 +526,20 @@ def _parse_datetime(raw_value: str) -> datetime: return parsed +def _parse_string_list(raw_value: str) -> list[str]: + values = [item.strip() for item in raw_value.split(",") if item.strip()] + if not values: + raise click.BadParameter("at least one item is required", param_hint="--items") + return values + + +def _parse_int_list(raw_value: str) -> list[int]: + values = [item.strip() for item in raw_value.split(",") if item.strip()] + if not values: + raise click.BadParameter("at least one item handle is required", param_hint="--item-handles") + return [int(item) for item in values] + + def _message_dict(message: Any) -> dict[str, Any]: return MessageToDict( message, diff --git a/clients/rust/crates/mxgw-cli/src/main.rs b/clients/rust/crates/mxgw-cli/src/main.rs index fd06350..5f95809 100644 --- a/clients/rust/crates/mxgw-cli/src/main.rs +++ b/clients/rust/crates/mxgw-cli/src/main.rs @@ -92,6 +92,30 @@ enum Command { #[arg(long)] json: bool, }, + SubscribeBulk { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + server_handle: i32, + #[arg(long, value_delimiter = ',')] + items: Vec, + #[arg(long)] + json: bool, + }, + UnsubscribeBulk { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + server_handle: i32, + #[arg(long, value_delimiter = ',')] + item_handles: Vec, + #[arg(long)] + json: bool, + }, StreamEvents { #[command(flatten)] connection: ConnectionArgs, @@ -103,6 +127,8 @@ enum Command { max_events: usize, #[arg(long)] json: bool, + #[arg(long)] + jsonl: bool, }, Write { #[command(flatten)] @@ -226,7 +252,7 @@ async fn main() -> ExitCode { async fn run(cli: Cli) -> Result<(), Error> { match cli.command { - Command::Version { json } => print_version(json), + Command::Version { json, .. } => print_version(json), Command::Ping { connection, message, @@ -323,6 +349,30 @@ async fn run(cli: Cli) -> Result<(), Error> { session.advise(server_handle, item_handle).await?; print_ok("advise", json); } + Command::SubscribeBulk { + connection, + session_id, + server_handle, + items, + json, + } => { + let session = session_for(connection, session_id).await?; + let results = session.subscribe_bulk(server_handle, items).await?; + print_bulk_results("subscribe-bulk", &results, json); + } + Command::UnsubscribeBulk { + connection, + session_id, + server_handle, + item_handles, + json, + } => { + let session = session_for(connection, session_id).await?; + let results = session + .unsubscribe_bulk(server_handle, item_handles) + .await?; + print_bulk_results("unsubscribe-bulk", &results, json); + } Command::StreamEvents { connection, session_id, @@ -527,6 +577,33 @@ fn print_ok(operation: &str, use_json: bool) { } } +fn print_bulk_results( + operation: &str, + results: &[mxgateway_client::generated::mxaccess_gateway::v1::SubscribeResult], + use_json: bool, +) { + if use_json { + let results_json: Vec<_> = results + .iter() + .map(|result| { + json!({ + "serverHandle": result.server_handle, + "tagAddress": result.tag_address, + "itemHandle": result.item_handle, + "wasSuccessful": result.was_successful, + "errorMessage": result.error_message, + }) + }) + .collect(); + println!( + "{}", + json!({ "operation": operation, "results": results_json }) + ); + } else { + println!("{}", results.len()); + } +} + fn parse_value(value_type: CliValueType, value: &str) -> Result { let parsed = match value_type { CliValueType::Bool => MxValue::bool(parse_cli_value(value)?), diff --git a/docs/GatewayConfiguration.md b/docs/GatewayConfiguration.md index dee9aa8..478cbc1 100644 --- a/docs/GatewayConfiguration.md +++ b/docs/GatewayConfiguration.md @@ -113,11 +113,12 @@ ordering and avoids competing consumers. | Option | Default | Description | |--------|---------|-------------| | `MxGateway:Events:QueueCapacity` | `10000` | Capacity for bounded per-session event queues used by the gateway worker event channel and the public gRPC event stream queue. | -| `MxGateway:Events:BackpressurePolicy` | `FailFast` | Event backpressure behavior. `FailFast` is the only supported value. | +| `MxGateway:Events:BackpressurePolicy` | `FailFast` | Event backpressure behavior. `FailFast` faults the session on public stream queue overflow. `DisconnectSubscriber` disconnects only the slow stream. | `QueueCapacity` must be greater than zero. With `FailFast`, queue overflow faults the affected worker or session instead of silently dropping MXAccess -events. +events. With `DisconnectSubscriber`, public gRPC stream overflow terminates only +the affected stream while the MXAccess session remains active. ## Dashboard Options diff --git a/docs/GatewayTesting.md b/docs/GatewayTesting.md index b250df4..42338f9 100644 --- a/docs/GatewayTesting.md +++ b/docs/GatewayTesting.md @@ -101,9 +101,10 @@ powershell -ExecutionPolicy Bypass -File scripts/discover-testmachine-tags.ps1 - `scripts/run-client-e2e-tests.ps1` drives the .NET, Go, Rust, Python, and Java client CLIs through a live gateway session. For each client it opens one -session, registers, adds and advises every discovered test tag, reads a bounded -event stream, then closes the session in a `finally` path. The script writes a -JSON report under `artifacts/e2e/`. +session, registers, verifies `SubscribeBulk` and `UnsubscribeBulk` on a bounded +tag subset, adds and advises every discovered test tag, reads a bounded event +stream, then closes the session in a `finally` path. The script writes a JSON +report under `artifacts/e2e/`. Build the gateway and worker, start the gateway, and provide a valid API key before running the client e2e script: @@ -117,7 +118,9 @@ Useful runner options: ```powershell powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -Clients dotnet,python -MachineStart 1 -MachineEnd 2 +powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -BulkTagCount 10 powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -SkipStream +powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -SkipBulk powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -Endpoint localhost:5000 -ApiKeyEnv MXGATEWAY_API_KEY ``` diff --git a/scripts/run-client-e2e-tests.ps1 b/scripts/run-client-e2e-tests.ps1 index 659ecf1..a5f2865 100644 --- a/scripts/run-client-e2e-tests.ps1 +++ b/scripts/run-client-e2e-tests.ps1 @@ -16,7 +16,9 @@ param( [string]$SqlServer = "localhost", [string]$Database = "ZB", [int]$EventLimit = 5, + [int]$BulkTagCount = 6, [switch]$SkipStream, + [switch]$SkipBulk, [switch]$DryRun, [string]$ReportPath ) @@ -50,6 +52,10 @@ if ($Attributes.Count -eq 0) { throw "At least one attribute is required." } +if ($BulkTagCount -lt 1) { + throw "BulkTagCount must be greater than zero." +} + foreach ($client in $Clients) { if ($validClients -notcontains $client) { throw "Unsupported client '$client'. Supported clients: $($validClients -join ', ')." @@ -237,6 +243,74 @@ function Get-StreamEventCount { } } +function Get-PropertyValue { + param( + [object]$Object, + [string[]]$Names + ) + + if ($null -eq $Object) { + return $null + } + + foreach ($name in $Names) { + $property = $Object.PSObject.Properties[$name] + if ($null -ne $property) { + return $property.Value + } + } + + return $null +} + +function Get-BulkResults { + param( + [string]$Client, + [string]$Operation, + [object]$Json + ) + + if ($Client -in @("go", "rust", "python", "java")) { + return @(Get-PropertyValue -Object $Json -Names @("results")) + } + + $replyName = if ($Operation -eq "subscribe-bulk") { "subscribeBulk" } else { "unsubscribeBulk" } + $reply = Get-PropertyValue -Object $Json -Names @($replyName) + return @(Get-PropertyValue -Object $reply -Names @("results")) +} + +function Get-BulkItemHandles { + param([object[]]$Results) + + return @($Results | ForEach-Object { + [int](Get-PropertyValue -Object $_ -Names @("itemHandle", "item_handle")) + } | Where-Object { + $_ -gt 0 + }) +} + +function Assert-BulkResults { + param( + [string]$Client, + [string]$Operation, + [object[]]$Results, + [int]$ExpectedCount + ) + + if ($Results.Count -ne $ExpectedCount) { + throw "$Client $Operation returned $($Results.Count) result(s); expected $ExpectedCount." + } + + foreach ($result in $Results) { + $success = Get-PropertyValue -Object $result -Names @("wasSuccessful", "was_successful") + if ($null -ne $success -and -not [bool]$success) { + $tagAddress = Get-PropertyValue -Object $result -Names @("tagAddress", "tag_address") + $errorMessage = Get-PropertyValue -Object $result -Names @("errorMessage", "error_message") + throw "$Client $Operation failed for '$tagAddress': $errorMessage" + } + } +} + function Get-ClientCommand { param( [string]$Client, @@ -266,6 +340,10 @@ function Get-ClientCommand { $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item) } elseif ($Operation -eq "advise") { $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)") + } elseif ($Operation -eq "subscribe-bulk") { + $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items) + } elseif ($Operation -eq "unsubscribe-bulk") { + $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles) } elseif ($Operation -eq "stream-events") { $arguments += @("--session-id", $Values.sessionId, "--max-events", "$EventLimit") } elseif ($Operation -eq "close-session") { @@ -289,6 +367,10 @@ function Get-ClientCommand { $arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-item", $Values.item) } elseif ($Operation -eq "advise") { $arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-item-handle", "$($Values.itemHandle)") + } elseif ($Operation -eq "subscribe-bulk") { + $arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-items", $Values.items) + } elseif ($Operation -eq "unsubscribe-bulk") { + $arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-item-handles", $Values.itemHandles) } elseif ($Operation -eq "stream-events") { $arguments += @("-session-id", $Values.sessionId, "-limit", "$EventLimit") } elseif ($Operation -eq "close-session") { @@ -311,6 +393,10 @@ function Get-ClientCommand { $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item) } elseif ($Operation -eq "advise") { $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)") + } elseif ($Operation -eq "subscribe-bulk") { + $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items) + } elseif ($Operation -eq "unsubscribe-bulk") { + $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles) } elseif ($Operation -eq "stream-events") { $arguments += @("--session-id", $Values.sessionId, "--max-events", "$EventLimit") } elseif ($Operation -eq "close-session") { @@ -334,6 +420,10 @@ function Get-ClientCommand { $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item) } elseif ($Operation -eq "advise") { $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)") + } elseif ($Operation -eq "subscribe-bulk") { + $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items) + } elseif ($Operation -eq "unsubscribe-bulk") { + $arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles) } elseif ($Operation -eq "stream-events") { $arguments += @("--session-id", $Values.sessionId, "--max-events", "$EventLimit", "--timeout", "15") } elseif ($Operation -eq "close-session") { @@ -360,6 +450,10 @@ function Get-ClientCommand { $cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item) } elseif ($Operation -eq "advise") { $cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)") + } elseif ($Operation -eq "subscribe-bulk") { + $cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items) + } elseif ($Operation -eq "unsubscribe-bulk") { + $cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles) } elseif ($Operation -eq "stream-events") { $cliArgs += @("--session-id", $Values.sessionId, "--limit", "$EventLimit") } elseif ($Operation -eq "close-session") { @@ -389,6 +483,18 @@ function Invoke-ClientOperation { "open-session" { return [pscustomobject]@{ sessionId = "dry-run-session-$Client"; reply = [pscustomobject]@{ sessionId = "dry-run-session-$Client" } } } "register" { return [pscustomobject]@{ serverHandle = 1; register = [pscustomobject]@{ serverHandle = 1 }; reply = [pscustomobject]@{ register = [pscustomobject]@{ serverHandle = 1 } } } } "add-item" { return [pscustomobject]@{ itemHandle = 1; addItem = [pscustomobject]@{ itemHandle = 1 }; reply = [pscustomobject]@{ addItem = [pscustomobject]@{ itemHandle = 1 } } } } + "subscribe-bulk" { + $results = @($Values.items -split "," | ForEach-Object -Begin { $index = 1 } -Process { + [pscustomobject]@{ itemHandle = $index++; tagAddress = $_; wasSuccessful = $true } + }) + return [pscustomobject]@{ subscribeBulk = [pscustomobject]@{ results = $results }; results = $results } + } + "unsubscribe-bulk" { + $results = @($Values.itemHandles -split "," | ForEach-Object { + [pscustomobject]@{ itemHandle = [int]$_; wasSuccessful = $true } + }) + return [pscustomobject]@{ unsubscribeBulk = [pscustomobject]@{ results = $results }; results = $results } + } "stream-events" { return [pscustomobject]@{ eventCount = 1; events = @([pscustomobject]@{ workerSequence = 1 }) } } default { return [pscustomobject]@{ ok = $true; reply = [pscustomobject]@{} } } } @@ -425,7 +531,9 @@ $run = [ordered]@{ machineEnd = $MachineEnd attributes = $Attributes eventLimit = $EventLimit + bulkTagCount = $BulkTagCount skipStream = [bool]$SkipStream + skipBulk = [bool]$SkipBulk startedAt = (Get-Date).ToUniversalTime().ToString("O") discoveredTags = $tags clients = @() @@ -441,6 +549,7 @@ foreach ($client in $Clients) { language = $client sessionId = $null serverHandle = $null + bulk = $null addedItems = @() eventCount = 0 closed = $false @@ -461,6 +570,37 @@ foreach ($client in $Clients) { $serverHandle = Get-ServerHandle -Client $client -Json $registerJson $clientResult.serverHandle = $serverHandle + if (-not $SkipBulk) { + $bulkTags = @($tags | Select-Object -First ([Math]::Min($BulkTagCount, $tags.Count))) + $bulkItems = ($bulkTags | ForEach-Object { $_.fullTagReference }) -join "," + $subscribeBulkJson = Invoke-ClientOperation -Client $client -Operation "subscribe-bulk" -Values @{ + sessionId = $sessionId + serverHandle = $serverHandle + items = $bulkItems + } + $subscribeResults = @(Get-BulkResults -Client $client -Operation "subscribe-bulk" -Json $subscribeBulkJson) + Assert-BulkResults -Client $client -Operation "subscribe-bulk" -Results $subscribeResults -ExpectedCount $bulkTags.Count + $bulkItemHandles = @(Get-BulkItemHandles -Results $subscribeResults) + if ($bulkItemHandles.Count -ne $bulkTags.Count) { + throw "$client subscribe-bulk returned $($bulkItemHandles.Count) usable item handle(s); expected $($bulkTags.Count)." + } + + $unsubscribeBulkJson = Invoke-ClientOperation -Client $client -Operation "unsubscribe-bulk" -Values @{ + sessionId = $sessionId + serverHandle = $serverHandle + itemHandles = $bulkItemHandles -join "," + } + $unsubscribeResults = @(Get-BulkResults -Client $client -Operation "unsubscribe-bulk" -Json $unsubscribeBulkJson) + Assert-BulkResults -Client $client -Operation "unsubscribe-bulk" -Results $unsubscribeResults -ExpectedCount $bulkItemHandles.Count + + $clientResult.bulk = [ordered]@{ + tagCount = $bulkTags.Count + subscribedCount = $subscribeResults.Count + unsubscribedCount = $unsubscribeResults.Count + itemHandles = $bulkItemHandles + } + } + foreach ($tag in $tags) { $addJson = Invoke-ClientOperation -Client $client -Operation "add-item" -Values @{ sessionId = $sessionId diff --git a/src/MxGateway.Server/Configuration/EventBackpressurePolicy.cs b/src/MxGateway.Server/Configuration/EventBackpressurePolicy.cs index 9d341e0..089aab4 100644 --- a/src/MxGateway.Server/Configuration/EventBackpressurePolicy.cs +++ b/src/MxGateway.Server/Configuration/EventBackpressurePolicy.cs @@ -2,5 +2,7 @@ namespace MxGateway.Server.Configuration; public enum EventBackpressurePolicy { - FailFast + FailFast, + + DisconnectSubscriber } diff --git a/src/MxGateway.Server/Grpc/EventStreamService.cs b/src/MxGateway.Server/Grpc/EventStreamService.cs index 949380a..8d8249f 100644 --- a/src/MxGateway.Server/Grpc/EventStreamService.cs +++ b/src/MxGateway.Server/Grpc/EventStreamService.cs @@ -108,9 +108,19 @@ public sealed class EventStreamService( if (!writer.TryWrite(publicEvent)) { string message = $"Session {session.SessionId} event stream queue overflowed."; - session.MarkFaulted(message); metrics.QueueOverflow("grpc-event-stream"); - metrics.Fault(SessionManagerErrorCode.EventQueueOverflow.ToString()); + if (options.Value.Events.BackpressurePolicy == EventBackpressurePolicy.FailFast) + { + session.MarkFaulted(message); + metrics.Fault(SessionManagerErrorCode.EventQueueOverflow.ToString()); + } + else + { + logger.LogDebug( + "Disconnecting event stream for session {SessionId} after queue overflow.", + session.SessionId); + } + writer.TryComplete(new SessionManagerException( SessionManagerErrorCode.EventQueueOverflow, message)); diff --git a/src/MxGateway.Server/Metrics/GatewayMetrics.cs b/src/MxGateway.Server/Metrics/GatewayMetrics.cs index 82e282e..96e9412 100644 --- a/src/MxGateway.Server/Metrics/GatewayMetrics.cs +++ b/src/MxGateway.Server/Metrics/GatewayMetrics.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using System.Diagnostics.Metrics; namespace MxGateway.Server.Metrics; @@ -25,8 +26,8 @@ public sealed class GatewayMetrics : IDisposable private readonly Histogram _commandLatencyHistogram; private readonly Histogram _eventStreamSendLatencyHistogram; private readonly Dictionary _commandFailuresByMethod = new(StringComparer.OrdinalIgnoreCase); - private readonly Dictionary _eventsByFamily = new(StringComparer.OrdinalIgnoreCase); - private readonly Dictionary _eventsBySession = new(StringComparer.Ordinal); + private readonly ConcurrentDictionary _eventsByFamily = new(StringComparer.OrdinalIgnoreCase); + private readonly ConcurrentDictionary _eventsBySession = new(StringComparer.Ordinal); private readonly Dictionary _retryAttemptsByArea = new(StringComparer.OrdinalIgnoreCase); private int _openSessions; @@ -173,12 +174,9 @@ public sealed class GatewayMetrics : IDisposable public void EventReceived(string sessionId, string family) { - lock (_syncRoot) - { - _eventsReceived++; - Increment(_eventsByFamily, family); - Increment(_eventsBySession, sessionId); - } + Interlocked.Increment(ref _eventsReceived); + Increment(_eventsByFamily, family); + Increment(_eventsBySession, sessionId); _eventsReceivedCounter.Add( 1, @@ -225,10 +223,7 @@ public sealed class GatewayMetrics : IDisposable public void RemoveSessionEvents(string sessionId) { - lock (_syncRoot) - { - _eventsBySession.Remove(sessionId); - } + _eventsBySession.TryRemove(sessionId, out _); } public void QueueOverflow(string queueName) @@ -296,7 +291,7 @@ public sealed class GatewayMetrics : IDisposable CommandsStarted: _commandsStarted, CommandsSucceeded: _commandsSucceeded, CommandsFailed: _commandsFailed, - EventsReceived: _eventsReceived, + EventsReceived: Interlocked.Read(ref _eventsReceived), QueueOverflows: _queueOverflows, Faults: _faults, WorkerKills: _workerKills, @@ -359,4 +354,9 @@ public sealed class GatewayMetrics : IDisposable values.TryGetValue(key, out long currentValue); values[key] = currentValue + 1; } + + private static void Increment(ConcurrentDictionary values, string key) + { + values.AddOrUpdate(key, 1, static (_, currentValue) => currentValue + 1); + } } diff --git a/src/MxGateway.Server/Sessions/SessionWorkerClientFactory.cs b/src/MxGateway.Server/Sessions/SessionWorkerClientFactory.cs index b6612d1..8b1d0bc 100644 --- a/src/MxGateway.Server/Sessions/SessionWorkerClientFactory.cs +++ b/src/MxGateway.Server/Sessions/SessionWorkerClientFactory.cs @@ -41,6 +41,9 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory NamedPipeServerStream? pipe = CreatePipe(session.PipeName); WorkerProcessHandle? processHandle = null; IWorkerClient? workerClient = null; + using CancellationTokenSource startupCancellation = + CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + startupCancellation.CancelAfter(session.StartupTimeout); try { session.TransitionTo(SessionState.StartingWorker); @@ -52,11 +55,11 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory GatewayContractInfo.WorkerProtocolVersion, session.Nonce, pipe), - cancellationToken) + startupCancellation.Token) .ConfigureAwait(false); session.TransitionTo(SessionState.WaitingForPipe); - await WaitForPipeConnectionAsync(pipe, session.StartupTimeout, cancellationToken).ConfigureAwait(false); + await WaitForPipeConnectionAsync(pipe, startupCancellation.Token).ConfigureAwait(false); session.TransitionTo(SessionState.Handshaking); WorkerFrameProtocolOptions frameOptions = new( @@ -88,14 +91,23 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory processHandle = null; session.TransitionTo(SessionState.InitializingWorker); - await workerClient.StartAsync(cancellationToken).ConfigureAwait(false); + await workerClient.StartAsync(startupCancellation.Token).ConfigureAwait(false); return workerClient; } - catch + catch (Exception exception) { if (workerClient is not null) { + try + { + workerClient.Kill("OpenSessionFailed"); + } + catch + { + // Preserve the startup failure while still disposing below. + } + await workerClient.DisposeAsync().ConfigureAwait(false); } else @@ -119,6 +131,15 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory pipe?.Dispose(); } + if (exception is OperationCanceledException + && startupCancellation.IsCancellationRequested + && !cancellationToken.IsCancellationRequested) + { + throw new TimeoutException( + $"Worker session {session.SessionId} did not complete startup within {session.StartupTimeout}.", + exception); + } + throw; } } @@ -135,11 +156,8 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory private static async Task WaitForPipeConnectionAsync( NamedPipeServerStream pipe, - TimeSpan startupTimeout, CancellationToken cancellationToken) { - using CancellationTokenSource timeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - timeout.CancelAfter(startupTimeout); - await pipe.WaitForConnectionAsync(timeout.Token).ConfigureAwait(false); + await pipe.WaitForConnectionAsync(cancellationToken).ConfigureAwait(false); } } diff --git a/src/MxGateway.Server/Workers/WorkerClient.cs b/src/MxGateway.Server/Workers/WorkerClient.cs index 4f944c8..25ed454 100644 --- a/src/MxGateway.Server/Workers/WorkerClient.cs +++ b/src/MxGateway.Server/Workers/WorkerClient.cs @@ -13,6 +13,7 @@ namespace MxGateway.Server.Workers; public sealed class WorkerClient : IWorkerClient { private const string GatewayVersionFallback = "unknown"; + private static readonly TimeSpan DisposeTaskTimeout = TimeSpan.FromSeconds(5); private readonly object _syncRoot = new(); private readonly WorkerClientConnection _connection; private readonly WorkerClientOptions _options; @@ -286,8 +287,19 @@ public sealed class WorkerClient : IWorkerClient WorkerClientErrorCode.GatewayShutdown, "Worker client was disposed.")); - await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false); await _connection.Stream.DisposeAsync().ConfigureAwait(false); + using CancellationTokenSource disposeTimeout = new(DisposeTaskTimeout); + try + { + await WaitForBackgroundTasksAsync(disposeTimeout.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + _logger.LogWarning( + "Timed out waiting for worker client background tasks to stop for session {SessionId}.", + SessionId); + } + _connection.ProcessHandle?.Dispose(); _pendingCommandSlots.Dispose(); _stopCts.Dispose(); diff --git a/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs b/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs index 4960895..d0d864f 100644 --- a/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs +++ b/src/MxGateway.Tests/Gateway/Grpc/EventStreamServiceTests.cs @@ -114,6 +114,37 @@ public sealed class EventStreamServiceTests Assert.Equal(1, metrics.GetSnapshot().Faults); } + [Fact] + public async Task StreamEventsAsync_WhenStreamQueueOverflowsWithDisconnectPolicy_LeavesSessionReady() + { + FakeWorkerClient workerClient = new(); + GatewaySession session = CreateReadySession(workerClient); + using GatewayMetrics metrics = new(); + EventStreamService service = CreateService( + new FakeSessionManager(session), + metrics, + queueCapacity: 1, + backpressurePolicy: EventBackpressurePolicy.DisconnectSubscriber); + workerClient.Events.Add(CreateWorkerEvent(sequence: 1, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 2, MxEventFamily.OnDataChange)); + workerClient.Events.Add(CreateWorkerEvent(sequence: 3, MxEventFamily.OnDataChange)); + workerClient.CompleteAfterConfiguredEvents = true; + await using IAsyncEnumerator subscriber = service + .StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None) + .GetAsyncEnumerator(); + + Assert.True(await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + SessionManagerException exception = await Assert.ThrowsAsync( + async () => await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout)); + + Assert.Equal(SessionManagerErrorCode.EventQueueOverflow, exception.ErrorCode); + Assert.Equal(SessionState.Ready, session.State); + GatewayMetricsSnapshot snapshot = metrics.GetSnapshot(); + Assert.Equal(1, snapshot.QueueOverflows); + Assert.Equal(0, snapshot.Faults); + Assert.Equal(1, snapshot.StreamDisconnects); + } + [Fact] public async Task StreamEventsAsync_DoesNotSynthesizeOperationComplete() { @@ -157,7 +188,8 @@ public sealed class EventStreamServiceTests private static EventStreamService CreateService( FakeSessionManager sessionManager, GatewayMetrics? metrics = null, - int queueCapacity = 8) + int queueCapacity = 8, + EventBackpressurePolicy backpressurePolicy = EventBackpressurePolicy.FailFast) { return new EventStreamService( sessionManager, @@ -166,6 +198,7 @@ public sealed class EventStreamServiceTests Events = new EventOptions { QueueCapacity = queueCapacity, + BackpressurePolicy = backpressurePolicy, }, }), new MxAccessGrpcMapper(), diff --git a/src/MxGateway.Tests/Gateway/Sessions/SessionWorkerClientFactoryFakeWorkerTests.cs b/src/MxGateway.Tests/Gateway/Sessions/SessionWorkerClientFactoryFakeWorkerTests.cs index 625dd38..a3a422a 100644 --- a/src/MxGateway.Tests/Gateway/Sessions/SessionWorkerClientFactoryFakeWorkerTests.cs +++ b/src/MxGateway.Tests/Gateway/Sessions/SessionWorkerClientFactoryFakeWorkerTests.cs @@ -65,13 +65,33 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests Assert.True(launcher.Process.IsDisposed); } - private static GatewayOptions CreateOptions() + [Fact] + public async Task CreateAsync_WhenFakeWorkerNeverSendsReady_TimesOutAndKillsWorker() + { + NeverReadyWorkerProcessLauncher launcher = new(); + using GatewayMetrics metrics = new(); + SessionWorkerClientFactory factory = new( + launcher, + Options.Create(CreateOptions(startupTimeoutSeconds: 1)), + metrics, + NullLoggerFactory.Instance); + GatewaySession session = CreateSession(startupTimeout: TimeSpan.FromSeconds(1)); + + TimeoutException exception = await Assert.ThrowsAsync( + async () => await factory.CreateAsync(session, CancellationToken.None).WaitAsync(TestTimeout)); + + Assert.Contains("did not complete startup", exception.Message); + Assert.Equal(1, launcher.Process.KillCount); + Assert.True(launcher.Process.IsDisposed); + } + + private static GatewayOptions CreateOptions(int startupTimeoutSeconds = 5) { return new GatewayOptions { Worker = new WorkerOptions { - StartupTimeoutSeconds = 5, + StartupTimeoutSeconds = startupTimeoutSeconds, ShutdownTimeoutSeconds = 5, HeartbeatIntervalSeconds = 30, HeartbeatGraceSeconds = 30, @@ -84,7 +104,7 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests }; } - private static GatewaySession CreateSession() + private static GatewaySession CreateSession(TimeSpan? startupTimeout = null) { return new GatewaySession( FakeWorkerHarness.DefaultSessionId, @@ -94,7 +114,7 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests "test-client", "fake-worker-session-test", "client-correlation-1", - TestTimeout, + startupTimeout ?? TestTimeout, TestTimeout, TestTimeout, DateTimeOffset.UtcNow); @@ -172,6 +192,38 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests } } + private sealed class NeverReadyWorkerProcessLauncher : IWorkerProcessLauncher + { + public FakeWorkerProcess Process { get; } = new(processId: 4680); + + public Task LaunchAsync( + WorkerProcessLaunchRequest request, + CancellationToken cancellationToken = default) + { + _ = RunWorkerAsync(request, cancellationToken); + + return Task.FromResult(CreateHandle(Process)); + } + + private async Task RunWorkerAsync( + WorkerProcessLaunchRequest request, + CancellationToken cancellationToken) + { + await using FakeWorkerHarness harness = await FakeWorkerHarness.ConnectToGatewayPipeAsync( + request.SessionId, + request.Nonce, + request.PipeName, + request.ProtocolVersion, + cancellationToken: cancellationToken).ConfigureAwait(false); + _ = await harness.ReadGatewayEnvelopeAsync(cancellationToken).ConfigureAwait(false); + await harness.SendWorkerHelloAsync( + workerProcessId: Process.Id, + workerProtocolVersion: request.ProtocolVersion, + cancellationToken: cancellationToken).ConfigureAwait(false); + await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken).ConfigureAwait(false); + } + } + private static WorkerProcessHandle CreateHandle(IWorkerProcess process) { return new WorkerProcessHandle( diff --git a/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs b/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs index 1dee019..2cc1c17 100644 --- a/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs +++ b/src/MxGateway.Tests/Gateway/Workers/WorkerClientTests.cs @@ -166,7 +166,8 @@ public sealed class WorkerClientTests await pipePair.DisposeWorkerSideAsync(); await WaitUntilAsync( - () => client.State == WorkerClientState.Faulted, + () => client.State == WorkerClientState.Faulted + && metrics.GetSnapshot().WorkersRunning == 0, TestTimeout); GatewayMetricsSnapshot snapshot = metrics.GetSnapshot(); @@ -174,6 +175,22 @@ public sealed class WorkerClientTests Assert.Equal(1, snapshot.WorkerExits); } + [Fact] + public async Task DisposeAsync_WhenPipeReadIsBlocked_ReturnsWithinBoundedTimeout() + { + await using PipePair pipePair = await PipePair.CreateAsync(); + WorkerClient client = CreateClient(pipePair); + await CompleteHandshakeAsync(client, pipePair); + + DateTimeOffset startedAt = DateTimeOffset.UtcNow; + await client.DisposeAsync().AsTask().WaitAsync(TestTimeout); + TimeSpan elapsed = DateTimeOffset.UtcNow - startedAt; + + Assert.True( + elapsed < TimeSpan.FromSeconds(4), + $"DisposeAsync took {elapsed.TotalMilliseconds:N0}ms."); + } + [Fact] public async Task ReadLoop_WhenHeartbeatArrives_UpdatesLastHeartbeatAndWorkerProcess() { diff --git a/src/MxGateway.Tests/Metrics/GatewayMetricsTests.cs b/src/MxGateway.Tests/Metrics/GatewayMetricsTests.cs index 8ad9bbe..cea1f2f 100644 --- a/src/MxGateway.Tests/Metrics/GatewayMetricsTests.cs +++ b/src/MxGateway.Tests/Metrics/GatewayMetricsTests.cs @@ -60,4 +60,22 @@ public sealed class GatewayMetricsTests Assert.Equal("depth", exception.ParamName); } + + [Fact] + public void RemoveSessionEvents_RemovesOnlyThatSession() + { + using GatewayMetrics metrics = new(); + + metrics.EventReceived("session-1", "OnDataChange"); + metrics.EventReceived("session-2", "OnWriteComplete"); + metrics.RemoveSessionEvents("session-1"); + + GatewayMetricsSnapshot snapshot = metrics.GetSnapshot(); + + Assert.Equal(2, snapshot.EventsReceived); + Assert.False(snapshot.EventsBySession.ContainsKey("session-1")); + Assert.Equal(1, snapshot.EventsBySession["session-2"]); + Assert.Equal(1, snapshot.EventsByFamily["OnDataChange"]); + Assert.Equal(1, snapshot.EventsByFamily["OnWriteComplete"]); + } } diff --git a/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs index 311ac86..c05e574 100644 --- a/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs +++ b/src/MxGateway.Worker.Tests/Ipc/WorkerPipeSessionTests.cs @@ -304,6 +304,45 @@ public sealed class WorkerPipeSessionTests await SendShutdownAndWaitAsync(pipePair, runTask, cancellation.Token); } + [Fact] + public async Task RunAsync_WhenShutdownArrivesDuringCommand_DropsLateReplyAndWritesShutdownAck() + { + using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5)); + using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token); + FakeRuntimeSession runtime = new() + { + BlockDispatch = true, + }; + WorkerPipeSession session = CreatePipeSession( + pipePair.WorkerStream, + runtime, + new WorkerPipeSessionOptions + { + HeartbeatInterval = TimeSpan.FromSeconds(1), + HeartbeatGrace = TimeSpan.FromSeconds(5), + }); + Task runTask = session.RunAsync(cancellation.Token); + await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token); + + await pipePair.GatewayWriter.WriteAsync( + CreateCommandEnvelope("command-during-shutdown"), + cancellation.Token); + Assert.True(runtime.DispatchStarted.Wait(TimeSpan.FromSeconds(2))); + + await pipePair.GatewayWriter + .WriteAsync(CreateShutdownEnvelope(), cancellation.Token); + + WorkerEnvelope shutdownAck = await ReadUntilAsync( + pipePair.GatewayReader, + WorkerEnvelope.BodyOneofCase.WorkerShutdownAck, + cancellation.Token); + + Assert.Equal(ProtocolStatusCode.Ok, shutdownAck.WorkerShutdownAck.Status.Code); + Task completedTask = await Task.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellation.Token)); + Assert.Same(runTask, completedTask); + await runTask; + } + private static WorkerPipeSession CreateSession( Stream inbound, Stream outbound, @@ -440,7 +479,7 @@ public sealed class WorkerPipeSessionTests Assert.Equal(ProtocolStatusCode.Ok, shutdownAck.WorkerShutdownAck.Status.Code); Task completedTask = await Task - .WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellationToken)) + .WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(5), cancellationToken)) .ConfigureAwait(false); Assert.Same(runTask, completedTask); diff --git a/src/MxGateway.Worker.Tests/MxAccess/MxAccessEventQueueTests.cs b/src/MxGateway.Worker.Tests/MxAccess/MxAccessEventQueueTests.cs index 72b61ee..cf884ec 100644 --- a/src/MxGateway.Worker.Tests/MxAccess/MxAccessEventQueueTests.cs +++ b/src/MxGateway.Worker.Tests/MxAccess/MxAccessEventQueueTests.cs @@ -12,17 +12,17 @@ public sealed class MxAccessEventQueueTests { MxAccessEventQueue queue = new(capacity: 4); - WorkerEvent first = queue.Enqueue(CreateEvent(MxEventFamily.OnDataChange, itemHandle: 10)); - WorkerEvent second = queue.Enqueue(CreateEvent(MxEventFamily.OnWriteComplete, itemHandle: 11)); + queue.Enqueue(CreateEvent(MxEventFamily.OnDataChange, itemHandle: 10)); + queue.Enqueue(CreateEvent(MxEventFamily.OnWriteComplete, itemHandle: 11)); - Assert.Equal(1UL, first.Event.WorkerSequence); - Assert.Equal(2UL, second.Event.WorkerSequence); - Assert.NotNull(first.Event.WorkerTimestamp); Assert.Equal(2, queue.Count); Assert.Equal(2UL, queue.LastEventSequence); Assert.True(queue.TryDequeue(out WorkerEvent? dequeuedFirst)); Assert.True(queue.TryDequeue(out WorkerEvent? dequeuedSecond)); + Assert.Equal(1UL, dequeuedFirst?.Event.WorkerSequence); + Assert.Equal(2UL, dequeuedSecond?.Event.WorkerSequence); + Assert.NotNull(dequeuedFirst?.Event.WorkerTimestamp); Assert.Equal(10, dequeuedFirst?.Event.ItemHandle); Assert.Equal(11, dequeuedSecond?.Event.ItemHandle); Assert.False(queue.TryDequeue(out _)); diff --git a/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs b/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs index a6e2730..cdbea79 100644 --- a/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs +++ b/src/MxGateway.Worker/Ipc/WorkerPipeSession.cs @@ -15,6 +15,7 @@ namespace MxGateway.Worker.Ipc; public sealed class WorkerPipeSession { private static readonly TimeSpan EventDrainInterval = TimeSpan.FromMilliseconds(25); + private static readonly TimeSpan BackgroundTaskStopTimeout = TimeSpan.FromSeconds(1); private const uint EventDrainBatchSize = 128; private readonly WorkerFrameProtocolOptions _options; @@ -24,9 +25,12 @@ public sealed class WorkerPipeSession private readonly IWorkerLogger? _logger; private readonly WorkerFrameReader _reader; private readonly WorkerFrameWriter _writer; + private readonly object _commandTaskGate = new(); + private readonly HashSet _activeCommandTasks = new(); private IWorkerRuntimeSession? _runtimeSession; private long _nextSequence; private WorkerState _state = WorkerState.Starting; + private bool _acceptingCommands = true; private bool _watchdogFaultSent; private bool _shutdownTimedOut; @@ -206,18 +210,31 @@ public sealed class WorkerPipeSession private async Task RunMessageLoopAsync(CancellationToken cancellationToken) { + using CancellationTokenSource loopCancellation = CancellationTokenSource + .CreateLinkedTokenSource(cancellationToken); using CancellationTokenSource heartbeatCancellation = CancellationTokenSource .CreateLinkedTokenSource(cancellationToken); Task heartbeatTask = RunHeartbeatLoopAsync(heartbeatCancellation.Token); Task eventDrainTask = RunEventDrainLoopAsync(heartbeatCancellation.Token); + Task readTask = _reader.ReadAsync(loopCancellation.Token); try { while (!cancellationToken.IsCancellationRequested) { - Task readTask = _reader.ReadAsync(cancellationToken); Task completedTask = await Task.WhenAny(readTask, heartbeatTask, eventDrainTask).ConfigureAwait(false); - if (completedTask == heartbeatTask) + if (completedTask == readTask) + { + WorkerEnvelope envelope = await readTask.ConfigureAwait(false); + bool keepReading = await DispatchGatewayEnvelopeAsync(envelope, cancellationToken).ConfigureAwait(false); + if (!keepReading) + { + return; + } + + readTask = _reader.ReadAsync(loopCancellation.Token); + } + else if (completedTask == heartbeatTask) { await heartbeatTask.ConfigureAwait(false); } @@ -225,33 +242,52 @@ public sealed class WorkerPipeSession { await eventDrainTask.ConfigureAwait(false); } - - WorkerEnvelope envelope = await readTask.ConfigureAwait(false); - bool keepReading = await DispatchGatewayEnvelopeAsync(envelope, cancellationToken).ConfigureAwait(false); - if (!keepReading) - { - return; - } } } finally { + loopCancellation.Cancel(); heartbeatCancellation.Cancel(); - try - { - await heartbeatTask.ConfigureAwait(false); - } - catch (OperationCanceledException) - { - } + await ObserveBackgroundTaskStopAsync(heartbeatTask, "Heartbeat").ConfigureAwait(false); + await ObserveBackgroundTaskStopAsync(eventDrainTask, "EventDrain").ConfigureAwait(false); + } + } - try - { - await eventDrainTask.ConfigureAwait(false); - } - catch (OperationCanceledException) - { - } + private async Task ObserveBackgroundTaskStopAsync( + Task task, + string taskName) + { + Task completedTask = await Task + .WhenAny(task, Task.Delay(BackgroundTaskStopTimeout)) + .ConfigureAwait(false); + if (completedTask != task) + { + _logger?.Error( + "WorkerPipeSessionBackgroundTaskStopTimedOut", + new Dictionary + { + ["task"] = taskName, + ["timeout_ms"] = BackgroundTaskStopTimeout.TotalMilliseconds, + }); + return; + } + + try + { + await task.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + } + catch (Exception ex) + { + _logger?.Error( + "WorkerPipeSessionBackgroundTaskStopFailed", + new Dictionary + { + ["task"] = taskName, + ["exception"] = ex.ToString(), + }); } } @@ -300,7 +336,7 @@ public sealed class WorkerPipeSession switch (envelope.BodyCase) { case WorkerEnvelope.BodyOneofCase.WorkerCommand: - _ = ProcessCommandAsync(envelope, cancellationToken); + TryStartCommandTask(envelope, cancellationToken); return true; case WorkerEnvelope.BodyOneofCase.WorkerShutdown: await ShutdownAsync(envelope.WorkerShutdown, cancellationToken).ConfigureAwait(false); @@ -333,6 +369,11 @@ public sealed class WorkerPipeSession try { MxCommandReply reply = await runtimeSession.DispatchAsync(staCommand).ConfigureAwait(false); + if (_state is not WorkerState.Ready and not WorkerState.ExecutingCommand) + { + return; + } + await _writer .WriteAsync( CreateEnvelope(new WorkerCommandReply @@ -370,11 +411,13 @@ public sealed class WorkerPipeSession } TimeSpan gracePeriod = ResolveGracePeriod(shutdown); + StopAcceptingCommands(); try { MxAccessShutdownResult result = await runtimeSession .ShutdownGracefullyAsync(gracePeriod, cancellationToken) .ConfigureAwait(false); + await WaitForActiveCommandTasksAsync(gracePeriod, cancellationToken).ConfigureAwait(false); LogShutdownFailures(result.Failures); await WriteShutdownAckAsync(CreateShutdownAck(result, shutdown), cancellationToken).ConfigureAwait(false); } @@ -387,6 +430,79 @@ public sealed class WorkerPipeSession } } + private void TryStartCommandTask( + WorkerEnvelope envelope, + CancellationToken cancellationToken) + { + Task commandTask; + lock (_commandTaskGate) + { + if (!_acceptingCommands) + { + return; + } + + commandTask = ProcessCommandAsync(envelope, cancellationToken); + _activeCommandTasks.Add(commandTask); + } + + _ = ObserveCommandTaskAsync(commandTask); + } + + private async Task ObserveCommandTaskAsync(Task commandTask) + { + try + { + await commandTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + } + finally + { + lock (_commandTaskGate) + { + _activeCommandTasks.Remove(commandTask); + } + } + } + + private void StopAcceptingCommands() + { + lock (_commandTaskGate) + { + _acceptingCommands = false; + } + } + + private async Task WaitForActiveCommandTasksAsync( + TimeSpan timeout, + CancellationToken cancellationToken) + { + Task[] activeTasks; + lock (_commandTaskGate) + { + activeTasks = new List(_activeCommandTasks).ToArray(); + } + + if (activeTasks.Length == 0) + { + return; + } + + Task activeCommandsTask = Task.WhenAll(activeTasks); + Task timeoutTask = Task.Delay(timeout, cancellationToken); + Task completedTask = await Task.WhenAny(activeCommandsTask, timeoutTask).ConfigureAwait(false); + if (completedTask == activeCommandsTask) + { + await activeCommandsTask.ConfigureAwait(false); + return; + } + + cancellationToken.ThrowIfCancellationRequested(); + throw new TimeoutException($"Worker command tasks did not stop within {timeout}."); + } + private Task WriteShutdownAckAsync( WorkerShutdownAck shutdownAck, CancellationToken cancellationToken) diff --git a/src/MxGateway.Worker/MxAccess/MxAccessEventQueue.cs b/src/MxGateway.Worker/MxAccess/MxAccessEventQueue.cs index 9a2d19f..184abd9 100644 --- a/src/MxGateway.Worker/MxAccess/MxAccessEventQueue.cs +++ b/src/MxGateway.Worker/MxAccess/MxAccessEventQueue.cs @@ -80,7 +80,7 @@ public sealed class MxAccessEventQueue } } - public WorkerEvent Enqueue(MxEvent mxEvent) + public void Enqueue(MxEvent mxEvent) { if (mxEvent is null) { @@ -109,8 +109,6 @@ public sealed class MxAccessEventQueue Event = queuedEvent, }; events.Enqueue(workerEvent); - - return workerEvent.Clone(); } } @@ -124,7 +122,7 @@ public sealed class MxAccessEventQueue return false; } - workerEvent = events.Dequeue().Clone(); + workerEvent = events.Dequeue(); return true; } } @@ -144,7 +142,7 @@ public sealed class MxAccessEventQueue List drained = new(drainCount); for (int index = 0; index < drainCount; index++) { - drained.Add(events.Dequeue().Clone()); + drained.Add(events.Dequeue()); } return drained;