Improve gateway reliability and client e2e coverage

This commit is contained in:
Joseph Doherty
2026-04-28 06:11:18 -04:00
parent 4fc355b357
commit 907aa49aea
25 changed files with 1153 additions and 83 deletions
@@ -86,6 +86,10 @@ public static class MxGatewayClientCli
.ConfigureAwait(false),
"advise" => await AdviseAsync(arguments, client, standardOutput, cancellation.Token)
.ConfigureAwait(false),
"subscribe-bulk" => await SubscribeBulkAsync(arguments, client, standardOutput, cancellation.Token)
.ConfigureAwait(false),
"unsubscribe-bulk" => await UnsubscribeBulkAsync(arguments, client, standardOutput, cancellation.Token)
.ConfigureAwait(false),
"stream-events" => await StreamEventsAsync(arguments, client, standardOutput, cancellation.Token)
.ConfigureAwait(false),
"write" => await WriteAsync(arguments, client, standardOutput, cancellation.Token)
@@ -289,6 +293,54 @@ public static class MxGatewayClientCli
cancellationToken);
}
private static Task<int> SubscribeBulkAsync(
CliArguments arguments,
IMxGatewayCliClient client,
TextWriter output,
CancellationToken cancellationToken)
{
SubscribeBulkCommand command = new()
{
ServerHandle = arguments.GetInt32("server-handle"),
};
command.TagAddresses.Add(ParseStringList(arguments.GetRequired("items")));
return InvokeAndWriteAsync(
arguments,
client,
output,
new MxCommand
{
Kind = MxCommandKind.SubscribeBulk,
SubscribeBulk = command,
},
cancellationToken);
}
private static Task<int> UnsubscribeBulkAsync(
CliArguments arguments,
IMxGatewayCliClient client,
TextWriter output,
CancellationToken cancellationToken)
{
UnsubscribeBulkCommand command = new()
{
ServerHandle = arguments.GetInt32("server-handle"),
};
command.ItemHandles.Add(ParseInt32List(arguments.GetRequired("item-handles")));
return InvokeAndWriteAsync(
arguments,
client,
output,
new MxCommand
{
Kind = MxCommandKind.UnsubscribeBulk,
UnsubscribeBulk = command,
},
cancellationToken);
}
private static Task<int> WriteAsync(
CliArguments arguments,
IMxGatewayCliClient client,
@@ -736,12 +788,40 @@ public static class MxGatewayClientCli
or "register"
or "add-item"
or "advise"
or "subscribe-bulk"
or "unsubscribe-bulk"
or "stream-events"
or "write"
or "write2"
or "smoke";
}
private static IReadOnlyList<string> ParseStringList(string value)
{
string[] items = value
.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (items.Length is 0)
{
throw new ArgumentException("At least one item is required.");
}
return items;
}
private static IReadOnlyList<int> ParseInt32List(string value)
{
string[] items = value
.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
if (items.Length is 0)
{
throw new ArgumentException("At least one item handle is required.");
}
return items
.Select(item => int.Parse(item, CultureInfo.InvariantCulture))
.ToArray();
}
private static string CreateCorrelationId()
{
return Guid.NewGuid().ToString("N");
@@ -756,6 +836,8 @@ public static class MxGatewayClientCli
writer.WriteLine("mxgw-dotnet register --session-id <id> --client-name <name> [--json]");
writer.WriteLine("mxgw-dotnet add-item --session-id <id> --server-handle <n> --item <ref> [--json]");
writer.WriteLine("mxgw-dotnet advise --session-id <id> --server-handle <n> --item-handle <n> [--json]");
writer.WriteLine("mxgw-dotnet subscribe-bulk --session-id <id> --server-handle <n> --items <ref,ref> [--json]");
writer.WriteLine("mxgw-dotnet unsubscribe-bulk --session-id <id> --server-handle <n> --item-handles <n,n> [--json]");
writer.WriteLine("mxgw-dotnet stream-events --session-id <id> [--max-events <n>] [--json]");
writer.WriteLine("mxgw-dotnet write --session-id <id> --server-handle <n> --item-handle <n> --type <type> --value <value> [--json]");
writer.WriteLine("mxgw-dotnet write2 --session-id <id> --server-handle <n> --item-handle <n> --type <type> --value <value> [--timestamp <iso>] [--json]");
+7 -4
View File
@@ -76,10 +76,13 @@ client, err := mxgateway.Dial(ctx, mxgateway.Options{
```
`Client.OpenSession` returns a `Session` with helpers for `Register`,
`AddItem`, `AddItem2`, `Advise`, `Write`, `Events`, and `Close`. Raw protobuf
messages remain available through the `mxgateway` package aliases and the
`Raw` helper methods. Typed errors support `errors.As` for `GatewayError`,
`CommandError`, and `MxAccessError`; command errors preserve the raw reply.
`AddItem`, `AddItem2`, `Advise`, `Write`, `Events`, and `Close`. Prefer
`SubscribeEvents` or `SubscribeEventsAfter` for long-running streams because the
returned subscription owns cancellation and exposes `Close` for deterministic
goroutine cleanup. Raw protobuf messages remain available through the
`mxgateway` package aliases and the `Raw` helper methods. Typed errors support
`errors.As` for `GatewayError`, `CommandError`, and `MxAccessError`; command
errors preserve the raw reply.
## CLI
+107 -2
View File
@@ -9,6 +9,7 @@ import (
"io"
"os"
"strconv"
"strings"
"time"
"gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/mxgateway"
@@ -77,6 +78,10 @@ func runWithIO(ctx context.Context, args []string, stdout, stderr io.Writer) err
return runAddItem(ctx, args[1:], stdout, stderr)
case "advise":
return runAdvise(ctx, args[1:], stdout, stderr)
case "subscribe-bulk":
return runSubscribeBulk(ctx, args[1:], stdout, stderr)
case "unsubscribe-bulk":
return runUnsubscribeBulk(ctx, args[1:], stdout, stderr)
case "write":
return runWrite(ctx, args[1:], stdout, stderr)
case "stream-events":
@@ -268,6 +273,60 @@ func runAdvise(ctx context.Context, args []string, stdout, stderr io.Writer) err
return writeCommandOutput(stdout, *jsonOutput, "advise", options, reply, err)
}
func runSubscribeBulk(ctx context.Context, args []string, stdout, stderr io.Writer) error {
flags := flag.NewFlagSet("subscribe-bulk", flag.ContinueOnError)
flags.SetOutput(stderr)
common := bindCommonFlags(flags)
jsonOutput := flags.Bool("json", false, "write JSON output")
sessionID := flags.String("session-id", "", "gateway session id")
serverHandle := flags.Int("server-handle", 0, "MXAccess server handle")
items := flags.String("items", "", "comma-separated item definitions")
if err := flags.Parse(args); err != nil {
return err
}
if *sessionID == "" || *items == "" {
return errors.New("session-id and items are required")
}
client, options, err := dialForCommand(ctx, common)
if err != nil {
return err
}
defer client.Close()
session := mxgateway.NewSessionForID(client, *sessionID)
results, err := session.SubscribeBulk(ctx, int32(*serverHandle), parseStringList(*items))
return writeBulkOutput(stdout, *jsonOutput, "subscribe-bulk", options, results, err)
}
func runUnsubscribeBulk(ctx context.Context, args []string, stdout, stderr io.Writer) error {
flags := flag.NewFlagSet("unsubscribe-bulk", flag.ContinueOnError)
flags.SetOutput(stderr)
common := bindCommonFlags(flags)
jsonOutput := flags.Bool("json", false, "write JSON output")
sessionID := flags.String("session-id", "", "gateway session id")
serverHandle := flags.Int("server-handle", 0, "MXAccess server handle")
itemHandles := flags.String("item-handles", "", "comma-separated item handles")
if err := flags.Parse(args); err != nil {
return err
}
if *sessionID == "" || *itemHandles == "" {
return errors.New("session-id and item-handles are required")
}
client, options, err := dialForCommand(ctx, common)
if err != nil {
return err
}
defer client.Close()
session := mxgateway.NewSessionForID(client, *sessionID)
results, err := session.UnsubscribeBulk(ctx, int32(*serverHandle), parseInt32List(*itemHandles))
return writeBulkOutput(stdout, *jsonOutput, "unsubscribe-bulk", options, results, err)
}
func runWrite(ctx context.Context, args []string, stdout, stderr io.Writer) error {
flags := flag.NewFlagSet("write", flag.ContinueOnError)
flags.SetOutput(stderr)
@@ -328,10 +387,12 @@ func runStreamEvents(ctx context.Context, args []string, stdout, stderr io.Write
session := mxgateway.NewSessionForID(client, *sessionID)
streamCtx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
events, err := session.EventsAfter(streamCtx, *after)
subscription, err := session.SubscribeEventsAfter(streamCtx, *after)
if err != nil {
return err
}
defer subscription.Close()
events := subscription.Events()
count := 0
for result := range events {
@@ -426,6 +487,35 @@ func closeSmokeSession(ctx context.Context, session *mxgateway.Session, primaryE
return closeErr
}
func parseStringList(value string) []string {
parts := strings.Split(value, ",")
items := make([]string, 0, len(parts))
for _, part := range parts {
item := strings.TrimSpace(part)
if item != "" {
items = append(items, item)
}
}
return items
}
func parseInt32List(value string) []int32 {
parts := strings.Split(value, ",")
items := make([]int32, 0, len(parts))
for _, part := range parts {
item := strings.TrimSpace(part)
if item == "" {
continue
}
parsed, err := strconv.ParseInt(item, 10, 32)
if err != nil {
panic(err)
}
items = append(items, int32(parsed))
}
return items
}
func bindCommonFlags(flags *flag.FlagSet) *commonOptions {
common := &commonOptions{}
flags.StringVar(&common.Endpoint, "endpoint", "localhost:5000", "gateway endpoint")
@@ -527,6 +617,21 @@ func writeCommandOutput(stdout io.Writer, jsonOutput bool, command string, optio
return nil
}
func writeBulkOutput(stdout io.Writer, jsonOutput bool, command string, options commonOptions, results []*mxgateway.SubscribeResult, err error) error {
if err != nil {
return err
}
if jsonOutput {
return writeJSON(stdout, map[string]any{
"command": command,
"options": options,
"results": results,
})
}
fmt.Fprintln(stdout, len(results))
return nil
}
func writeJSON(writer io.Writer, value any) error {
encoder := json.NewEncoder(writer)
encoder.SetIndent("", " ")
@@ -546,5 +651,5 @@ type protojsonMessage interface {
}
func writeUsage(writer io.Writer) {
fmt.Fprintln(writer, "usage: mxgw-go <version|open-session|close-session|register|add-item|advise|write|stream-events|smoke>")
fmt.Fprintln(writer, "usage: mxgw-go <version|open-session|close-session|register|add-item|advise|subscribe-bulk|unsubscribe-bulk|write|stream-events|smoke>")
}
@@ -77,6 +77,42 @@ func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) {
}
}
func TestEventSubscriptionCloseStopsStream(t *testing.T) {
fake := &fakeGatewayServer{
streamStarted: make(chan struct{}),
streamDone: make(chan struct{}),
}
client, cleanup := newBufconnClient(t, fake)
defer cleanup()
session := NewSessionForID(client, "session-1")
subscription, err := session.SubscribeEvents(context.Background())
if err != nil {
t.Fatalf("SubscribeEvents() error = %v", err)
}
<-fake.streamStarted
first := <-subscription.Events()
if first.Err != nil {
t.Fatalf("first event error = %v", first.Err)
}
subscription.Close()
select {
case <-fake.streamDone:
case <-time.After(2 * time.Second):
t.Fatal("event stream did not stop after subscription close")
}
select {
case _, ok := <-subscription.Events():
if ok {
t.Fatal("subscription channel remained open after close")
}
case <-time.After(2 * time.Second):
t.Fatal("subscription channel did not close")
}
}
func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) {
fake := &fakeGatewayServer{
invokeReply: &pb.MxCommandReply{
@@ -235,6 +271,7 @@ type fakeGatewayServer struct {
openAuth string
streamAuth string
streamStarted chan struct{}
streamDone chan struct{}
invokeReply *pb.MxCommandReply
invokeRequest *pb.MxCommandRequest
}
@@ -277,6 +314,9 @@ func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest
func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error {
s.streamAuth = authorizationFromContext(stream.Context())
if s.streamDone != nil {
defer close(s.streamDone)
}
if s.streamStarted != nil {
close(s.streamStarted)
}
+51 -5
View File
@@ -22,6 +22,30 @@ type EventResult struct {
Err error
}
// EventSubscription owns a running gateway event stream.
type EventSubscription struct {
results <-chan EventResult
cancel context.CancelFunc
done <-chan struct{}
once sync.Once
}
// Events returns the stream results channel.
func (s *EventSubscription) Events() <-chan EventResult {
return s.results
}
// Close cancels the stream and waits for the receive goroutine to stop.
func (s *EventSubscription) Close() {
if s == nil {
return
}
s.once.Do(func() {
s.cancel()
<-s.done
})
}
// Session represents one gateway-backed MXAccess session.
type Session struct {
client *Client
@@ -394,34 +418,56 @@ func (s *Session) Events(ctx context.Context) (<-chan EventResult, error) {
// EventsAfter streams ordered session events after the given worker sequence.
func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (<-chan EventResult, error) {
stream, err := s.client.StreamEventsRaw(ctx, &pb.StreamEventsRequest{
subscription, err := s.SubscribeEventsAfter(ctx, afterWorkerSequence)
if err != nil {
return nil, err
}
return subscription.Events(), nil
}
// SubscribeEvents starts an owned event subscription.
func (s *Session) SubscribeEvents(ctx context.Context) (*EventSubscription, error) {
return s.SubscribeEventsAfter(ctx, 0)
}
// SubscribeEventsAfter starts an owned event subscription after the given worker sequence.
func (s *Session) SubscribeEventsAfter(ctx context.Context, afterWorkerSequence uint64) (*EventSubscription, error) {
streamCtx, cancel := context.WithCancel(ctx)
stream, err := s.client.StreamEventsRaw(streamCtx, &pb.StreamEventsRequest{
SessionId: s.ID(),
AfterWorkerSequence: afterWorkerSequence,
})
if err != nil {
cancel()
return nil, err
}
results := make(chan EventResult, 16)
done := make(chan struct{})
go func() {
defer close(results)
defer close(done)
for {
event, err := stream.Recv()
if err == nil {
if !sendEventResult(ctx, results, EventResult{Event: event}) {
if !sendEventResult(streamCtx, results, EventResult{Event: event}) {
return
}
continue
}
if err == io.EOF || status.Code(err) == codes.Canceled || ctx.Err() != nil {
if err == io.EOF || status.Code(err) == codes.Canceled || streamCtx.Err() != nil {
return
}
sendEventResult(ctx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}})
sendEventResult(streamCtx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}})
return
}
}()
return results, nil
return &EventSubscription{
results: results,
cancel: cancel,
done: done,
}, nil
}
func ensureBulkSize(name string, length int) error {
@@ -12,7 +12,9 @@ import com.google.protobuf.util.JsonFormat;
import java.io.PrintWriter;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import mxaccess_gateway.v1.MxaccessGateway.CloseSessionRequest;
@@ -20,6 +22,7 @@ import mxaccess_gateway.v1.MxaccessGateway.MxCommandReply;
import mxaccess_gateway.v1.MxaccessGateway.MxEvent;
import mxaccess_gateway.v1.MxaccessGateway.MxValue;
import mxaccess_gateway.v1.MxaccessGateway.OpenSessionRequest;
import mxaccess_gateway.v1.MxaccessGateway.SubscribeResult;
import picocli.CommandLine;
import picocli.CommandLine.Command;
import picocli.CommandLine.Mixin;
@@ -75,6 +78,8 @@ public final class MxGatewayCli implements Callable<Integer> {
commandLine.addSubcommand("register", new RegisterCommand(clientFactory));
commandLine.addSubcommand("add-item", new AddItemCommand(clientFactory));
commandLine.addSubcommand("advise", new AdviseCommand(clientFactory));
commandLine.addSubcommand("subscribe-bulk", new SubscribeBulkCommand(clientFactory));
commandLine.addSubcommand("unsubscribe-bulk", new UnsubscribeBulkCommand(clientFactory));
commandLine.addSubcommand("write", new WriteCommand(clientFactory));
commandLine.addSubcommand("stream-events", new StreamEventsCommand(clientFactory));
commandLine.addSubcommand("smoke", new SmokeCommand(clientFactory));
@@ -243,6 +248,58 @@ public final class MxGatewayCli implements Callable<Integer> {
}
}
@Command(name = "subscribe-bulk", description = "Invokes MXAccess SubscribeBulk.")
static final class SubscribeBulkCommand extends GatewayCommand {
@Option(names = "--session-id", required = true, description = "Gateway session id.")
String sessionId;
@Option(names = "--server-handle", required = true, description = "MXAccess server handle.")
int serverHandle;
@Option(names = "--items", required = true, description = "Comma-separated item definitions.")
String items;
SubscribeBulkCommand(MxGatewayCliClientFactory clientFactory) {
super(clientFactory);
}
@Override
public Integer call() {
try (MxGatewayCliClient client = clientFactory.connect(common.resolved())) {
List<SubscribeResult> results =
client.session(sessionId).subscribeBulk(serverHandle, parseStringList(items));
writeBulkOutput("subscribe-bulk", common, json, results);
}
return 0;
}
}
@Command(name = "unsubscribe-bulk", description = "Invokes MXAccess UnsubscribeBulk.")
static final class UnsubscribeBulkCommand extends GatewayCommand {
@Option(names = "--session-id", required = true, description = "Gateway session id.")
String sessionId;
@Option(names = "--server-handle", required = true, description = "MXAccess server handle.")
int serverHandle;
@Option(names = "--item-handles", required = true, description = "Comma-separated item handles.")
String itemHandles;
UnsubscribeBulkCommand(MxGatewayCliClientFactory clientFactory) {
super(clientFactory);
}
@Override
public Integer call() {
try (MxGatewayCliClient client = clientFactory.connect(common.resolved())) {
List<SubscribeResult> results =
client.session(sessionId).unsubscribeBulk(serverHandle, parseIntList(itemHandles));
writeBulkOutput("unsubscribe-bulk", common, json, results);
}
return 0;
}
}
@Command(name = "write", description = "Invokes MXAccess Write.")
static final class WriteCommand extends GatewayCommand {
@Option(names = "--session-id", required = true, description = "Gateway session id.")
@@ -454,6 +511,10 @@ public final class MxGatewayCli implements Callable<Integer> {
MxCommandReply writeRaw(int serverHandle, int itemHandle, MxValue value, int userId);
List<SubscribeResult> subscribeBulk(int serverHandle, List<String> items);
List<SubscribeResult> unsubscribeBulk(int serverHandle, List<Integer> itemHandles);
MxEventStream streamEventsAfter(long afterWorkerSequence);
}
@@ -535,6 +596,16 @@ public final class MxGatewayCli implements Callable<Integer> {
return session.writeRaw(serverHandle, itemHandle, value, userId);
}
@Override
public List<SubscribeResult> subscribeBulk(int serverHandle, List<String> items) {
return session.subscribeBulk(serverHandle, items);
}
@Override
public List<SubscribeResult> unsubscribeBulk(int serverHandle, List<Integer> itemHandles) {
return session.unsubscribeBulk(serverHandle, itemHandles);
}
@Override
public MxEventStream streamEventsAfter(long afterWorkerSequence) {
return session.streamEventsAfter(afterWorkerSequence);
@@ -559,6 +630,30 @@ public final class MxGatewayCli implements Callable<Integer> {
out.println(textSupplier.get());
}
private static void writeBulkOutput(
String command, CommonOptions common, boolean json, List<SubscribeResult> results) {
PrintWriter out = common.spec.commandLine().getOut();
if (json) {
Map<String, Object> output = new LinkedHashMap<>();
output.put("command", command);
output.put("options", common.redactedJsonMap());
output.put("results", results.stream().map(MxGatewayCli::subscribeResultMap).toList());
out.println(jsonObject(output));
return;
}
out.println(results.size());
}
private static Map<String, Object> subscribeResultMap(SubscribeResult result) {
Map<String, Object> values = new LinkedHashMap<>();
values.put("serverHandle", result.getServerHandle());
values.put("tagAddress", result.getTagAddress());
values.put("itemHandle", result.getItemHandle());
values.put("wasSuccessful", result.getWasSuccessful());
values.put("errorMessage", result.getErrorMessage());
return values;
}
private static MxValue parseValue(String type, String text) {
return switch (type) {
case "bool" -> MxValues.boolValue(Boolean.parseBoolean(text));
@@ -571,6 +666,17 @@ public final class MxGatewayCli implements Callable<Integer> {
};
}
private static List<String> parseStringList(String value) {
return Arrays.stream(value.split(","))
.map(String::trim)
.filter(item -> !item.isBlank())
.toList();
}
private static List<Integer> parseIntList(String value) {
return parseStringList(value).stream().map(Integer::parseInt).toList();
}
private static Duration parseDuration(String value) {
if (value == null || value.isBlank()) {
return Duration.ofSeconds(30);
@@ -630,6 +736,20 @@ public final class MxGatewayCli implements Callable<Integer> {
if (value instanceof Map<?, ?> map) {
return jsonObject((Map<String, Object>) map);
}
if (value instanceof Iterable<?> iterable) {
StringBuilder builder = new StringBuilder();
builder.append('[');
boolean first = true;
for (Object item : iterable) {
if (!first) {
builder.append(',');
}
first = false;
builder.append(jsonValue(item));
}
builder.append(']');
return builder.toString();
}
return jsonString(value.toString());
}
@@ -6,6 +6,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.List;
import mxaccess_gateway.v1.MxaccessGateway.AddItemReply;
import mxaccess_gateway.v1.MxaccessGateway.CloseSessionReply;
import mxaccess_gateway.v1.MxaccessGateway.CloseSessionRequest;
@@ -19,6 +21,7 @@ import mxaccess_gateway.v1.MxaccessGateway.ProtocolStatus;
import mxaccess_gateway.v1.MxaccessGateway.ProtocolStatusCode;
import mxaccess_gateway.v1.MxaccessGateway.RegisterReply;
import mxaccess_gateway.v1.MxaccessGateway.SessionState;
import mxaccess_gateway.v1.MxaccessGateway.SubscribeResult;
import org.junit.jupiter.api.Test;
final class MxGatewayCliTests {
@@ -100,6 +103,44 @@ final class MxGatewayCliTests {
assertTrue(run.output().contains("\"itemHandle\":7"));
}
@Test
void subscribeBulkCommandPrintsResults() {
CliRun run = execute(
new FakeClientFactory(),
"subscribe-bulk",
"--session-id",
"session-cli",
"--server-handle",
"42",
"--items",
"TestMachine_001.TestChangingInt,TestMachine_002.TestChangingInt",
"--json");
assertEquals(0, run.exitCode());
assertTrue(run.output().contains("\"command\":\"subscribe-bulk\""));
assertTrue(run.output().contains("\"itemHandle\":100"));
assertTrue(run.output().contains("\"tagAddress\":\"TestMachine_002.TestChangingInt\""));
}
@Test
void unsubscribeBulkCommandPrintsResults() {
CliRun run = execute(
new FakeClientFactory(),
"unsubscribe-bulk",
"--session-id",
"session-cli",
"--server-handle",
"42",
"--item-handles",
"100,101",
"--json");
assertEquals(0, run.exitCode());
assertTrue(run.output().contains("\"command\":\"unsubscribe-bulk\""));
assertTrue(run.output().contains("\"itemHandle\":101"));
assertTrue(run.output().contains("\"wasSuccessful\":true"));
}
private static CliRun execute(MxGatewayCli.MxGatewayCliClientFactory factory, String... args) {
StringWriter output = new StringWriter();
StringWriter errors = new StringWriter();
@@ -227,6 +268,33 @@ final class MxGatewayCliTests {
.build();
}
@Override
public List<SubscribeResult> subscribeBulk(int serverHandle, List<String> items) {
List<SubscribeResult> results = new ArrayList<>();
for (int index = 0; index < items.size(); index++) {
results.add(SubscribeResult.newBuilder()
.setServerHandle(serverHandle)
.setTagAddress(items.get(index))
.setItemHandle(100 + index)
.setWasSuccessful(true)
.build());
}
return results;
}
@Override
public List<SubscribeResult> unsubscribeBulk(int serverHandle, List<Integer> itemHandles) {
List<SubscribeResult> results = new ArrayList<>();
for (Integer itemHandle : itemHandles) {
results.add(SubscribeResult.newBuilder()
.setServerHandle(serverHandle)
.setItemHandle(itemHandle)
.setWasSuccessful(true)
.build());
}
return results;
}
@Override
public com.dohertylan.mxgateway.client.MxEventStream streamEventsAfter(long afterWorkerSequence) {
throw new UnsupportedOperationException("stream-events is covered by client tests");
@@ -150,6 +150,40 @@ def advise(**kwargs: Any) -> None:
_run(_advise(**kwargs), output_json=kwargs["output_json"], secrets=_secrets(kwargs))
@main.command("subscribe-bulk")
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
@click.option("--items", required=True, help="Comma-separated MXAccess item definitions.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def subscribe_bulk(**kwargs: Any) -> None:
"""Invoke MXAccess SubscribeBulk."""
_run(
_subscribe_bulk(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command("unsubscribe-bulk")
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@click.option("--server-handle", required=True, type=int, help="MXAccess server handle.")
@click.option("--item-handles", required=True, help="Comma-separated MXAccess item handles.")
@click.option("--correlation-id", default="", help="Client correlation id.")
@click.option("--json", "output_json", is_flag=True, help="Emit JSON output.")
def unsubscribe_bulk(**kwargs: Any) -> None:
"""Invoke MXAccess UnsubscribeBulk."""
_run(
_unsubscribe_bulk(**kwargs),
output_json=kwargs["output_json"],
secrets=_secrets(kwargs),
)
@main.command("stream-events")
@gateway_options
@click.option("--session-id", required=True, help="Gateway session id.")
@@ -282,6 +316,28 @@ async def _advise(**kwargs: Any) -> dict[str, Any]:
return {"ok": True}
async def _subscribe_bulk(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
results = await session.subscribe_bulk(
kwargs["server_handle"],
_parse_string_list(kwargs["items"]),
correlation_id=kwargs["correlation_id"],
)
return {"results": [_message_dict(result) for result in results]}
async def _unsubscribe_bulk(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
results = await session.unsubscribe_bulk(
kwargs["server_handle"],
_parse_int_list(kwargs["item_handles"]),
correlation_id=kwargs["correlation_id"],
)
return {"results": [_message_dict(result) for result in results]}
async def _stream_events(**kwargs: Any) -> dict[str, Any]:
async with await _connect(kwargs) as client:
session = _session(client, kwargs["session_id"])
@@ -470,6 +526,20 @@ def _parse_datetime(raw_value: str) -> datetime:
return parsed
def _parse_string_list(raw_value: str) -> list[str]:
values = [item.strip() for item in raw_value.split(",") if item.strip()]
if not values:
raise click.BadParameter("at least one item is required", param_hint="--items")
return values
def _parse_int_list(raw_value: str) -> list[int]:
values = [item.strip() for item in raw_value.split(",") if item.strip()]
if not values:
raise click.BadParameter("at least one item handle is required", param_hint="--item-handles")
return [int(item) for item in values]
def _message_dict(message: Any) -> dict[str, Any]:
return MessageToDict(
message,
+78 -1
View File
@@ -92,6 +92,30 @@ enum Command {
#[arg(long)]
json: bool,
},
SubscribeBulk {
#[command(flatten)]
connection: ConnectionArgs,
#[arg(long)]
session_id: String,
#[arg(long)]
server_handle: i32,
#[arg(long, value_delimiter = ',')]
items: Vec<String>,
#[arg(long)]
json: bool,
},
UnsubscribeBulk {
#[command(flatten)]
connection: ConnectionArgs,
#[arg(long)]
session_id: String,
#[arg(long)]
server_handle: i32,
#[arg(long, value_delimiter = ',')]
item_handles: Vec<i32>,
#[arg(long)]
json: bool,
},
StreamEvents {
#[command(flatten)]
connection: ConnectionArgs,
@@ -103,6 +127,8 @@ enum Command {
max_events: usize,
#[arg(long)]
json: bool,
#[arg(long)]
jsonl: bool,
},
Write {
#[command(flatten)]
@@ -226,7 +252,7 @@ async fn main() -> ExitCode {
async fn run(cli: Cli) -> Result<(), Error> {
match cli.command {
Command::Version { json } => print_version(json),
Command::Version { json, .. } => print_version(json),
Command::Ping {
connection,
message,
@@ -323,6 +349,30 @@ async fn run(cli: Cli) -> Result<(), Error> {
session.advise(server_handle, item_handle).await?;
print_ok("advise", json);
}
Command::SubscribeBulk {
connection,
session_id,
server_handle,
items,
json,
} => {
let session = session_for(connection, session_id).await?;
let results = session.subscribe_bulk(server_handle, items).await?;
print_bulk_results("subscribe-bulk", &results, json);
}
Command::UnsubscribeBulk {
connection,
session_id,
server_handle,
item_handles,
json,
} => {
let session = session_for(connection, session_id).await?;
let results = session
.unsubscribe_bulk(server_handle, item_handles)
.await?;
print_bulk_results("unsubscribe-bulk", &results, json);
}
Command::StreamEvents {
connection,
session_id,
@@ -527,6 +577,33 @@ fn print_ok(operation: &str, use_json: bool) {
}
}
fn print_bulk_results(
operation: &str,
results: &[mxgateway_client::generated::mxaccess_gateway::v1::SubscribeResult],
use_json: bool,
) {
if use_json {
let results_json: Vec<_> = results
.iter()
.map(|result| {
json!({
"serverHandle": result.server_handle,
"tagAddress": result.tag_address,
"itemHandle": result.item_handle,
"wasSuccessful": result.was_successful,
"errorMessage": result.error_message,
})
})
.collect();
println!(
"{}",
json!({ "operation": operation, "results": results_json })
);
} else {
println!("{}", results.len());
}
}
fn parse_value(value_type: CliValueType, value: &str) -> Result<MxValue, Error> {
let parsed = match value_type {
CliValueType::Bool => MxValue::bool(parse_cli_value(value)?),
+3 -2
View File
@@ -113,11 +113,12 @@ ordering and avoids competing consumers.
| Option | Default | Description |
|--------|---------|-------------|
| `MxGateway:Events:QueueCapacity` | `10000` | Capacity for bounded per-session event queues used by the gateway worker event channel and the public gRPC event stream queue. |
| `MxGateway:Events:BackpressurePolicy` | `FailFast` | Event backpressure behavior. `FailFast` is the only supported value. |
| `MxGateway:Events:BackpressurePolicy` | `FailFast` | Event backpressure behavior. `FailFast` faults the session on public stream queue overflow. `DisconnectSubscriber` disconnects only the slow stream. |
`QueueCapacity` must be greater than zero. With `FailFast`, queue overflow
faults the affected worker or session instead of silently dropping MXAccess
events.
events. With `DisconnectSubscriber`, public gRPC stream overflow terminates only
the affected stream while the MXAccess session remains active.
## Dashboard Options
+6 -3
View File
@@ -101,9 +101,10 @@ powershell -ExecutionPolicy Bypass -File scripts/discover-testmachine-tags.ps1 -
`scripts/run-client-e2e-tests.ps1` drives the .NET, Go, Rust, Python, and Java
client CLIs through a live gateway session. For each client it opens one
session, registers, adds and advises every discovered test tag, reads a bounded
event stream, then closes the session in a `finally` path. The script writes a
JSON report under `artifacts/e2e/`.
session, registers, verifies `SubscribeBulk` and `UnsubscribeBulk` on a bounded
tag subset, adds and advises every discovered test tag, reads a bounded event
stream, then closes the session in a `finally` path. The script writes a JSON
report under `artifacts/e2e/`.
Build the gateway and worker, start the gateway, and provide a valid API key
before running the client e2e script:
@@ -117,7 +118,9 @@ Useful runner options:
```powershell
powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -Clients dotnet,python -MachineStart 1 -MachineEnd 2
powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -BulkTagCount 10
powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -SkipStream
powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -SkipBulk
powershell -ExecutionPolicy Bypass -File scripts/run-client-e2e-tests.ps1 -Endpoint localhost:5000 -ApiKeyEnv MXGATEWAY_API_KEY
```
+140
View File
@@ -16,7 +16,9 @@ param(
[string]$SqlServer = "localhost",
[string]$Database = "ZB",
[int]$EventLimit = 5,
[int]$BulkTagCount = 6,
[switch]$SkipStream,
[switch]$SkipBulk,
[switch]$DryRun,
[string]$ReportPath
)
@@ -50,6 +52,10 @@ if ($Attributes.Count -eq 0) {
throw "At least one attribute is required."
}
if ($BulkTagCount -lt 1) {
throw "BulkTagCount must be greater than zero."
}
foreach ($client in $Clients) {
if ($validClients -notcontains $client) {
throw "Unsupported client '$client'. Supported clients: $($validClients -join ', ')."
@@ -237,6 +243,74 @@ function Get-StreamEventCount {
}
}
function Get-PropertyValue {
param(
[object]$Object,
[string[]]$Names
)
if ($null -eq $Object) {
return $null
}
foreach ($name in $Names) {
$property = $Object.PSObject.Properties[$name]
if ($null -ne $property) {
return $property.Value
}
}
return $null
}
function Get-BulkResults {
param(
[string]$Client,
[string]$Operation,
[object]$Json
)
if ($Client -in @("go", "rust", "python", "java")) {
return @(Get-PropertyValue -Object $Json -Names @("results"))
}
$replyName = if ($Operation -eq "subscribe-bulk") { "subscribeBulk" } else { "unsubscribeBulk" }
$reply = Get-PropertyValue -Object $Json -Names @($replyName)
return @(Get-PropertyValue -Object $reply -Names @("results"))
}
function Get-BulkItemHandles {
param([object[]]$Results)
return @($Results | ForEach-Object {
[int](Get-PropertyValue -Object $_ -Names @("itemHandle", "item_handle"))
} | Where-Object {
$_ -gt 0
})
}
function Assert-BulkResults {
param(
[string]$Client,
[string]$Operation,
[object[]]$Results,
[int]$ExpectedCount
)
if ($Results.Count -ne $ExpectedCount) {
throw "$Client $Operation returned $($Results.Count) result(s); expected $ExpectedCount."
}
foreach ($result in $Results) {
$success = Get-PropertyValue -Object $result -Names @("wasSuccessful", "was_successful")
if ($null -ne $success -and -not [bool]$success) {
$tagAddress = Get-PropertyValue -Object $result -Names @("tagAddress", "tag_address")
$errorMessage = Get-PropertyValue -Object $result -Names @("errorMessage", "error_message")
throw "$Client $Operation failed for '$tagAddress': $errorMessage"
}
}
}
function Get-ClientCommand {
param(
[string]$Client,
@@ -266,6 +340,10 @@ function Get-ClientCommand {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item)
} elseif ($Operation -eq "advise") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)")
} elseif ($Operation -eq "subscribe-bulk") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items)
} elseif ($Operation -eq "unsubscribe-bulk") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles)
} elseif ($Operation -eq "stream-events") {
$arguments += @("--session-id", $Values.sessionId, "--max-events", "$EventLimit")
} elseif ($Operation -eq "close-session") {
@@ -289,6 +367,10 @@ function Get-ClientCommand {
$arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-item", $Values.item)
} elseif ($Operation -eq "advise") {
$arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-item-handle", "$($Values.itemHandle)")
} elseif ($Operation -eq "subscribe-bulk") {
$arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-items", $Values.items)
} elseif ($Operation -eq "unsubscribe-bulk") {
$arguments += @("-session-id", $Values.sessionId, "-server-handle", "$($Values.serverHandle)", "-item-handles", $Values.itemHandles)
} elseif ($Operation -eq "stream-events") {
$arguments += @("-session-id", $Values.sessionId, "-limit", "$EventLimit")
} elseif ($Operation -eq "close-session") {
@@ -311,6 +393,10 @@ function Get-ClientCommand {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item)
} elseif ($Operation -eq "advise") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)")
} elseif ($Operation -eq "subscribe-bulk") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items)
} elseif ($Operation -eq "unsubscribe-bulk") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles)
} elseif ($Operation -eq "stream-events") {
$arguments += @("--session-id", $Values.sessionId, "--max-events", "$EventLimit")
} elseif ($Operation -eq "close-session") {
@@ -334,6 +420,10 @@ function Get-ClientCommand {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item)
} elseif ($Operation -eq "advise") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)")
} elseif ($Operation -eq "subscribe-bulk") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items)
} elseif ($Operation -eq "unsubscribe-bulk") {
$arguments += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles)
} elseif ($Operation -eq "stream-events") {
$arguments += @("--session-id", $Values.sessionId, "--max-events", "$EventLimit", "--timeout", "15")
} elseif ($Operation -eq "close-session") {
@@ -360,6 +450,10 @@ function Get-ClientCommand {
$cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item", $Values.item)
} elseif ($Operation -eq "advise") {
$cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handle", "$($Values.itemHandle)")
} elseif ($Operation -eq "subscribe-bulk") {
$cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--items", $Values.items)
} elseif ($Operation -eq "unsubscribe-bulk") {
$cliArgs += @("--session-id", $Values.sessionId, "--server-handle", "$($Values.serverHandle)", "--item-handles", $Values.itemHandles)
} elseif ($Operation -eq "stream-events") {
$cliArgs += @("--session-id", $Values.sessionId, "--limit", "$EventLimit")
} elseif ($Operation -eq "close-session") {
@@ -389,6 +483,18 @@ function Invoke-ClientOperation {
"open-session" { return [pscustomobject]@{ sessionId = "dry-run-session-$Client"; reply = [pscustomobject]@{ sessionId = "dry-run-session-$Client" } } }
"register" { return [pscustomobject]@{ serverHandle = 1; register = [pscustomobject]@{ serverHandle = 1 }; reply = [pscustomobject]@{ register = [pscustomobject]@{ serverHandle = 1 } } } }
"add-item" { return [pscustomobject]@{ itemHandle = 1; addItem = [pscustomobject]@{ itemHandle = 1 }; reply = [pscustomobject]@{ addItem = [pscustomobject]@{ itemHandle = 1 } } } }
"subscribe-bulk" {
$results = @($Values.items -split "," | ForEach-Object -Begin { $index = 1 } -Process {
[pscustomobject]@{ itemHandle = $index++; tagAddress = $_; wasSuccessful = $true }
})
return [pscustomobject]@{ subscribeBulk = [pscustomobject]@{ results = $results }; results = $results }
}
"unsubscribe-bulk" {
$results = @($Values.itemHandles -split "," | ForEach-Object {
[pscustomobject]@{ itemHandle = [int]$_; wasSuccessful = $true }
})
return [pscustomobject]@{ unsubscribeBulk = [pscustomobject]@{ results = $results }; results = $results }
}
"stream-events" { return [pscustomobject]@{ eventCount = 1; events = @([pscustomobject]@{ workerSequence = 1 }) } }
default { return [pscustomobject]@{ ok = $true; reply = [pscustomobject]@{} } }
}
@@ -425,7 +531,9 @@ $run = [ordered]@{
machineEnd = $MachineEnd
attributes = $Attributes
eventLimit = $EventLimit
bulkTagCount = $BulkTagCount
skipStream = [bool]$SkipStream
skipBulk = [bool]$SkipBulk
startedAt = (Get-Date).ToUniversalTime().ToString("O")
discoveredTags = $tags
clients = @()
@@ -441,6 +549,7 @@ foreach ($client in $Clients) {
language = $client
sessionId = $null
serverHandle = $null
bulk = $null
addedItems = @()
eventCount = 0
closed = $false
@@ -461,6 +570,37 @@ foreach ($client in $Clients) {
$serverHandle = Get-ServerHandle -Client $client -Json $registerJson
$clientResult.serverHandle = $serverHandle
if (-not $SkipBulk) {
$bulkTags = @($tags | Select-Object -First ([Math]::Min($BulkTagCount, $tags.Count)))
$bulkItems = ($bulkTags | ForEach-Object { $_.fullTagReference }) -join ","
$subscribeBulkJson = Invoke-ClientOperation -Client $client -Operation "subscribe-bulk" -Values @{
sessionId = $sessionId
serverHandle = $serverHandle
items = $bulkItems
}
$subscribeResults = @(Get-BulkResults -Client $client -Operation "subscribe-bulk" -Json $subscribeBulkJson)
Assert-BulkResults -Client $client -Operation "subscribe-bulk" -Results $subscribeResults -ExpectedCount $bulkTags.Count
$bulkItemHandles = @(Get-BulkItemHandles -Results $subscribeResults)
if ($bulkItemHandles.Count -ne $bulkTags.Count) {
throw "$client subscribe-bulk returned $($bulkItemHandles.Count) usable item handle(s); expected $($bulkTags.Count)."
}
$unsubscribeBulkJson = Invoke-ClientOperation -Client $client -Operation "unsubscribe-bulk" -Values @{
sessionId = $sessionId
serverHandle = $serverHandle
itemHandles = $bulkItemHandles -join ","
}
$unsubscribeResults = @(Get-BulkResults -Client $client -Operation "unsubscribe-bulk" -Json $unsubscribeBulkJson)
Assert-BulkResults -Client $client -Operation "unsubscribe-bulk" -Results $unsubscribeResults -ExpectedCount $bulkItemHandles.Count
$clientResult.bulk = [ordered]@{
tagCount = $bulkTags.Count
subscribedCount = $subscribeResults.Count
unsubscribedCount = $unsubscribeResults.Count
itemHandles = $bulkItemHandles
}
}
foreach ($tag in $tags) {
$addJson = Invoke-ClientOperation -Client $client -Operation "add-item" -Values @{
sessionId = $sessionId
@@ -2,5 +2,7 @@ namespace MxGateway.Server.Configuration;
public enum EventBackpressurePolicy
{
FailFast
FailFast,
DisconnectSubscriber
}
@@ -108,9 +108,19 @@ public sealed class EventStreamService(
if (!writer.TryWrite(publicEvent))
{
string message = $"Session {session.SessionId} event stream queue overflowed.";
session.MarkFaulted(message);
metrics.QueueOverflow("grpc-event-stream");
metrics.Fault(SessionManagerErrorCode.EventQueueOverflow.ToString());
if (options.Value.Events.BackpressurePolicy == EventBackpressurePolicy.FailFast)
{
session.MarkFaulted(message);
metrics.Fault(SessionManagerErrorCode.EventQueueOverflow.ToString());
}
else
{
logger.LogDebug(
"Disconnecting event stream for session {SessionId} after queue overflow.",
session.SessionId);
}
writer.TryComplete(new SessionManagerException(
SessionManagerErrorCode.EventQueueOverflow,
message));
+13 -13
View File
@@ -1,3 +1,4 @@
using System.Collections.Concurrent;
using System.Diagnostics.Metrics;
namespace MxGateway.Server.Metrics;
@@ -25,8 +26,8 @@ public sealed class GatewayMetrics : IDisposable
private readonly Histogram<double> _commandLatencyHistogram;
private readonly Histogram<double> _eventStreamSendLatencyHistogram;
private readonly Dictionary<string, long> _commandFailuresByMethod = new(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, long> _eventsByFamily = new(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, long> _eventsBySession = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, long> _eventsByFamily = new(StringComparer.OrdinalIgnoreCase);
private readonly ConcurrentDictionary<string, long> _eventsBySession = new(StringComparer.Ordinal);
private readonly Dictionary<string, long> _retryAttemptsByArea = new(StringComparer.OrdinalIgnoreCase);
private int _openSessions;
@@ -173,12 +174,9 @@ public sealed class GatewayMetrics : IDisposable
public void EventReceived(string sessionId, string family)
{
lock (_syncRoot)
{
_eventsReceived++;
Increment(_eventsByFamily, family);
Increment(_eventsBySession, sessionId);
}
Interlocked.Increment(ref _eventsReceived);
Increment(_eventsByFamily, family);
Increment(_eventsBySession, sessionId);
_eventsReceivedCounter.Add(
1,
@@ -225,10 +223,7 @@ public sealed class GatewayMetrics : IDisposable
public void RemoveSessionEvents(string sessionId)
{
lock (_syncRoot)
{
_eventsBySession.Remove(sessionId);
}
_eventsBySession.TryRemove(sessionId, out _);
}
public void QueueOverflow(string queueName)
@@ -296,7 +291,7 @@ public sealed class GatewayMetrics : IDisposable
CommandsStarted: _commandsStarted,
CommandsSucceeded: _commandsSucceeded,
CommandsFailed: _commandsFailed,
EventsReceived: _eventsReceived,
EventsReceived: Interlocked.Read(ref _eventsReceived),
QueueOverflows: _queueOverflows,
Faults: _faults,
WorkerKills: _workerKills,
@@ -359,4 +354,9 @@ public sealed class GatewayMetrics : IDisposable
values.TryGetValue(key, out long currentValue);
values[key] = currentValue + 1;
}
private static void Increment(ConcurrentDictionary<string, long> values, string key)
{
values.AddOrUpdate(key, 1, static (_, currentValue) => currentValue + 1);
}
}
@@ -41,6 +41,9 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory
NamedPipeServerStream? pipe = CreatePipe(session.PipeName);
WorkerProcessHandle? processHandle = null;
IWorkerClient? workerClient = null;
using CancellationTokenSource startupCancellation =
CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
startupCancellation.CancelAfter(session.StartupTimeout);
try
{
session.TransitionTo(SessionState.StartingWorker);
@@ -52,11 +55,11 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory
GatewayContractInfo.WorkerProtocolVersion,
session.Nonce,
pipe),
cancellationToken)
startupCancellation.Token)
.ConfigureAwait(false);
session.TransitionTo(SessionState.WaitingForPipe);
await WaitForPipeConnectionAsync(pipe, session.StartupTimeout, cancellationToken).ConfigureAwait(false);
await WaitForPipeConnectionAsync(pipe, startupCancellation.Token).ConfigureAwait(false);
session.TransitionTo(SessionState.Handshaking);
WorkerFrameProtocolOptions frameOptions = new(
@@ -88,14 +91,23 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory
processHandle = null;
session.TransitionTo(SessionState.InitializingWorker);
await workerClient.StartAsync(cancellationToken).ConfigureAwait(false);
await workerClient.StartAsync(startupCancellation.Token).ConfigureAwait(false);
return workerClient;
}
catch
catch (Exception exception)
{
if (workerClient is not null)
{
try
{
workerClient.Kill("OpenSessionFailed");
}
catch
{
// Preserve the startup failure while still disposing below.
}
await workerClient.DisposeAsync().ConfigureAwait(false);
}
else
@@ -119,6 +131,15 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory
pipe?.Dispose();
}
if (exception is OperationCanceledException
&& startupCancellation.IsCancellationRequested
&& !cancellationToken.IsCancellationRequested)
{
throw new TimeoutException(
$"Worker session {session.SessionId} did not complete startup within {session.StartupTimeout}.",
exception);
}
throw;
}
}
@@ -135,11 +156,8 @@ public sealed class SessionWorkerClientFactory : ISessionWorkerClientFactory
private static async Task WaitForPipeConnectionAsync(
NamedPipeServerStream pipe,
TimeSpan startupTimeout,
CancellationToken cancellationToken)
{
using CancellationTokenSource timeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeout.CancelAfter(startupTimeout);
await pipe.WaitForConnectionAsync(timeout.Token).ConfigureAwait(false);
await pipe.WaitForConnectionAsync(cancellationToken).ConfigureAwait(false);
}
}
+13 -1
View File
@@ -13,6 +13,7 @@ namespace MxGateway.Server.Workers;
public sealed class WorkerClient : IWorkerClient
{
private const string GatewayVersionFallback = "unknown";
private static readonly TimeSpan DisposeTaskTimeout = TimeSpan.FromSeconds(5);
private readonly object _syncRoot = new();
private readonly WorkerClientConnection _connection;
private readonly WorkerClientOptions _options;
@@ -286,8 +287,19 @@ public sealed class WorkerClient : IWorkerClient
WorkerClientErrorCode.GatewayShutdown,
"Worker client was disposed."));
await WaitForBackgroundTasksAsync(CancellationToken.None).ConfigureAwait(false);
await _connection.Stream.DisposeAsync().ConfigureAwait(false);
using CancellationTokenSource disposeTimeout = new(DisposeTaskTimeout);
try
{
await WaitForBackgroundTasksAsync(disposeTimeout.Token).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
_logger.LogWarning(
"Timed out waiting for worker client background tasks to stop for session {SessionId}.",
SessionId);
}
_connection.ProcessHandle?.Dispose();
_pendingCommandSlots.Dispose();
_stopCts.Dispose();
@@ -114,6 +114,37 @@ public sealed class EventStreamServiceTests
Assert.Equal(1, metrics.GetSnapshot().Faults);
}
[Fact]
public async Task StreamEventsAsync_WhenStreamQueueOverflowsWithDisconnectPolicy_LeavesSessionReady()
{
FakeWorkerClient workerClient = new();
GatewaySession session = CreateReadySession(workerClient);
using GatewayMetrics metrics = new();
EventStreamService service = CreateService(
new FakeSessionManager(session),
metrics,
queueCapacity: 1,
backpressurePolicy: EventBackpressurePolicy.DisconnectSubscriber);
workerClient.Events.Add(CreateWorkerEvent(sequence: 1, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 2, MxEventFamily.OnDataChange));
workerClient.Events.Add(CreateWorkerEvent(sequence: 3, MxEventFamily.OnDataChange));
workerClient.CompleteAfterConfiguredEvents = true;
await using IAsyncEnumerator<MxEvent> subscriber = service
.StreamEventsAsync(CreateRequest(session.SessionId), CancellationToken.None)
.GetAsyncEnumerator();
Assert.True(await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
SessionManagerException exception = await Assert.ThrowsAsync<SessionManagerException>(
async () => await subscriber.MoveNextAsync().AsTask().WaitAsync(TestTimeout));
Assert.Equal(SessionManagerErrorCode.EventQueueOverflow, exception.ErrorCode);
Assert.Equal(SessionState.Ready, session.State);
GatewayMetricsSnapshot snapshot = metrics.GetSnapshot();
Assert.Equal(1, snapshot.QueueOverflows);
Assert.Equal(0, snapshot.Faults);
Assert.Equal(1, snapshot.StreamDisconnects);
}
[Fact]
public async Task StreamEventsAsync_DoesNotSynthesizeOperationComplete()
{
@@ -157,7 +188,8 @@ public sealed class EventStreamServiceTests
private static EventStreamService CreateService(
FakeSessionManager sessionManager,
GatewayMetrics? metrics = null,
int queueCapacity = 8)
int queueCapacity = 8,
EventBackpressurePolicy backpressurePolicy = EventBackpressurePolicy.FailFast)
{
return new EventStreamService(
sessionManager,
@@ -166,6 +198,7 @@ public sealed class EventStreamServiceTests
Events = new EventOptions
{
QueueCapacity = queueCapacity,
BackpressurePolicy = backpressurePolicy,
},
}),
new MxAccessGrpcMapper(),
@@ -65,13 +65,33 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests
Assert.True(launcher.Process.IsDisposed);
}
private static GatewayOptions CreateOptions()
[Fact]
public async Task CreateAsync_WhenFakeWorkerNeverSendsReady_TimesOutAndKillsWorker()
{
NeverReadyWorkerProcessLauncher launcher = new();
using GatewayMetrics metrics = new();
SessionWorkerClientFactory factory = new(
launcher,
Options.Create(CreateOptions(startupTimeoutSeconds: 1)),
metrics,
NullLoggerFactory.Instance);
GatewaySession session = CreateSession(startupTimeout: TimeSpan.FromSeconds(1));
TimeoutException exception = await Assert.ThrowsAsync<TimeoutException>(
async () => await factory.CreateAsync(session, CancellationToken.None).WaitAsync(TestTimeout));
Assert.Contains("did not complete startup", exception.Message);
Assert.Equal(1, launcher.Process.KillCount);
Assert.True(launcher.Process.IsDisposed);
}
private static GatewayOptions CreateOptions(int startupTimeoutSeconds = 5)
{
return new GatewayOptions
{
Worker = new WorkerOptions
{
StartupTimeoutSeconds = 5,
StartupTimeoutSeconds = startupTimeoutSeconds,
ShutdownTimeoutSeconds = 5,
HeartbeatIntervalSeconds = 30,
HeartbeatGraceSeconds = 30,
@@ -84,7 +104,7 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests
};
}
private static GatewaySession CreateSession()
private static GatewaySession CreateSession(TimeSpan? startupTimeout = null)
{
return new GatewaySession(
FakeWorkerHarness.DefaultSessionId,
@@ -94,7 +114,7 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests
"test-client",
"fake-worker-session-test",
"client-correlation-1",
TestTimeout,
startupTimeout ?? TestTimeout,
TestTimeout,
TestTimeout,
DateTimeOffset.UtcNow);
@@ -172,6 +192,38 @@ public sealed class SessionWorkerClientFactoryFakeWorkerTests
}
}
private sealed class NeverReadyWorkerProcessLauncher : IWorkerProcessLauncher
{
public FakeWorkerProcess Process { get; } = new(processId: 4680);
public Task<WorkerProcessHandle> LaunchAsync(
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken = default)
{
_ = RunWorkerAsync(request, cancellationToken);
return Task.FromResult(CreateHandle(Process));
}
private async Task RunWorkerAsync(
WorkerProcessLaunchRequest request,
CancellationToken cancellationToken)
{
await using FakeWorkerHarness harness = await FakeWorkerHarness.ConnectToGatewayPipeAsync(
request.SessionId,
request.Nonce,
request.PipeName,
request.ProtocolVersion,
cancellationToken: cancellationToken).ConfigureAwait(false);
_ = await harness.ReadGatewayEnvelopeAsync(cancellationToken).ConfigureAwait(false);
await harness.SendWorkerHelloAsync(
workerProcessId: Process.Id,
workerProtocolVersion: request.ProtocolVersion,
cancellationToken: cancellationToken).ConfigureAwait(false);
await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken).ConfigureAwait(false);
}
}
private static WorkerProcessHandle CreateHandle(IWorkerProcess process)
{
return new WorkerProcessHandle(
@@ -166,7 +166,8 @@ public sealed class WorkerClientTests
await pipePair.DisposeWorkerSideAsync();
await WaitUntilAsync(
() => client.State == WorkerClientState.Faulted,
() => client.State == WorkerClientState.Faulted
&& metrics.GetSnapshot().WorkersRunning == 0,
TestTimeout);
GatewayMetricsSnapshot snapshot = metrics.GetSnapshot();
@@ -174,6 +175,22 @@ public sealed class WorkerClientTests
Assert.Equal(1, snapshot.WorkerExits);
}
[Fact]
public async Task DisposeAsync_WhenPipeReadIsBlocked_ReturnsWithinBoundedTimeout()
{
await using PipePair pipePair = await PipePair.CreateAsync();
WorkerClient client = CreateClient(pipePair);
await CompleteHandshakeAsync(client, pipePair);
DateTimeOffset startedAt = DateTimeOffset.UtcNow;
await client.DisposeAsync().AsTask().WaitAsync(TestTimeout);
TimeSpan elapsed = DateTimeOffset.UtcNow - startedAt;
Assert.True(
elapsed < TimeSpan.FromSeconds(4),
$"DisposeAsync took {elapsed.TotalMilliseconds:N0}ms.");
}
[Fact]
public async Task ReadLoop_WhenHeartbeatArrives_UpdatesLastHeartbeatAndWorkerProcess()
{
@@ -60,4 +60,22 @@ public sealed class GatewayMetricsTests
Assert.Equal("depth", exception.ParamName);
}
[Fact]
public void RemoveSessionEvents_RemovesOnlyThatSession()
{
using GatewayMetrics metrics = new();
metrics.EventReceived("session-1", "OnDataChange");
metrics.EventReceived("session-2", "OnWriteComplete");
metrics.RemoveSessionEvents("session-1");
GatewayMetricsSnapshot snapshot = metrics.GetSnapshot();
Assert.Equal(2, snapshot.EventsReceived);
Assert.False(snapshot.EventsBySession.ContainsKey("session-1"));
Assert.Equal(1, snapshot.EventsBySession["session-2"]);
Assert.Equal(1, snapshot.EventsByFamily["OnDataChange"]);
Assert.Equal(1, snapshot.EventsByFamily["OnWriteComplete"]);
}
}
@@ -304,6 +304,45 @@ public sealed class WorkerPipeSessionTests
await SendShutdownAndWaitAsync(pipePair, runTask, cancellation.Token);
}
[Fact]
public async Task RunAsync_WhenShutdownArrivesDuringCommand_DropsLateReplyAndWritesShutdownAck()
{
using CancellationTokenSource cancellation = new(TimeSpan.FromSeconds(5));
using PipePair pipePair = await PipePair.CreateAsync(cancellation.Token);
FakeRuntimeSession runtime = new()
{
BlockDispatch = true,
};
WorkerPipeSession session = CreatePipeSession(
pipePair.WorkerStream,
runtime,
new WorkerPipeSessionOptions
{
HeartbeatInterval = TimeSpan.FromSeconds(1),
HeartbeatGrace = TimeSpan.FromSeconds(5),
});
Task runTask = session.RunAsync(cancellation.Token);
await CompleteGatewayHandshakeAsync(pipePair, cancellation.Token);
await pipePair.GatewayWriter.WriteAsync(
CreateCommandEnvelope("command-during-shutdown"),
cancellation.Token);
Assert.True(runtime.DispatchStarted.Wait(TimeSpan.FromSeconds(2)));
await pipePair.GatewayWriter
.WriteAsync(CreateShutdownEnvelope(), cancellation.Token);
WorkerEnvelope shutdownAck = await ReadUntilAsync(
pipePair.GatewayReader,
WorkerEnvelope.BodyOneofCase.WorkerShutdownAck,
cancellation.Token);
Assert.Equal(ProtocolStatusCode.Ok, shutdownAck.WorkerShutdownAck.Status.Code);
Task completedTask = await Task.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellation.Token));
Assert.Same(runTask, completedTask);
await runTask;
}
private static WorkerPipeSession CreateSession(
Stream inbound,
Stream outbound,
@@ -440,7 +479,7 @@ public sealed class WorkerPipeSessionTests
Assert.Equal(ProtocolStatusCode.Ok, shutdownAck.WorkerShutdownAck.Status.Code);
Task completedTask = await Task
.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(2), cancellationToken))
.WhenAny(runTask, Task.Delay(TimeSpan.FromSeconds(5), cancellationToken))
.ConfigureAwait(false);
Assert.Same(runTask, completedTask);
@@ -12,17 +12,17 @@ public sealed class MxAccessEventQueueTests
{
MxAccessEventQueue queue = new(capacity: 4);
WorkerEvent first = queue.Enqueue(CreateEvent(MxEventFamily.OnDataChange, itemHandle: 10));
WorkerEvent second = queue.Enqueue(CreateEvent(MxEventFamily.OnWriteComplete, itemHandle: 11));
queue.Enqueue(CreateEvent(MxEventFamily.OnDataChange, itemHandle: 10));
queue.Enqueue(CreateEvent(MxEventFamily.OnWriteComplete, itemHandle: 11));
Assert.Equal(1UL, first.Event.WorkerSequence);
Assert.Equal(2UL, second.Event.WorkerSequence);
Assert.NotNull(first.Event.WorkerTimestamp);
Assert.Equal(2, queue.Count);
Assert.Equal(2UL, queue.LastEventSequence);
Assert.True(queue.TryDequeue(out WorkerEvent? dequeuedFirst));
Assert.True(queue.TryDequeue(out WorkerEvent? dequeuedSecond));
Assert.Equal(1UL, dequeuedFirst?.Event.WorkerSequence);
Assert.Equal(2UL, dequeuedSecond?.Event.WorkerSequence);
Assert.NotNull(dequeuedFirst?.Event.WorkerTimestamp);
Assert.Equal(10, dequeuedFirst?.Event.ItemHandle);
Assert.Equal(11, dequeuedSecond?.Event.ItemHandle);
Assert.False(queue.TryDequeue(out _));
+140 -24
View File
@@ -15,6 +15,7 @@ namespace MxGateway.Worker.Ipc;
public sealed class WorkerPipeSession
{
private static readonly TimeSpan EventDrainInterval = TimeSpan.FromMilliseconds(25);
private static readonly TimeSpan BackgroundTaskStopTimeout = TimeSpan.FromSeconds(1);
private const uint EventDrainBatchSize = 128;
private readonly WorkerFrameProtocolOptions _options;
@@ -24,9 +25,12 @@ public sealed class WorkerPipeSession
private readonly IWorkerLogger? _logger;
private readonly WorkerFrameReader _reader;
private readonly WorkerFrameWriter _writer;
private readonly object _commandTaskGate = new();
private readonly HashSet<Task> _activeCommandTasks = new();
private IWorkerRuntimeSession? _runtimeSession;
private long _nextSequence;
private WorkerState _state = WorkerState.Starting;
private bool _acceptingCommands = true;
private bool _watchdogFaultSent;
private bool _shutdownTimedOut;
@@ -206,18 +210,31 @@ public sealed class WorkerPipeSession
private async Task RunMessageLoopAsync(CancellationToken cancellationToken)
{
using CancellationTokenSource loopCancellation = CancellationTokenSource
.CreateLinkedTokenSource(cancellationToken);
using CancellationTokenSource heartbeatCancellation = CancellationTokenSource
.CreateLinkedTokenSource(cancellationToken);
Task heartbeatTask = RunHeartbeatLoopAsync(heartbeatCancellation.Token);
Task eventDrainTask = RunEventDrainLoopAsync(heartbeatCancellation.Token);
Task<WorkerEnvelope> readTask = _reader.ReadAsync(loopCancellation.Token);
try
{
while (!cancellationToken.IsCancellationRequested)
{
Task<WorkerEnvelope> readTask = _reader.ReadAsync(cancellationToken);
Task completedTask = await Task.WhenAny(readTask, heartbeatTask, eventDrainTask).ConfigureAwait(false);
if (completedTask == heartbeatTask)
if (completedTask == readTask)
{
WorkerEnvelope envelope = await readTask.ConfigureAwait(false);
bool keepReading = await DispatchGatewayEnvelopeAsync(envelope, cancellationToken).ConfigureAwait(false);
if (!keepReading)
{
return;
}
readTask = _reader.ReadAsync(loopCancellation.Token);
}
else if (completedTask == heartbeatTask)
{
await heartbeatTask.ConfigureAwait(false);
}
@@ -225,33 +242,52 @@ public sealed class WorkerPipeSession
{
await eventDrainTask.ConfigureAwait(false);
}
WorkerEnvelope envelope = await readTask.ConfigureAwait(false);
bool keepReading = await DispatchGatewayEnvelopeAsync(envelope, cancellationToken).ConfigureAwait(false);
if (!keepReading)
{
return;
}
}
}
finally
{
loopCancellation.Cancel();
heartbeatCancellation.Cancel();
try
{
await heartbeatTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
}
await ObserveBackgroundTaskStopAsync(heartbeatTask, "Heartbeat").ConfigureAwait(false);
await ObserveBackgroundTaskStopAsync(eventDrainTask, "EventDrain").ConfigureAwait(false);
}
}
try
{
await eventDrainTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
}
private async Task ObserveBackgroundTaskStopAsync(
Task task,
string taskName)
{
Task completedTask = await Task
.WhenAny(task, Task.Delay(BackgroundTaskStopTimeout))
.ConfigureAwait(false);
if (completedTask != task)
{
_logger?.Error(
"WorkerPipeSessionBackgroundTaskStopTimedOut",
new Dictionary<string, object?>
{
["task"] = taskName,
["timeout_ms"] = BackgroundTaskStopTimeout.TotalMilliseconds,
});
return;
}
try
{
await task.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
}
catch (Exception ex)
{
_logger?.Error(
"WorkerPipeSessionBackgroundTaskStopFailed",
new Dictionary<string, object?>
{
["task"] = taskName,
["exception"] = ex.ToString(),
});
}
}
@@ -300,7 +336,7 @@ public sealed class WorkerPipeSession
switch (envelope.BodyCase)
{
case WorkerEnvelope.BodyOneofCase.WorkerCommand:
_ = ProcessCommandAsync(envelope, cancellationToken);
TryStartCommandTask(envelope, cancellationToken);
return true;
case WorkerEnvelope.BodyOneofCase.WorkerShutdown:
await ShutdownAsync(envelope.WorkerShutdown, cancellationToken).ConfigureAwait(false);
@@ -333,6 +369,11 @@ public sealed class WorkerPipeSession
try
{
MxCommandReply reply = await runtimeSession.DispatchAsync(staCommand).ConfigureAwait(false);
if (_state is not WorkerState.Ready and not WorkerState.ExecutingCommand)
{
return;
}
await _writer
.WriteAsync(
CreateEnvelope(new WorkerCommandReply
@@ -370,11 +411,13 @@ public sealed class WorkerPipeSession
}
TimeSpan gracePeriod = ResolveGracePeriod(shutdown);
StopAcceptingCommands();
try
{
MxAccessShutdownResult result = await runtimeSession
.ShutdownGracefullyAsync(gracePeriod, cancellationToken)
.ConfigureAwait(false);
await WaitForActiveCommandTasksAsync(gracePeriod, cancellationToken).ConfigureAwait(false);
LogShutdownFailures(result.Failures);
await WriteShutdownAckAsync(CreateShutdownAck(result, shutdown), cancellationToken).ConfigureAwait(false);
}
@@ -387,6 +430,79 @@ public sealed class WorkerPipeSession
}
}
private void TryStartCommandTask(
WorkerEnvelope envelope,
CancellationToken cancellationToken)
{
Task commandTask;
lock (_commandTaskGate)
{
if (!_acceptingCommands)
{
return;
}
commandTask = ProcessCommandAsync(envelope, cancellationToken);
_activeCommandTasks.Add(commandTask);
}
_ = ObserveCommandTaskAsync(commandTask);
}
private async Task ObserveCommandTaskAsync(Task commandTask)
{
try
{
await commandTask.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
}
finally
{
lock (_commandTaskGate)
{
_activeCommandTasks.Remove(commandTask);
}
}
}
private void StopAcceptingCommands()
{
lock (_commandTaskGate)
{
_acceptingCommands = false;
}
}
private async Task WaitForActiveCommandTasksAsync(
TimeSpan timeout,
CancellationToken cancellationToken)
{
Task[] activeTasks;
lock (_commandTaskGate)
{
activeTasks = new List<Task>(_activeCommandTasks).ToArray();
}
if (activeTasks.Length == 0)
{
return;
}
Task activeCommandsTask = Task.WhenAll(activeTasks);
Task timeoutTask = Task.Delay(timeout, cancellationToken);
Task completedTask = await Task.WhenAny(activeCommandsTask, timeoutTask).ConfigureAwait(false);
if (completedTask == activeCommandsTask)
{
await activeCommandsTask.ConfigureAwait(false);
return;
}
cancellationToken.ThrowIfCancellationRequested();
throw new TimeoutException($"Worker command tasks did not stop within {timeout}.");
}
private Task WriteShutdownAckAsync(
WorkerShutdownAck shutdownAck,
CancellationToken cancellationToken)
@@ -80,7 +80,7 @@ public sealed class MxAccessEventQueue
}
}
public WorkerEvent Enqueue(MxEvent mxEvent)
public void Enqueue(MxEvent mxEvent)
{
if (mxEvent is null)
{
@@ -109,8 +109,6 @@ public sealed class MxAccessEventQueue
Event = queuedEvent,
};
events.Enqueue(workerEvent);
return workerEvent.Clone();
}
}
@@ -124,7 +122,7 @@ public sealed class MxAccessEventQueue
return false;
}
workerEvent = events.Dequeue().Clone();
workerEvent = events.Dequeue();
return true;
}
}
@@ -144,7 +142,7 @@ public sealed class MxAccessEventQueue
List<WorkerEvent> drained = new(drainCount);
for (int index = 0; index < drainCount; index++)
{
drained.Add(events.Dequeue().Clone());
drained.Add(events.Dequeue());
}
return drained;