Improve gateway reliability and dashboard docs

This commit is contained in:
Joseph Doherty
2026-04-28 00:13:22 -04:00
parent bd4a09a35e
commit 4fc355b357
61 changed files with 1722 additions and 150 deletions
@@ -8,6 +8,8 @@ namespace MxGateway.Client.Cli;
public static class MxGatewayClientCli
{
private const uint MaxAggregateEvents = 10_000;
private static readonly JsonFormatter ProtobufJsonFormatter = JsonFormatter.Default;
private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web);
@@ -342,8 +344,22 @@ public static class MxGatewayClientCli
TextWriter output,
CancellationToken cancellationToken)
{
var events = new List<MxEvent>();
uint maxEvents = arguments.GetUInt32("max-events", 0);
bool json = arguments.HasFlag("json");
bool jsonLines = arguments.HasFlag("jsonl");
if (json && !jsonLines && maxEvents is 0)
{
throw new ArgumentException("--json stream-events requires --max-events to bound aggregate output.");
}
if (maxEvents > MaxAggregateEvents)
{
throw new ArgumentException($"--max-events cannot exceed {MaxAggregateEvents}.");
}
var events = json && !jsonLines
? new List<MxEvent>(checked((int)maxEvents))
: [];
uint eventCount = 0;
var request = new StreamEventsRequest
{
@@ -355,7 +371,11 @@ public static class MxGatewayClientCli
.WithCancellation(cancellationToken)
.ConfigureAwait(false))
{
if (arguments.HasFlag("json"))
if (jsonLines)
{
output.WriteLine(ProtobufJsonFormatter.Format(gatewayEvent));
}
else if (json)
{
events.Add(gatewayEvent);
}
@@ -371,7 +391,7 @@ public static class MxGatewayClientCli
}
}
if (arguments.HasFlag("json"))
if (json && !jsonLines)
{
output.WriteLine(JsonSerializer.Serialize(
new { events = events.Select(EventToJsonElement).ToArray() },
@@ -25,7 +25,7 @@ internal sealed class GrpcMxGatewayClientTransport(
}
catch (RpcException exception)
{
throw MapRpcException(exception);
throw MapRpcException(exception, callOptions.CancellationToken);
}
}
@@ -41,7 +41,7 @@ internal sealed class GrpcMxGatewayClientTransport(
}
catch (RpcException exception)
{
throw MapRpcException(exception);
throw MapRpcException(exception, callOptions.CancellationToken);
}
}
@@ -57,7 +57,7 @@ internal sealed class GrpcMxGatewayClientTransport(
}
catch (RpcException exception)
{
throw MapRpcException(exception);
throw MapRpcException(exception, callOptions.CancellationToken);
}
}
@@ -87,7 +87,7 @@ internal sealed class GrpcMxGatewayClientTransport(
}
catch (RpcException exception)
{
throw MapRpcException(exception);
throw MapRpcException(exception, effectiveCancellationToken);
}
yield return gatewayEvent;
@@ -101,8 +101,18 @@ internal sealed class GrpcMxGatewayClientTransport(
return StreamEventsAsync(request, callOptions);
}
private static MxGatewayException MapRpcException(RpcException exception)
private static Exception MapRpcException(
RpcException exception,
CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested || exception.StatusCode == StatusCode.Cancelled)
{
return new OperationCanceledException(
exception.Status.Detail,
exception,
cancellationToken);
}
return exception.StatusCode switch
{
StatusCode.Unauthenticated => new MxGatewayAuthenticationException(
@@ -3,6 +3,9 @@ using Grpc.Net.Client;
using Microsoft.Extensions.Logging;
using MxGateway.Contracts.Proto;
using Polly;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
namespace MxGateway.Client;
@@ -54,10 +57,12 @@ public sealed class MxGatewayClient : IAsyncDisposable
ArgumentNullException.ThrowIfNull(options);
options.Validate();
HttpMessageHandler handler = CreateHttpHandler(options);
var channel = GrpcChannel.ForAddress(
options.Endpoint,
new GrpcChannelOptions
{
HttpHandler = handler,
LoggerFactory = options.LoggerFactory,
});
@@ -126,7 +131,7 @@ public sealed class MxGatewayClient : IAsyncDisposable
ArgumentNullException.ThrowIfNull(request);
ThrowIfDisposed();
return _transport.StreamEventsAsync(request, CreateCallOptions(cancellationToken));
return _transport.StreamEventsAsync(request, CreateStreamCallOptions(cancellationToken));
}
public ValueTask DisposeAsync()
@@ -142,6 +147,18 @@ public sealed class MxGatewayClient : IAsyncDisposable
}
internal CallOptions CreateCallOptions(CancellationToken cancellationToken)
{
return CreateCallOptions(cancellationToken, Options.DefaultCallTimeout);
}
internal CallOptions CreateStreamCallOptions(CancellationToken cancellationToken)
{
return CreateCallOptions(cancellationToken, Options.StreamTimeout);
}
internal CallOptions CreateCallOptions(
CancellationToken cancellationToken,
TimeSpan? timeout)
{
Metadata headers = new()
{
@@ -150,18 +167,61 @@ public sealed class MxGatewayClient : IAsyncDisposable
return new CallOptions(
headers,
DateTime.UtcNow.Add(Options.DefaultCallTimeout),
timeout is null ? null : DateTime.UtcNow.Add(timeout.Value),
cancellationToken);
}
private Task<T> ExecuteSafeUnaryAsync<T>(
private async Task<T> ExecuteSafeUnaryAsync<T>(
Func<CancellationToken, Task<T>> call,
CancellationToken cancellationToken)
{
return _safeUnaryRetryPipeline.ExecuteAsync(
using CancellationTokenSource timeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeout.CancelAfter(Options.DefaultCallTimeout);
return await _safeUnaryRetryPipeline.ExecuteAsync(
async token => await call(token).ConfigureAwait(false),
cancellationToken)
.AsTask();
timeout.Token)
.ConfigureAwait(false);
}
private static HttpMessageHandler CreateHttpHandler(MxGatewayClientOptions options)
{
SocketsHttpHandler handler = new()
{
ConnectTimeout = options.ConnectTimeout,
};
if (options.UseTls)
{
handler.SslOptions = new SslClientAuthenticationOptions();
if (!string.IsNullOrWhiteSpace(options.ServerNameOverride))
{
handler.SslOptions.TargetHost = options.ServerNameOverride;
}
if (!string.IsNullOrWhiteSpace(options.CaCertificatePath))
{
X509Certificate2 trustedRoot = X509CertificateLoader.LoadCertificateFromFile(options.CaCertificatePath);
handler.SslOptions.RemoteCertificateValidationCallback = (_, certificate, chain, errors) =>
{
if (certificate is null)
{
return false;
}
using X509Chain customChain = new();
customChain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
customChain.ChainPolicy.CustomTrustStore.Add(trustedRoot);
customChain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
customChain.ChainPolicy.VerificationFlags = X509VerificationFlags.NoFlag;
X509Certificate2 certificateToValidate = certificate as X509Certificate2
?? X509CertificateLoader.LoadCertificate(certificate.Export(X509ContentType.Cert));
return customChain.Build(certificateToValidate);
};
}
}
return handler;
}
private void ThrowIfDisposed()
@@ -21,6 +21,8 @@ public sealed class MxGatewayClientOptions
public TimeSpan DefaultCallTimeout { get; init; } = TimeSpan.FromSeconds(30);
public TimeSpan? StreamTimeout { get; init; }
public MxGatewayClientRetryOptions Retry { get; init; } = new();
public ILoggerFactory? LoggerFactory { get; init; }
@@ -57,6 +59,27 @@ public sealed class MxGatewayClientOptions
"The default call timeout must be greater than zero.");
}
if (StreamTimeout is not null && StreamTimeout <= TimeSpan.Zero)
{
throw new ArgumentOutOfRangeException(
nameof(StreamTimeout),
"The stream timeout must be greater than zero when configured.");
}
if (UseTls && Endpoint.Scheme != Uri.UriSchemeHttps)
{
throw new ArgumentException(
"UseTls requires an https gateway endpoint.",
nameof(Endpoint));
}
if (!UseTls && Endpoint.Scheme == Uri.UriSchemeHttps)
{
throw new ArgumentException(
"An https gateway endpoint requires UseTls.",
nameof(Endpoint));
}
Retry.Validate();
}
}
+23 -3
View File
@@ -377,17 +377,19 @@ func runSmoke(ctx context.Context, args []string, stdout, stderr io.Writer) erro
if err != nil {
return err
}
defer session.Close(context.Background())
serverHandle, err := session.Register(ctx, *clientName)
if err != nil {
return err
return closeSmokeSession(ctx, session, err)
}
itemHandle, err := session.AddItem(ctx, serverHandle, *item)
if err != nil {
return err
return closeSmokeSession(ctx, session, err)
}
if err := session.Advise(ctx, serverHandle, itemHandle); err != nil {
return closeSmokeSession(ctx, session, err)
}
if err := closeSmokeSession(ctx, session, nil); err != nil {
return err
}
@@ -406,6 +408,24 @@ func runSmoke(ctx context.Context, args []string, stdout, stderr io.Writer) erro
return nil
}
func closeSmokeSession(ctx context.Context, session *mxgateway.Session, primaryErr error) error {
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if deadline, ok := ctx.Deadline(); ok {
if until := time.Until(deadline); until > 0 && until < 5*time.Second {
cancel()
closeCtx, cancel = context.WithTimeout(context.Background(), until)
defer cancel()
}
}
_, closeErr := session.Close(closeCtx)
if primaryErr != nil {
return primaryErr
}
return closeErr
}
func bindCommonFlags(flags *flag.FlagSet) *commonOptions {
common := &commonOptions{}
flags.StringVar(&common.Endpoint, "endpoint", "localhost:5000", "gateway endpoint")
+5 -2
View File
@@ -184,8 +184,11 @@ func (c *Client) callContext(ctx context.Context) (context.Context, context.Canc
if timeout < 0 {
return ctx, func() {}
}
if _, ok := ctx.Deadline(); ok {
return ctx, func() {}
if deadline, ok := ctx.Deadline(); ok {
timeoutDeadline := time.Now().Add(timeout)
if deadline.Before(timeoutDeadline) {
return ctx, func() {}
}
}
return context.WithTimeout(ctx, timeout)
}
+41 -2
View File
@@ -5,6 +5,7 @@ import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"sync"
@@ -13,6 +14,8 @@ import (
"google.golang.org/grpc/status"
)
const maxBulkItems = 1000
// EventResult carries either the next ordered event or a terminal stream error.
type EventResult struct {
Event *MxEvent
@@ -225,6 +228,9 @@ func (s *Session) AddItemBulk(ctx context.Context, serverHandle int32, tagAddres
if tagAddresses == nil {
return nil, errors.New("mxgateway: tag addresses are required")
}
if err := ensureBulkSize("tag addresses", len(tagAddresses)); err != nil {
return nil, err
}
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM_BULK,
Payload: &pb.MxCommand_AddItemBulk{
@@ -245,6 +251,9 @@ func (s *Session) AdviseItemBulk(ctx context.Context, serverHandle int32, itemHa
if itemHandles == nil {
return nil, errors.New("mxgateway: item handles are required")
}
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
return nil, err
}
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE_ITEM_BULK,
Payload: &pb.MxCommand_AdviseItemBulk{
@@ -265,6 +274,9 @@ func (s *Session) RemoveItemBulk(ctx context.Context, serverHandle int32, itemHa
if itemHandles == nil {
return nil, errors.New("mxgateway: item handles are required")
}
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
return nil, err
}
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
Kind: pb.MxCommandKind_MX_COMMAND_KIND_REMOVE_ITEM_BULK,
Payload: &pb.MxCommand_RemoveItemBulk{
@@ -285,6 +297,9 @@ func (s *Session) UnAdviseItemBulk(ctx context.Context, serverHandle int32, item
if itemHandles == nil {
return nil, errors.New("mxgateway: item handles are required")
}
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
return nil, err
}
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
Kind: pb.MxCommandKind_MX_COMMAND_KIND_UN_ADVISE_ITEM_BULK,
Payload: &pb.MxCommand_UnAdviseItemBulk{
@@ -305,6 +320,9 @@ func (s *Session) SubscribeBulk(ctx context.Context, serverHandle int32, tagAddr
if tagAddresses == nil {
return nil, errors.New("mxgateway: tag addresses are required")
}
if err := ensureBulkSize("tag addresses", len(tagAddresses)); err != nil {
return nil, err
}
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
Kind: pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK,
Payload: &pb.MxCommand_SubscribeBulk{
@@ -325,6 +343,9 @@ func (s *Session) UnsubscribeBulk(ctx context.Context, serverHandle int32, itemH
if itemHandles == nil {
return nil, errors.New("mxgateway: item handles are required")
}
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
return nil, err
}
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
Kind: pb.MxCommandKind_MX_COMMAND_KIND_UNSUBSCRIBE_BULK,
Payload: &pb.MxCommand_UnsubscribeBulk{
@@ -387,13 +408,15 @@ func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (
for {
event, err := stream.Recv()
if err == nil {
results <- EventResult{Event: event}
if !sendEventResult(ctx, results, EventResult{Event: event}) {
return
}
continue
}
if err == io.EOF || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return
}
results <- EventResult{Err: &GatewayError{Op: "stream events", Err: err}}
sendEventResult(ctx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}})
return
}
}()
@@ -401,6 +424,22 @@ func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (
return results, nil
}
func ensureBulkSize(name string, length int) error {
if length > maxBulkItems {
return fmt.Errorf("mxgateway: %s bulk commands are limited to %d item(s)", name, maxBulkItems)
}
return nil
}
func sendEventResult(ctx context.Context, results chan<- EventResult, result EventResult) bool {
select {
case results <- result:
return true
case <-ctx.Done():
return false
}
}
func (s *Session) invokeCommand(ctx context.Context, command *MxCommand) (*MxCommandReply, error) {
return s.client.Invoke(ctx, &pb.MxCommandRequest{
SessionId: s.ID(),
@@ -334,25 +334,28 @@ public final class MxGatewayCli implements Callable<Integer> {
var session = client.openSession(OpenSessionRequest.newBuilder()
.setClientSessionName(clientName)
.build());
MxGatewayCliSession cliSession = client.session(session.getSessionId());
int serverHandle = cliSession.register(clientName);
int itemHandle = cliSession.addItem(serverHandle, item);
cliSession.advise(serverHandle, itemHandle);
if (json) {
Map<String, Object> output = new LinkedHashMap<>();
output.put("command", "smoke");
output.put("options", common.redactedJsonMap());
output.put("sessionId", session.getSessionId());
output.put("serverHandle", serverHandle);
output.put("itemHandle", itemHandle);
client.out().println(jsonObject(output));
} else {
client.out().printf(
"session=%s server=%d item=%d%n", session.getSessionId(), serverHandle, itemHandle);
try {
MxGatewayCliSession cliSession = client.session(session.getSessionId());
int serverHandle = cliSession.register(clientName);
int itemHandle = cliSession.addItem(serverHandle, item);
cliSession.advise(serverHandle, itemHandle);
if (json) {
Map<String, Object> output = new LinkedHashMap<>();
output.put("command", "smoke");
output.put("options", common.redactedJsonMap());
output.put("sessionId", session.getSessionId());
output.put("serverHandle", serverHandle);
output.put("itemHandle", itemHandle);
client.out().println(jsonObject(output));
} else {
client.out().printf(
"session=%s server=%d item=%d%n", session.getSessionId(), serverHandle, itemHandle);
}
} finally {
client.closeSession(CloseSessionRequest.newBuilder()
.setSessionId(session.getSessionId())
.build());
}
client.closeSession(CloseSessionRequest.newBuilder()
.setSessionId(session.getSessionId())
.build());
}
return 0;
}
@@ -105,13 +105,20 @@ public final class MxEventStream implements Iterator<MxEvent>, AutoCloseable {
private void offer(Object value) {
Objects.requireNonNull(value, "value");
if (value == END) {
queue.offer(value);
if (!queue.offer(value)) {
queue.clear();
queue.offer(value);
}
return;
}
try {
queue.put(value);
} catch (InterruptedException error) {
Thread.currentThread().interrupt();
if (!queue.offer(value)) {
ClientCallStreamObserver<StreamEventsRequest> stream = requestStream;
if (stream != null) {
stream.cancel("client event stream queue overflowed", null);
}
queue.clear();
queue.offer(new MxGatewayException("gateway stream events queue overflowed"));
queue.offer(END);
}
}
}
@@ -63,7 +63,7 @@ public final class MxGatewayClient implements AutoCloseable {
}
public MxAccessGatewayGrpc.MxAccessGatewayStub rawAsyncStub() {
return withDeadline(asyncStub);
return asyncStub;
}
public MxGatewaySession openSession(OpenSessionRequest request) {
@@ -140,14 +140,14 @@ public final class MxGatewayClient implements AutoCloseable {
public MxEventStream streamEvents(StreamEventsRequest request) {
MxEventStream stream = new MxEventStream(16);
rawAsyncStub().streamEvents(request, stream.observer());
withStreamDeadline(rawAsyncStub()).streamEvents(request, stream.observer());
return stream;
}
public MxGatewayEventSubscription streamEventsAsync(
StreamEventsRequest request, StreamObserver<MxEvent> observer) {
MxGatewayEventSubscription subscription = new MxGatewayEventSubscription();
rawAsyncStub().streamEvents(request, subscription.wrap(observer));
withStreamDeadline(rawAsyncStub()).streamEvents(request, subscription.wrap(observer));
return subscription;
}
@@ -161,7 +161,9 @@ public final class MxGatewayClient implements AutoCloseable {
public void closeAndAwaitTermination() throws InterruptedException {
if (ownedChannel != null) {
ownedChannel.shutdown();
ownedChannel.awaitTermination(options.connectTimeout().toMillis(), TimeUnit.MILLISECONDS);
if (!ownedChannel.awaitTermination(options.connectTimeout().toMillis(), TimeUnit.MILLISECONDS)) {
ownedChannel.shutdownNow();
}
}
}
@@ -199,6 +201,13 @@ public final class MxGatewayClient implements AutoCloseable {
return stub.withDeadlineAfter(options.callTimeout().toNanos(), TimeUnit.NANOSECONDS);
}
private <T extends io.grpc.stub.AbstractStub<T>> T withStreamDeadline(T stub) {
if (options.streamTimeout() == null || options.streamTimeout().isNegative()) {
return stub;
}
return stub.withDeadlineAfter(options.streamTimeout().toNanos(), TimeUnit.NANOSECONDS);
}
private static <T> CompletableFuture<T> toCompletable(com.google.common.util.concurrent.ListenableFuture<T> source) {
CompletableFuture<T> target = new CompletableFuture<>();
Futures.addCallback(
@@ -219,6 +228,11 @@ public final class MxGatewayClient implements AutoCloseable {
}
},
MoreExecutors.directExecutor());
target.whenComplete((ignoredResult, ignoredError) -> {
if (target.isCancelled()) {
source.cancel(true);
}
});
return target;
}
@@ -15,6 +15,7 @@ public final class MxGatewayClientOptions {
private final String serverNameOverride;
private final Duration connectTimeout;
private final Duration callTimeout;
private final Duration streamTimeout;
private MxGatewayClientOptions(Builder builder) {
endpoint = requireText(builder.endpoint, "endpoint");
@@ -24,6 +25,7 @@ public final class MxGatewayClientOptions {
serverNameOverride = builder.serverNameOverride == null ? "" : builder.serverNameOverride;
connectTimeout = builder.connectTimeout == null ? DEFAULT_CONNECT_TIMEOUT : builder.connectTimeout;
callTimeout = builder.callTimeout == null ? DEFAULT_CALL_TIMEOUT : builder.callTimeout;
streamTimeout = builder.streamTimeout;
}
public static Builder builder() {
@@ -62,6 +64,10 @@ public final class MxGatewayClientOptions {
return callTimeout;
}
public Duration streamTimeout() {
return streamTimeout;
}
@Override
public String toString() {
return "MxGatewayClientOptions{"
@@ -82,6 +88,8 @@ public final class MxGatewayClientOptions {
+ connectTimeout
+ ", callTimeout="
+ callTimeout
+ ", streamTimeout="
+ streamTimeout
+ '}';
}
@@ -100,6 +108,7 @@ public final class MxGatewayClientOptions {
private String serverNameOverride;
private Duration connectTimeout;
private Duration callTimeout;
private Duration streamTimeout;
private Builder() {
}
@@ -139,6 +148,11 @@ public final class MxGatewayClientOptions {
return this;
}
public Builder streamTimeout(Duration value) {
streamTimeout = Objects.requireNonNull(value, "streamTimeout");
return this;
}
public MxGatewayClientOptions build() {
return new MxGatewayClientOptions(this);
}
@@ -4,17 +4,22 @@ import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicBoolean;
import mxaccess_gateway.v1.MxaccessGateway.MxEvent;
import mxaccess_gateway.v1.MxaccessGateway.StreamEventsRequest;
public final class MxGatewayEventSubscription implements AutoCloseable {
private final AtomicReference<ClientCallStreamObserver<StreamEventsRequest>> requestStream = new AtomicReference<>();
private final AtomicBoolean cancelled = new AtomicBoolean();
ClientResponseObserver<StreamEventsRequest, MxEvent> wrap(StreamObserver<MxEvent> observer) {
return new ClientResponseObserver<>() {
@Override
public void beforeStart(ClientCallStreamObserver<StreamEventsRequest> stream) {
requestStream.set(stream);
if (cancelled.get()) {
stream.cancel("client cancelled event stream", null);
}
}
@Override
@@ -35,6 +40,7 @@ public final class MxGatewayEventSubscription implements AutoCloseable {
}
public void cancel() {
cancelled.set(true);
ClientCallStreamObserver<StreamEventsRequest> stream = requestStream.get();
if (stream != null) {
stream.cancel("client cancelled event stream", null);
+15 -9
View File
@@ -74,9 +74,9 @@ class GatewayClient:
if self._closed:
return
self._closed = True
if self._channel is not None:
await self._channel.close()
self._closed = True
async def open_session(
self,
@@ -124,10 +124,10 @@ class GatewayClient:
) -> AsyncIterator[pb.MxEvent]:
"""Return an async event iterator and cancel the stream when iteration stops."""
call = self.raw_stub.StreamEvents(
request,
metadata=merge_metadata(self.options.api_key, metadata),
)
kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)}
if self.options.stream_timeout is not None:
kwargs["timeout"] = self.options.stream_timeout
call = self.raw_stub.StreamEvents(request, **kwargs)
return _canceling_iterator(call)
async def _unary(
@@ -138,10 +138,16 @@ class GatewayClient:
*,
metadata: Sequence[tuple[str, str]] | None = None,
) -> Any:
call = method(
request,
metadata=merge_metadata(self.options.api_key, metadata),
)
kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)}
if self.options.call_timeout is not None:
kwargs["timeout"] = self.options.call_timeout
try:
call = method(request, **kwargs)
except TypeError as error:
if "timeout" not in kwargs or "unexpected keyword argument 'timeout'" not in str(error):
raise
kwargs.pop("timeout")
call = method(request, **kwargs)
try:
return await call
except asyncio.CancelledError:
+9 -1
View File
@@ -19,6 +19,8 @@ class ClientOptions:
plaintext: bool = False
ca_file: str | None = None
server_name_override: str | None = None
call_timeout: float | None = 30.0
stream_timeout: float | None = None
def __post_init__(self) -> None:
if not self.endpoint:
@@ -26,6 +28,10 @@ class ClientOptions:
if self.plaintext and self.ca_file:
raise ValueError("ca_file cannot be used with plaintext connections")
if self.call_timeout is not None and self.call_timeout <= 0:
raise ValueError("call_timeout must be greater than zero")
if self.stream_timeout is not None and self.stream_timeout <= 0:
raise ValueError("stream_timeout must be greater than zero")
def __repr__(self) -> str:
api_key = REDACTED if self.api_key else None
@@ -33,7 +39,9 @@ class ClientOptions:
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
f"ca_file={self.ca_file!r}, "
f"server_name_override={self.server_name_override!r})"
f"server_name_override={self.server_name_override!r}, "
f"call_timeout={self.call_timeout!r}, "
f"stream_timeout={self.stream_timeout!r})"
)
+16 -2
View File
@@ -8,6 +8,8 @@ from .errors import ensure_mxaccess_success
from .generated import mxaccess_gateway_pb2 as pb
from .values import MxValueInput, to_mx_value
MAX_BULK_ITEMS = 1000
class Session:
"""A single gateway-backed MXAccess session."""
@@ -40,13 +42,14 @@ class Session:
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
)
self._closed = True
return await self.client.close_session_raw(
reply = await self.client.close_session_raw(
pb.CloseSessionRequest(
session_id=self.session_id,
client_correlation_id=client_correlation_id,
),
)
self._closed = True
return reply
async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply:
"""Invoke a raw command and enforce gateway and MXAccess success."""
@@ -192,6 +195,7 @@ class Session:
) -> list[pb.SubscribeResult]:
if tag_addresses is None:
raise TypeError("tag_addresses is required")
_ensure_bulk_size("tag_addresses", len(tag_addresses))
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_ADD_ITEM_BULK,
@@ -213,6 +217,7 @@ class Session:
) -> list[pb.SubscribeResult]:
if item_handles is None:
raise TypeError("item_handles is required")
_ensure_bulk_size("item_handles", len(item_handles))
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_ADVISE_ITEM_BULK,
@@ -234,6 +239,7 @@ class Session:
) -> list[pb.SubscribeResult]:
if item_handles is None:
raise TypeError("item_handles is required")
_ensure_bulk_size("item_handles", len(item_handles))
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_REMOVE_ITEM_BULK,
@@ -255,6 +261,7 @@ class Session:
) -> list[pb.SubscribeResult]:
if item_handles is None:
raise TypeError("item_handles is required")
_ensure_bulk_size("item_handles", len(item_handles))
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_UN_ADVISE_ITEM_BULK,
@@ -276,6 +283,7 @@ class Session:
) -> list[pb.SubscribeResult]:
if tag_addresses is None:
raise TypeError("tag_addresses is required")
_ensure_bulk_size("tag_addresses", len(tag_addresses))
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_SUBSCRIBE_BULK,
@@ -297,6 +305,7 @@ class Session:
) -> list[pb.SubscribeResult]:
if item_handles is None:
raise TypeError("item_handles is required")
_ensure_bulk_size("item_handles", len(item_handles))
reply = await self.invoke(
pb.MxCommand(
kind=pb.MX_COMMAND_KIND_UNSUBSCRIBE_BULK,
@@ -368,4 +377,9 @@ class Session:
)
def _ensure_bulk_size(name: str, count: int) -> None:
if count > MAX_BULK_ITEMS:
raise ValueError(f"{name} bulk commands are limited to {MAX_BULK_ITEMS} item(s)")
from .client import GatewayClient # noqa: E402
@@ -20,6 +20,8 @@ from mxgateway.generated import mxaccess_gateway_pb2 as pb
from mxgateway.options import ClientOptions
from mxgateway.values import MxValueInput
MAX_AGGREGATE_EVENTS = 10_000
@click.group()
def main() -> None:
@@ -55,6 +57,8 @@ def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]:
default=None,
help="TLS server name override for test environments.",
)(command)
command = click.option("--call-timeout", default=30.0, type=float, show_default=True)(command)
command = click.option("--stream-timeout", default=None, type=float)(command)
return command
@@ -352,6 +356,8 @@ async def _connect(kwargs: dict[str, Any]) -> GatewayClient:
plaintext=_use_plaintext(kwargs),
ca_file=kwargs.get("ca_file"),
server_name_override=kwargs.get("server_name_override"),
call_timeout=kwargs.get("call_timeout"),
stream_timeout=kwargs.get("stream_timeout"),
),
)
@@ -416,6 +422,12 @@ async def _collect_events(
max_events: int,
timeout: float,
) -> list[pb.MxEvent]:
if max_events > MAX_AGGREGATE_EVENTS:
raise click.BadParameter(
f"must be less than or equal to {MAX_AGGREGATE_EVENTS}",
param_hint="--max-events",
)
collected: list[pb.MxEvent] = []
iterator = events.__aiter__()
try:
@@ -423,6 +435,8 @@ async def _collect_events(
collected.append(await asyncio.wait_for(iterator.__anext__(), timeout=timeout))
except StopAsyncIteration:
pass
except asyncio.TimeoutError:
pass
finally:
close = getattr(iterator, "aclose", None)
if close is not None:
+30 -7
View File
@@ -16,6 +16,8 @@ use mxgateway_client::{
use serde_json::json;
use serde_json::Value;
const MAX_AGGREGATE_EVENTS: usize = 10_000;
#[derive(Debug, Parser)]
#[command(name = "mxgw")]
#[command(about = "MXAccess Gateway Rust test CLI")]
@@ -29,6 +31,8 @@ enum Command {
Version {
#[arg(long)]
json: bool,
#[arg(long)]
jsonl: bool,
},
Ping {
#[command(flatten)]
@@ -325,7 +329,15 @@ async fn run(cli: Cli) -> Result<(), Error> {
after_worker_sequence,
max_events,
json,
jsonl,
} => {
if max_events > MAX_AGGREGATE_EVENTS {
return Err(Error::InvalidArgument {
name: "max-events".to_owned(),
detail: format!("must be less than or equal to {MAX_AGGREGATE_EVENTS}"),
});
}
let client = connect(connection).await?;
let mut stream = client
.stream_events(StreamEventsRequest {
@@ -334,19 +346,30 @@ async fn run(cli: Cli) -> Result<(), Error> {
})
.await?;
let mut events = Vec::new();
while events.len() < max_events {
let mut event_count = 0usize;
while event_count < max_events {
let Some(event) = stream.next().await else {
break;
};
events.push(event?);
}
if json {
println!("{}", json!({ "eventCount": events.len() }));
} else {
for event in events {
let event = event?;
event_count += 1;
if jsonl {
println!(
"{}",
json!({
"workerSequence": event.worker_sequence,
"family": event.family,
})
);
} else if json {
events.push(event);
} else {
println!("{} {}", event.worker_sequence, event.family);
}
}
if json {
println!("{}", json!({ "eventCount": event_count }));
}
}
Command::Write {
connection,
+14 -2
View File
@@ -5,7 +5,7 @@ use tonic::transport::{Certificate, Channel, ClientTlsConfig};
use tonic::Request;
use crate::auth::AuthInterceptor;
use crate::error::{ensure_command_success, Error};
use crate::error::{ensure_command_success, ensure_protocol_success, Error};
use crate::generated::mxaccess_gateway::v1::mx_access_gateway_client::MxAccessGatewayClient;
use crate::generated::mxaccess_gateway::v1::{
CloseSessionReply, CloseSessionRequest, MxCommandReply, MxCommandRequest, MxEvent,
@@ -23,6 +23,7 @@ pub type EventStream =
pub struct GatewayClient {
inner: RawGatewayClient,
call_timeout: std::time::Duration,
stream_timeout: Option<std::time::Duration>,
}
impl GatewayClient {
@@ -57,6 +58,7 @@ impl GatewayClient {
Ok(Self {
inner: MxAccessGatewayClient::with_interceptor(channel, interceptor),
call_timeout: options.call_timeout(),
stream_timeout: options.stream_timeout(),
})
}
@@ -83,6 +85,7 @@ impl GatewayClient {
pub async fn open_session(&self, request: OpenSessionRequest) -> Result<Session, Error> {
let reply = self.open_session_raw(request).await?;
ensure_protocol_success("open session", reply.protocol_status.as_ref())?;
Ok(Session::new(reply.session_id, self.clone()))
}
@@ -107,7 +110,7 @@ impl GatewayClient {
pub async fn stream_events(&self, request: StreamEventsRequest) -> Result<EventStream, Error> {
let mut client = self.inner.clone();
let response = client.stream_events(self.unary_request(request)).await?;
let response = client.stream_events(self.stream_request(request)).await?;
let stream = futures_util::StreamExt::map(response.into_inner(), |result| {
result.map_err(Error::from)
});
@@ -120,4 +123,13 @@ impl GatewayClient {
request.set_timeout(self.call_timeout);
request
}
fn stream_request<T>(&self, message: T) -> Request<T> {
let mut request = Request::new(message);
if let Some(timeout) = self.stream_timeout {
request.set_timeout(timeout);
}
request
}
}
+29 -1
View File
@@ -1,7 +1,7 @@
use thiserror::Error as ThisError;
use tonic::Code;
use crate::generated::mxaccess_gateway::v1::{MxCommandReply, ProtocolStatusCode};
use crate::generated::mxaccess_gateway::v1::{MxCommandReply, ProtocolStatus, ProtocolStatusCode};
#[derive(Debug, ThisError)]
pub enum Error {
@@ -47,6 +47,13 @@ pub enum Error {
#[error("gateway command failed: {0}")]
Command(#[from] Box<CommandError>),
#[error("gateway {operation} failed: {code:?}: {message}")]
ProtocolStatus {
operation: &'static str,
code: ProtocolStatusCode,
message: String,
},
}
#[derive(Clone, Debug)]
@@ -125,6 +132,27 @@ pub fn ensure_command_success(reply: MxCommandReply) -> Result<MxCommandReply, E
}
}
pub fn ensure_protocol_success(
operation: &'static str,
status: Option<&ProtocolStatus>,
) -> Result<(), Error> {
let code = status
.and_then(|status| ProtocolStatusCode::try_from(status.code).ok())
.unwrap_or(ProtocolStatusCode::Unspecified);
if code == ProtocolStatusCode::Ok {
Ok(())
} else {
Err(Error::ProtocolStatus {
operation,
code,
message: status
.map(|status| status.message.clone())
.unwrap_or_default(),
})
}
}
fn redact_credentials(message: &str) -> String {
message
.split_whitespace()
+12
View File
@@ -13,6 +13,7 @@ pub struct ClientOptions {
server_name_override: Option<String>,
connect_timeout: Duration,
call_timeout: Duration,
stream_timeout: Option<Duration>,
}
impl ClientOptions {
@@ -25,6 +26,7 @@ impl ClientOptions {
server_name_override: None,
connect_timeout: Duration::from_secs(10),
call_timeout: Duration::from_secs(30),
stream_timeout: None,
}
}
@@ -58,6 +60,11 @@ impl ClientOptions {
self
}
pub fn with_stream_timeout(mut self, stream_timeout: Duration) -> Self {
self.stream_timeout = Some(stream_timeout);
self
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
@@ -85,6 +92,10 @@ impl ClientOptions {
pub fn call_timeout(&self) -> Duration {
self.call_timeout
}
pub fn stream_timeout(&self) -> Option<Duration> {
self.stream_timeout
}
}
impl Default for ClientOptions {
@@ -104,6 +115,7 @@ impl fmt::Debug for ClientOptions {
.field("server_name_override", &self.server_name_override)
.field("connect_timeout", &self.connect_timeout)
.field("call_timeout", &self.call_timeout)
.field("stream_timeout", &self.stream_timeout)
.finish()
}
}
+23 -2
View File
@@ -1,5 +1,5 @@
use crate::client::{EventStream, GatewayClient};
use crate::error::Error;
use crate::error::{ensure_protocol_success, Error};
use crate::generated::mxaccess_gateway::v1::mx_command::Payload;
use crate::generated::mxaccess_gateway::v1::mx_command_reply;
use crate::generated::mxaccess_gateway::v1::{
@@ -11,6 +11,8 @@ use crate::generated::mxaccess_gateway::v1::{
};
use crate::value::MxValue;
const MAX_BULK_ITEMS: usize = 1_000;
/// Session identifier returned by the gateway.
#[derive(Clone)]
pub struct Session {
@@ -40,12 +42,14 @@ impl Session {
}
pub async fn close(&self) -> Result<(), Error> {
self.client
let reply = self
.client
.close_session_raw(CloseSessionRequest {
session_id: self.id.clone(),
client_correlation_id: "rust-client-close-session".to_owned(),
})
.await?;
ensure_protocol_success("close session", reply.protocol_status.as_ref())?;
Ok(())
}
@@ -137,6 +141,7 @@ impl Session {
server_handle: i32,
tag_addresses: Vec<String>,
) -> Result<Vec<SubscribeResult>, Error> {
ensure_bulk_size("tag_addresses", tag_addresses.len())?;
let reply = self
.invoke(
MxCommandKind::AddItemBulk,
@@ -155,6 +160,7 @@ impl Session {
server_handle: i32,
item_handles: Vec<i32>,
) -> Result<Vec<SubscribeResult>, Error> {
ensure_bulk_size("item_handles", item_handles.len())?;
let reply = self
.invoke(
MxCommandKind::AdviseItemBulk,
@@ -173,6 +179,7 @@ impl Session {
server_handle: i32,
item_handles: Vec<i32>,
) -> Result<Vec<SubscribeResult>, Error> {
ensure_bulk_size("item_handles", item_handles.len())?;
let reply = self
.invoke(
MxCommandKind::RemoveItemBulk,
@@ -191,6 +198,7 @@ impl Session {
server_handle: i32,
item_handles: Vec<i32>,
) -> Result<Vec<SubscribeResult>, Error> {
ensure_bulk_size("item_handles", item_handles.len())?;
let reply = self
.invoke(
MxCommandKind::UnAdviseItemBulk,
@@ -209,6 +217,7 @@ impl Session {
server_handle: i32,
tag_addresses: Vec<String>,
) -> Result<Vec<SubscribeResult>, Error> {
ensure_bulk_size("tag_addresses", tag_addresses.len())?;
let reply = self
.invoke(
MxCommandKind::SubscribeBulk,
@@ -227,6 +236,7 @@ impl Session {
server_handle: i32,
item_handles: Vec<i32>,
) -> Result<Vec<SubscribeResult>, Error> {
ensure_bulk_size("item_handles", item_handles.len())?;
let reply = self
.invoke(
MxCommandKind::UnsubscribeBulk,
@@ -327,6 +337,17 @@ impl Session {
}
}
fn ensure_bulk_size(name: &'static str, len: usize) -> Result<(), Error> {
if len > MAX_BULK_ITEMS {
Err(Error::InvalidArgument {
name: name.to_owned(),
detail: format!("bulk commands are limited to {MAX_BULK_ITEMS} item(s)"),
})
} else {
Ok(())
}
}
fn register_server_handle(reply: &MxCommandReply) -> i32 {
match reply.payload.as_ref() {
Some(mx_command_reply::Payload::Register(register)) => register.server_handle,