Improve gateway reliability and dashboard docs
This commit is contained in:
@@ -8,6 +8,8 @@ namespace MxGateway.Client.Cli;
|
||||
|
||||
public static class MxGatewayClientCli
|
||||
{
|
||||
private const uint MaxAggregateEvents = 10_000;
|
||||
|
||||
private static readonly JsonFormatter ProtobufJsonFormatter = JsonFormatter.Default;
|
||||
|
||||
private static readonly JsonSerializerOptions JsonOptions = new(JsonSerializerDefaults.Web);
|
||||
@@ -342,8 +344,22 @@ public static class MxGatewayClientCli
|
||||
TextWriter output,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
var events = new List<MxEvent>();
|
||||
uint maxEvents = arguments.GetUInt32("max-events", 0);
|
||||
bool json = arguments.HasFlag("json");
|
||||
bool jsonLines = arguments.HasFlag("jsonl");
|
||||
if (json && !jsonLines && maxEvents is 0)
|
||||
{
|
||||
throw new ArgumentException("--json stream-events requires --max-events to bound aggregate output.");
|
||||
}
|
||||
|
||||
if (maxEvents > MaxAggregateEvents)
|
||||
{
|
||||
throw new ArgumentException($"--max-events cannot exceed {MaxAggregateEvents}.");
|
||||
}
|
||||
|
||||
var events = json && !jsonLines
|
||||
? new List<MxEvent>(checked((int)maxEvents))
|
||||
: [];
|
||||
uint eventCount = 0;
|
||||
var request = new StreamEventsRequest
|
||||
{
|
||||
@@ -355,7 +371,11 @@ public static class MxGatewayClientCli
|
||||
.WithCancellation(cancellationToken)
|
||||
.ConfigureAwait(false))
|
||||
{
|
||||
if (arguments.HasFlag("json"))
|
||||
if (jsonLines)
|
||||
{
|
||||
output.WriteLine(ProtobufJsonFormatter.Format(gatewayEvent));
|
||||
}
|
||||
else if (json)
|
||||
{
|
||||
events.Add(gatewayEvent);
|
||||
}
|
||||
@@ -371,7 +391,7 @@ public static class MxGatewayClientCli
|
||||
}
|
||||
}
|
||||
|
||||
if (arguments.HasFlag("json"))
|
||||
if (json && !jsonLines)
|
||||
{
|
||||
output.WriteLine(JsonSerializer.Serialize(
|
||||
new { events = events.Select(EventToJsonElement).ToArray() },
|
||||
|
||||
@@ -25,7 +25,7 @@ internal sealed class GrpcMxGatewayClientTransport(
|
||||
}
|
||||
catch (RpcException exception)
|
||||
{
|
||||
throw MapRpcException(exception);
|
||||
throw MapRpcException(exception, callOptions.CancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ internal sealed class GrpcMxGatewayClientTransport(
|
||||
}
|
||||
catch (RpcException exception)
|
||||
{
|
||||
throw MapRpcException(exception);
|
||||
throw MapRpcException(exception, callOptions.CancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ internal sealed class GrpcMxGatewayClientTransport(
|
||||
}
|
||||
catch (RpcException exception)
|
||||
{
|
||||
throw MapRpcException(exception);
|
||||
throw MapRpcException(exception, callOptions.CancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ internal sealed class GrpcMxGatewayClientTransport(
|
||||
}
|
||||
catch (RpcException exception)
|
||||
{
|
||||
throw MapRpcException(exception);
|
||||
throw MapRpcException(exception, effectiveCancellationToken);
|
||||
}
|
||||
|
||||
yield return gatewayEvent;
|
||||
@@ -101,8 +101,18 @@ internal sealed class GrpcMxGatewayClientTransport(
|
||||
return StreamEventsAsync(request, callOptions);
|
||||
}
|
||||
|
||||
private static MxGatewayException MapRpcException(RpcException exception)
|
||||
private static Exception MapRpcException(
|
||||
RpcException exception,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
if (cancellationToken.IsCancellationRequested || exception.StatusCode == StatusCode.Cancelled)
|
||||
{
|
||||
return new OperationCanceledException(
|
||||
exception.Status.Detail,
|
||||
exception,
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
return exception.StatusCode switch
|
||||
{
|
||||
StatusCode.Unauthenticated => new MxGatewayAuthenticationException(
|
||||
|
||||
@@ -3,6 +3,9 @@ using Grpc.Net.Client;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using MxGateway.Contracts.Proto;
|
||||
using Polly;
|
||||
using System.Net.Http;
|
||||
using System.Net.Security;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
|
||||
namespace MxGateway.Client;
|
||||
|
||||
@@ -54,10 +57,12 @@ public sealed class MxGatewayClient : IAsyncDisposable
|
||||
ArgumentNullException.ThrowIfNull(options);
|
||||
options.Validate();
|
||||
|
||||
HttpMessageHandler handler = CreateHttpHandler(options);
|
||||
var channel = GrpcChannel.ForAddress(
|
||||
options.Endpoint,
|
||||
new GrpcChannelOptions
|
||||
{
|
||||
HttpHandler = handler,
|
||||
LoggerFactory = options.LoggerFactory,
|
||||
});
|
||||
|
||||
@@ -126,7 +131,7 @@ public sealed class MxGatewayClient : IAsyncDisposable
|
||||
ArgumentNullException.ThrowIfNull(request);
|
||||
ThrowIfDisposed();
|
||||
|
||||
return _transport.StreamEventsAsync(request, CreateCallOptions(cancellationToken));
|
||||
return _transport.StreamEventsAsync(request, CreateStreamCallOptions(cancellationToken));
|
||||
}
|
||||
|
||||
public ValueTask DisposeAsync()
|
||||
@@ -142,6 +147,18 @@ public sealed class MxGatewayClient : IAsyncDisposable
|
||||
}
|
||||
|
||||
internal CallOptions CreateCallOptions(CancellationToken cancellationToken)
|
||||
{
|
||||
return CreateCallOptions(cancellationToken, Options.DefaultCallTimeout);
|
||||
}
|
||||
|
||||
internal CallOptions CreateStreamCallOptions(CancellationToken cancellationToken)
|
||||
{
|
||||
return CreateCallOptions(cancellationToken, Options.StreamTimeout);
|
||||
}
|
||||
|
||||
internal CallOptions CreateCallOptions(
|
||||
CancellationToken cancellationToken,
|
||||
TimeSpan? timeout)
|
||||
{
|
||||
Metadata headers = new()
|
||||
{
|
||||
@@ -150,18 +167,61 @@ public sealed class MxGatewayClient : IAsyncDisposable
|
||||
|
||||
return new CallOptions(
|
||||
headers,
|
||||
DateTime.UtcNow.Add(Options.DefaultCallTimeout),
|
||||
timeout is null ? null : DateTime.UtcNow.Add(timeout.Value),
|
||||
cancellationToken);
|
||||
}
|
||||
|
||||
private Task<T> ExecuteSafeUnaryAsync<T>(
|
||||
private async Task<T> ExecuteSafeUnaryAsync<T>(
|
||||
Func<CancellationToken, Task<T>> call,
|
||||
CancellationToken cancellationToken)
|
||||
{
|
||||
return _safeUnaryRetryPipeline.ExecuteAsync(
|
||||
using CancellationTokenSource timeout = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
|
||||
timeout.CancelAfter(Options.DefaultCallTimeout);
|
||||
|
||||
return await _safeUnaryRetryPipeline.ExecuteAsync(
|
||||
async token => await call(token).ConfigureAwait(false),
|
||||
cancellationToken)
|
||||
.AsTask();
|
||||
timeout.Token)
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private static HttpMessageHandler CreateHttpHandler(MxGatewayClientOptions options)
|
||||
{
|
||||
SocketsHttpHandler handler = new()
|
||||
{
|
||||
ConnectTimeout = options.ConnectTimeout,
|
||||
};
|
||||
|
||||
if (options.UseTls)
|
||||
{
|
||||
handler.SslOptions = new SslClientAuthenticationOptions();
|
||||
if (!string.IsNullOrWhiteSpace(options.ServerNameOverride))
|
||||
{
|
||||
handler.SslOptions.TargetHost = options.ServerNameOverride;
|
||||
}
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(options.CaCertificatePath))
|
||||
{
|
||||
X509Certificate2 trustedRoot = X509CertificateLoader.LoadCertificateFromFile(options.CaCertificatePath);
|
||||
handler.SslOptions.RemoteCertificateValidationCallback = (_, certificate, chain, errors) =>
|
||||
{
|
||||
if (certificate is null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
using X509Chain customChain = new();
|
||||
customChain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
|
||||
customChain.ChainPolicy.CustomTrustStore.Add(trustedRoot);
|
||||
customChain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
|
||||
customChain.ChainPolicy.VerificationFlags = X509VerificationFlags.NoFlag;
|
||||
X509Certificate2 certificateToValidate = certificate as X509Certificate2
|
||||
?? X509CertificateLoader.LoadCertificate(certificate.Export(X509ContentType.Cert));
|
||||
return customChain.Build(certificateToValidate);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return handler;
|
||||
}
|
||||
|
||||
private void ThrowIfDisposed()
|
||||
|
||||
@@ -21,6 +21,8 @@ public sealed class MxGatewayClientOptions
|
||||
|
||||
public TimeSpan DefaultCallTimeout { get; init; } = TimeSpan.FromSeconds(30);
|
||||
|
||||
public TimeSpan? StreamTimeout { get; init; }
|
||||
|
||||
public MxGatewayClientRetryOptions Retry { get; init; } = new();
|
||||
|
||||
public ILoggerFactory? LoggerFactory { get; init; }
|
||||
@@ -57,6 +59,27 @@ public sealed class MxGatewayClientOptions
|
||||
"The default call timeout must be greater than zero.");
|
||||
}
|
||||
|
||||
if (StreamTimeout is not null && StreamTimeout <= TimeSpan.Zero)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(
|
||||
nameof(StreamTimeout),
|
||||
"The stream timeout must be greater than zero when configured.");
|
||||
}
|
||||
|
||||
if (UseTls && Endpoint.Scheme != Uri.UriSchemeHttps)
|
||||
{
|
||||
throw new ArgumentException(
|
||||
"UseTls requires an https gateway endpoint.",
|
||||
nameof(Endpoint));
|
||||
}
|
||||
|
||||
if (!UseTls && Endpoint.Scheme == Uri.UriSchemeHttps)
|
||||
{
|
||||
throw new ArgumentException(
|
||||
"An https gateway endpoint requires UseTls.",
|
||||
nameof(Endpoint));
|
||||
}
|
||||
|
||||
Retry.Validate();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,17 +377,19 @@ func runSmoke(ctx context.Context, args []string, stdout, stderr io.Writer) erro
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer session.Close(context.Background())
|
||||
|
||||
serverHandle, err := session.Register(ctx, *clientName)
|
||||
if err != nil {
|
||||
return err
|
||||
return closeSmokeSession(ctx, session, err)
|
||||
}
|
||||
itemHandle, err := session.AddItem(ctx, serverHandle, *item)
|
||||
if err != nil {
|
||||
return err
|
||||
return closeSmokeSession(ctx, session, err)
|
||||
}
|
||||
if err := session.Advise(ctx, serverHandle, itemHandle); err != nil {
|
||||
return closeSmokeSession(ctx, session, err)
|
||||
}
|
||||
if err := closeSmokeSession(ctx, session, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -406,6 +408,24 @@ func runSmoke(ctx context.Context, args []string, stdout, stderr io.Writer) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeSmokeSession(ctx context.Context, session *mxgateway.Session, primaryErr error) error {
|
||||
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if until := time.Until(deadline); until > 0 && until < 5*time.Second {
|
||||
cancel()
|
||||
closeCtx, cancel = context.WithTimeout(context.Background(), until)
|
||||
defer cancel()
|
||||
}
|
||||
}
|
||||
|
||||
_, closeErr := session.Close(closeCtx)
|
||||
if primaryErr != nil {
|
||||
return primaryErr
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
|
||||
func bindCommonFlags(flags *flag.FlagSet) *commonOptions {
|
||||
common := &commonOptions{}
|
||||
flags.StringVar(&common.Endpoint, "endpoint", "localhost:5000", "gateway endpoint")
|
||||
|
||||
@@ -184,8 +184,11 @@ func (c *Client) callContext(ctx context.Context) (context.Context, context.Canc
|
||||
if timeout < 0 {
|
||||
return ctx, func() {}
|
||||
}
|
||||
if _, ok := ctx.Deadline(); ok {
|
||||
return ctx, func() {}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeoutDeadline := time.Now().Add(timeout)
|
||||
if deadline.Before(timeoutDeadline) {
|
||||
return ctx, func() {}
|
||||
}
|
||||
}
|
||||
return context.WithTimeout(ctx, timeout)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const maxBulkItems = 1000
|
||||
|
||||
// EventResult carries either the next ordered event or a terminal stream error.
|
||||
type EventResult struct {
|
||||
Event *MxEvent
|
||||
@@ -225,6 +228,9 @@ func (s *Session) AddItemBulk(ctx context.Context, serverHandle int32, tagAddres
|
||||
if tagAddresses == nil {
|
||||
return nil, errors.New("mxgateway: tag addresses are required")
|
||||
}
|
||||
if err := ensureBulkSize("tag addresses", len(tagAddresses)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
|
||||
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM_BULK,
|
||||
Payload: &pb.MxCommand_AddItemBulk{
|
||||
@@ -245,6 +251,9 @@ func (s *Session) AdviseItemBulk(ctx context.Context, serverHandle int32, itemHa
|
||||
if itemHandles == nil {
|
||||
return nil, errors.New("mxgateway: item handles are required")
|
||||
}
|
||||
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
|
||||
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE_ITEM_BULK,
|
||||
Payload: &pb.MxCommand_AdviseItemBulk{
|
||||
@@ -265,6 +274,9 @@ func (s *Session) RemoveItemBulk(ctx context.Context, serverHandle int32, itemHa
|
||||
if itemHandles == nil {
|
||||
return nil, errors.New("mxgateway: item handles are required")
|
||||
}
|
||||
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
|
||||
Kind: pb.MxCommandKind_MX_COMMAND_KIND_REMOVE_ITEM_BULK,
|
||||
Payload: &pb.MxCommand_RemoveItemBulk{
|
||||
@@ -285,6 +297,9 @@ func (s *Session) UnAdviseItemBulk(ctx context.Context, serverHandle int32, item
|
||||
if itemHandles == nil {
|
||||
return nil, errors.New("mxgateway: item handles are required")
|
||||
}
|
||||
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
|
||||
Kind: pb.MxCommandKind_MX_COMMAND_KIND_UN_ADVISE_ITEM_BULK,
|
||||
Payload: &pb.MxCommand_UnAdviseItemBulk{
|
||||
@@ -305,6 +320,9 @@ func (s *Session) SubscribeBulk(ctx context.Context, serverHandle int32, tagAddr
|
||||
if tagAddresses == nil {
|
||||
return nil, errors.New("mxgateway: tag addresses are required")
|
||||
}
|
||||
if err := ensureBulkSize("tag addresses", len(tagAddresses)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
|
||||
Kind: pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK,
|
||||
Payload: &pb.MxCommand_SubscribeBulk{
|
||||
@@ -325,6 +343,9 @@ func (s *Session) UnsubscribeBulk(ctx context.Context, serverHandle int32, itemH
|
||||
if itemHandles == nil {
|
||||
return nil, errors.New("mxgateway: item handles are required")
|
||||
}
|
||||
if err := ensureBulkSize("item handles", len(itemHandles)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply, err := s.invokeCommand(ctx, &pb.MxCommand{
|
||||
Kind: pb.MxCommandKind_MX_COMMAND_KIND_UNSUBSCRIBE_BULK,
|
||||
Payload: &pb.MxCommand_UnsubscribeBulk{
|
||||
@@ -387,13 +408,15 @@ func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (
|
||||
for {
|
||||
event, err := stream.Recv()
|
||||
if err == nil {
|
||||
results <- EventResult{Event: event}
|
||||
if !sendEventResult(ctx, results, EventResult{Event: event}) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err == io.EOF || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
results <- EventResult{Err: &GatewayError{Op: "stream events", Err: err}}
|
||||
sendEventResult(ctx, results, EventResult{Err: &GatewayError{Op: "stream events", Err: err}})
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -401,6 +424,22 @@ func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func ensureBulkSize(name string, length int) error {
|
||||
if length > maxBulkItems {
|
||||
return fmt.Errorf("mxgateway: %s bulk commands are limited to %d item(s)", name, maxBulkItems)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendEventResult(ctx context.Context, results chan<- EventResult, result EventResult) bool {
|
||||
select {
|
||||
case results <- result:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) invokeCommand(ctx context.Context, command *MxCommand) (*MxCommandReply, error) {
|
||||
return s.client.Invoke(ctx, &pb.MxCommandRequest{
|
||||
SessionId: s.ID(),
|
||||
|
||||
+21
-18
@@ -334,25 +334,28 @@ public final class MxGatewayCli implements Callable<Integer> {
|
||||
var session = client.openSession(OpenSessionRequest.newBuilder()
|
||||
.setClientSessionName(clientName)
|
||||
.build());
|
||||
MxGatewayCliSession cliSession = client.session(session.getSessionId());
|
||||
int serverHandle = cliSession.register(clientName);
|
||||
int itemHandle = cliSession.addItem(serverHandle, item);
|
||||
cliSession.advise(serverHandle, itemHandle);
|
||||
if (json) {
|
||||
Map<String, Object> output = new LinkedHashMap<>();
|
||||
output.put("command", "smoke");
|
||||
output.put("options", common.redactedJsonMap());
|
||||
output.put("sessionId", session.getSessionId());
|
||||
output.put("serverHandle", serverHandle);
|
||||
output.put("itemHandle", itemHandle);
|
||||
client.out().println(jsonObject(output));
|
||||
} else {
|
||||
client.out().printf(
|
||||
"session=%s server=%d item=%d%n", session.getSessionId(), serverHandle, itemHandle);
|
||||
try {
|
||||
MxGatewayCliSession cliSession = client.session(session.getSessionId());
|
||||
int serverHandle = cliSession.register(clientName);
|
||||
int itemHandle = cliSession.addItem(serverHandle, item);
|
||||
cliSession.advise(serverHandle, itemHandle);
|
||||
if (json) {
|
||||
Map<String, Object> output = new LinkedHashMap<>();
|
||||
output.put("command", "smoke");
|
||||
output.put("options", common.redactedJsonMap());
|
||||
output.put("sessionId", session.getSessionId());
|
||||
output.put("serverHandle", serverHandle);
|
||||
output.put("itemHandle", itemHandle);
|
||||
client.out().println(jsonObject(output));
|
||||
} else {
|
||||
client.out().printf(
|
||||
"session=%s server=%d item=%d%n", session.getSessionId(), serverHandle, itemHandle);
|
||||
}
|
||||
} finally {
|
||||
client.closeSession(CloseSessionRequest.newBuilder()
|
||||
.setSessionId(session.getSessionId())
|
||||
.build());
|
||||
}
|
||||
client.closeSession(CloseSessionRequest.newBuilder()
|
||||
.setSessionId(session.getSessionId())
|
||||
.build());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
+12
-5
@@ -105,13 +105,20 @@ public final class MxEventStream implements Iterator<MxEvent>, AutoCloseable {
|
||||
private void offer(Object value) {
|
||||
Objects.requireNonNull(value, "value");
|
||||
if (value == END) {
|
||||
queue.offer(value);
|
||||
if (!queue.offer(value)) {
|
||||
queue.clear();
|
||||
queue.offer(value);
|
||||
}
|
||||
return;
|
||||
}
|
||||
try {
|
||||
queue.put(value);
|
||||
} catch (InterruptedException error) {
|
||||
Thread.currentThread().interrupt();
|
||||
if (!queue.offer(value)) {
|
||||
ClientCallStreamObserver<StreamEventsRequest> stream = requestStream;
|
||||
if (stream != null) {
|
||||
stream.cancel("client event stream queue overflowed", null);
|
||||
}
|
||||
queue.clear();
|
||||
queue.offer(new MxGatewayException("gateway stream events queue overflowed"));
|
||||
queue.offer(END);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+18
-4
@@ -63,7 +63,7 @@ public final class MxGatewayClient implements AutoCloseable {
|
||||
}
|
||||
|
||||
public MxAccessGatewayGrpc.MxAccessGatewayStub rawAsyncStub() {
|
||||
return withDeadline(asyncStub);
|
||||
return asyncStub;
|
||||
}
|
||||
|
||||
public MxGatewaySession openSession(OpenSessionRequest request) {
|
||||
@@ -140,14 +140,14 @@ public final class MxGatewayClient implements AutoCloseable {
|
||||
|
||||
public MxEventStream streamEvents(StreamEventsRequest request) {
|
||||
MxEventStream stream = new MxEventStream(16);
|
||||
rawAsyncStub().streamEvents(request, stream.observer());
|
||||
withStreamDeadline(rawAsyncStub()).streamEvents(request, stream.observer());
|
||||
return stream;
|
||||
}
|
||||
|
||||
public MxGatewayEventSubscription streamEventsAsync(
|
||||
StreamEventsRequest request, StreamObserver<MxEvent> observer) {
|
||||
MxGatewayEventSubscription subscription = new MxGatewayEventSubscription();
|
||||
rawAsyncStub().streamEvents(request, subscription.wrap(observer));
|
||||
withStreamDeadline(rawAsyncStub()).streamEvents(request, subscription.wrap(observer));
|
||||
return subscription;
|
||||
}
|
||||
|
||||
@@ -161,7 +161,9 @@ public final class MxGatewayClient implements AutoCloseable {
|
||||
public void closeAndAwaitTermination() throws InterruptedException {
|
||||
if (ownedChannel != null) {
|
||||
ownedChannel.shutdown();
|
||||
ownedChannel.awaitTermination(options.connectTimeout().toMillis(), TimeUnit.MILLISECONDS);
|
||||
if (!ownedChannel.awaitTermination(options.connectTimeout().toMillis(), TimeUnit.MILLISECONDS)) {
|
||||
ownedChannel.shutdownNow();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,6 +201,13 @@ public final class MxGatewayClient implements AutoCloseable {
|
||||
return stub.withDeadlineAfter(options.callTimeout().toNanos(), TimeUnit.NANOSECONDS);
|
||||
}
|
||||
|
||||
private <T extends io.grpc.stub.AbstractStub<T>> T withStreamDeadline(T stub) {
|
||||
if (options.streamTimeout() == null || options.streamTimeout().isNegative()) {
|
||||
return stub;
|
||||
}
|
||||
return stub.withDeadlineAfter(options.streamTimeout().toNanos(), TimeUnit.NANOSECONDS);
|
||||
}
|
||||
|
||||
private static <T> CompletableFuture<T> toCompletable(com.google.common.util.concurrent.ListenableFuture<T> source) {
|
||||
CompletableFuture<T> target = new CompletableFuture<>();
|
||||
Futures.addCallback(
|
||||
@@ -219,6 +228,11 @@ public final class MxGatewayClient implements AutoCloseable {
|
||||
}
|
||||
},
|
||||
MoreExecutors.directExecutor());
|
||||
target.whenComplete((ignoredResult, ignoredError) -> {
|
||||
if (target.isCancelled()) {
|
||||
source.cancel(true);
|
||||
}
|
||||
});
|
||||
return target;
|
||||
}
|
||||
|
||||
|
||||
+14
@@ -15,6 +15,7 @@ public final class MxGatewayClientOptions {
|
||||
private final String serverNameOverride;
|
||||
private final Duration connectTimeout;
|
||||
private final Duration callTimeout;
|
||||
private final Duration streamTimeout;
|
||||
|
||||
private MxGatewayClientOptions(Builder builder) {
|
||||
endpoint = requireText(builder.endpoint, "endpoint");
|
||||
@@ -24,6 +25,7 @@ public final class MxGatewayClientOptions {
|
||||
serverNameOverride = builder.serverNameOverride == null ? "" : builder.serverNameOverride;
|
||||
connectTimeout = builder.connectTimeout == null ? DEFAULT_CONNECT_TIMEOUT : builder.connectTimeout;
|
||||
callTimeout = builder.callTimeout == null ? DEFAULT_CALL_TIMEOUT : builder.callTimeout;
|
||||
streamTimeout = builder.streamTimeout;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
@@ -62,6 +64,10 @@ public final class MxGatewayClientOptions {
|
||||
return callTimeout;
|
||||
}
|
||||
|
||||
public Duration streamTimeout() {
|
||||
return streamTimeout;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MxGatewayClientOptions{"
|
||||
@@ -82,6 +88,8 @@ public final class MxGatewayClientOptions {
|
||||
+ connectTimeout
|
||||
+ ", callTimeout="
|
||||
+ callTimeout
|
||||
+ ", streamTimeout="
|
||||
+ streamTimeout
|
||||
+ '}';
|
||||
}
|
||||
|
||||
@@ -100,6 +108,7 @@ public final class MxGatewayClientOptions {
|
||||
private String serverNameOverride;
|
||||
private Duration connectTimeout;
|
||||
private Duration callTimeout;
|
||||
private Duration streamTimeout;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
@@ -139,6 +148,11 @@ public final class MxGatewayClientOptions {
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder streamTimeout(Duration value) {
|
||||
streamTimeout = Objects.requireNonNull(value, "streamTimeout");
|
||||
return this;
|
||||
}
|
||||
|
||||
public MxGatewayClientOptions build() {
|
||||
return new MxGatewayClientOptions(this);
|
||||
}
|
||||
|
||||
+6
@@ -4,17 +4,22 @@ import io.grpc.stub.ClientCallStreamObserver;
|
||||
import io.grpc.stub.ClientResponseObserver;
|
||||
import io.grpc.stub.StreamObserver;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import mxaccess_gateway.v1.MxaccessGateway.MxEvent;
|
||||
import mxaccess_gateway.v1.MxaccessGateway.StreamEventsRequest;
|
||||
|
||||
public final class MxGatewayEventSubscription implements AutoCloseable {
|
||||
private final AtomicReference<ClientCallStreamObserver<StreamEventsRequest>> requestStream = new AtomicReference<>();
|
||||
private final AtomicBoolean cancelled = new AtomicBoolean();
|
||||
|
||||
ClientResponseObserver<StreamEventsRequest, MxEvent> wrap(StreamObserver<MxEvent> observer) {
|
||||
return new ClientResponseObserver<>() {
|
||||
@Override
|
||||
public void beforeStart(ClientCallStreamObserver<StreamEventsRequest> stream) {
|
||||
requestStream.set(stream);
|
||||
if (cancelled.get()) {
|
||||
stream.cancel("client cancelled event stream", null);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -35,6 +40,7 @@ public final class MxGatewayEventSubscription implements AutoCloseable {
|
||||
}
|
||||
|
||||
public void cancel() {
|
||||
cancelled.set(true);
|
||||
ClientCallStreamObserver<StreamEventsRequest> stream = requestStream.get();
|
||||
if (stream != null) {
|
||||
stream.cancel("client cancelled event stream", null);
|
||||
|
||||
@@ -74,9 +74,9 @@ class GatewayClient:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
if self._channel is not None:
|
||||
await self._channel.close()
|
||||
self._closed = True
|
||||
|
||||
async def open_session(
|
||||
self,
|
||||
@@ -124,10 +124,10 @@ class GatewayClient:
|
||||
) -> AsyncIterator[pb.MxEvent]:
|
||||
"""Return an async event iterator and cancel the stream when iteration stops."""
|
||||
|
||||
call = self.raw_stub.StreamEvents(
|
||||
request,
|
||||
metadata=merge_metadata(self.options.api_key, metadata),
|
||||
)
|
||||
kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)}
|
||||
if self.options.stream_timeout is not None:
|
||||
kwargs["timeout"] = self.options.stream_timeout
|
||||
call = self.raw_stub.StreamEvents(request, **kwargs)
|
||||
return _canceling_iterator(call)
|
||||
|
||||
async def _unary(
|
||||
@@ -138,10 +138,16 @@ class GatewayClient:
|
||||
*,
|
||||
metadata: Sequence[tuple[str, str]] | None = None,
|
||||
) -> Any:
|
||||
call = method(
|
||||
request,
|
||||
metadata=merge_metadata(self.options.api_key, metadata),
|
||||
)
|
||||
kwargs: dict[str, Any] = {"metadata": merge_metadata(self.options.api_key, metadata)}
|
||||
if self.options.call_timeout is not None:
|
||||
kwargs["timeout"] = self.options.call_timeout
|
||||
try:
|
||||
call = method(request, **kwargs)
|
||||
except TypeError as error:
|
||||
if "timeout" not in kwargs or "unexpected keyword argument 'timeout'" not in str(error):
|
||||
raise
|
||||
kwargs.pop("timeout")
|
||||
call = method(request, **kwargs)
|
||||
try:
|
||||
return await call
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -19,6 +19,8 @@ class ClientOptions:
|
||||
plaintext: bool = False
|
||||
ca_file: str | None = None
|
||||
server_name_override: str | None = None
|
||||
call_timeout: float | None = 30.0
|
||||
stream_timeout: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.endpoint:
|
||||
@@ -26,6 +28,10 @@ class ClientOptions:
|
||||
|
||||
if self.plaintext and self.ca_file:
|
||||
raise ValueError("ca_file cannot be used with plaintext connections")
|
||||
if self.call_timeout is not None and self.call_timeout <= 0:
|
||||
raise ValueError("call_timeout must be greater than zero")
|
||||
if self.stream_timeout is not None and self.stream_timeout <= 0:
|
||||
raise ValueError("stream_timeout must be greater than zero")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
api_key = REDACTED if self.api_key else None
|
||||
@@ -33,7 +39,9 @@ class ClientOptions:
|
||||
f"{type(self).__name__}(endpoint={self.endpoint!r}, "
|
||||
f"api_key={api_key!r}, plaintext={self.plaintext!r}, "
|
||||
f"ca_file={self.ca_file!r}, "
|
||||
f"server_name_override={self.server_name_override!r})"
|
||||
f"server_name_override={self.server_name_override!r}, "
|
||||
f"call_timeout={self.call_timeout!r}, "
|
||||
f"stream_timeout={self.stream_timeout!r})"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from .errors import ensure_mxaccess_success
|
||||
from .generated import mxaccess_gateway_pb2 as pb
|
||||
from .values import MxValueInput, to_mx_value
|
||||
|
||||
MAX_BULK_ITEMS = 1000
|
||||
|
||||
|
||||
class Session:
|
||||
"""A single gateway-backed MXAccess session."""
|
||||
@@ -40,13 +42,14 @@ class Session:
|
||||
protocol_status=pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK),
|
||||
)
|
||||
|
||||
self._closed = True
|
||||
return await self.client.close_session_raw(
|
||||
reply = await self.client.close_session_raw(
|
||||
pb.CloseSessionRequest(
|
||||
session_id=self.session_id,
|
||||
client_correlation_id=client_correlation_id,
|
||||
),
|
||||
)
|
||||
self._closed = True
|
||||
return reply
|
||||
|
||||
async def invoke(self, command: pb.MxCommand, *, correlation_id: str = "") -> pb.MxCommandReply:
|
||||
"""Invoke a raw command and enforce gateway and MXAccess success."""
|
||||
@@ -192,6 +195,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if tag_addresses is None:
|
||||
raise TypeError("tag_addresses is required")
|
||||
_ensure_bulk_size("tag_addresses", len(tag_addresses))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADD_ITEM_BULK,
|
||||
@@ -213,6 +217,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_ADVISE_ITEM_BULK,
|
||||
@@ -234,6 +239,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_REMOVE_ITEM_BULK,
|
||||
@@ -255,6 +261,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_UN_ADVISE_ITEM_BULK,
|
||||
@@ -276,6 +283,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if tag_addresses is None:
|
||||
raise TypeError("tag_addresses is required")
|
||||
_ensure_bulk_size("tag_addresses", len(tag_addresses))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_SUBSCRIBE_BULK,
|
||||
@@ -297,6 +305,7 @@ class Session:
|
||||
) -> list[pb.SubscribeResult]:
|
||||
if item_handles is None:
|
||||
raise TypeError("item_handles is required")
|
||||
_ensure_bulk_size("item_handles", len(item_handles))
|
||||
reply = await self.invoke(
|
||||
pb.MxCommand(
|
||||
kind=pb.MX_COMMAND_KIND_UNSUBSCRIBE_BULK,
|
||||
@@ -368,4 +377,9 @@ class Session:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_bulk_size(name: str, count: int) -> None:
|
||||
if count > MAX_BULK_ITEMS:
|
||||
raise ValueError(f"{name} bulk commands are limited to {MAX_BULK_ITEMS} item(s)")
|
||||
|
||||
|
||||
from .client import GatewayClient # noqa: E402
|
||||
|
||||
@@ -20,6 +20,8 @@ from mxgateway.generated import mxaccess_gateway_pb2 as pb
|
||||
from mxgateway.options import ClientOptions
|
||||
from mxgateway.values import MxValueInput
|
||||
|
||||
MAX_AGGREGATE_EVENTS = 10_000
|
||||
|
||||
|
||||
@click.group()
|
||||
def main() -> None:
|
||||
@@ -55,6 +57,8 @@ def gateway_options(command: Callable[..., Any]) -> Callable[..., Any]:
|
||||
default=None,
|
||||
help="TLS server name override for test environments.",
|
||||
)(command)
|
||||
command = click.option("--call-timeout", default=30.0, type=float, show_default=True)(command)
|
||||
command = click.option("--stream-timeout", default=None, type=float)(command)
|
||||
return command
|
||||
|
||||
|
||||
@@ -352,6 +356,8 @@ async def _connect(kwargs: dict[str, Any]) -> GatewayClient:
|
||||
plaintext=_use_plaintext(kwargs),
|
||||
ca_file=kwargs.get("ca_file"),
|
||||
server_name_override=kwargs.get("server_name_override"),
|
||||
call_timeout=kwargs.get("call_timeout"),
|
||||
stream_timeout=kwargs.get("stream_timeout"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -416,6 +422,12 @@ async def _collect_events(
|
||||
max_events: int,
|
||||
timeout: float,
|
||||
) -> list[pb.MxEvent]:
|
||||
if max_events > MAX_AGGREGATE_EVENTS:
|
||||
raise click.BadParameter(
|
||||
f"must be less than or equal to {MAX_AGGREGATE_EVENTS}",
|
||||
param_hint="--max-events",
|
||||
)
|
||||
|
||||
collected: list[pb.MxEvent] = []
|
||||
iterator = events.__aiter__()
|
||||
try:
|
||||
@@ -423,6 +435,8 @@ async def _collect_events(
|
||||
collected.append(await asyncio.wait_for(iterator.__anext__(), timeout=timeout))
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
close = getattr(iterator, "aclose", None)
|
||||
if close is not None:
|
||||
|
||||
@@ -16,6 +16,8 @@ use mxgateway_client::{
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
|
||||
const MAX_AGGREGATE_EVENTS: usize = 10_000;
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(name = "mxgw")]
|
||||
#[command(about = "MXAccess Gateway Rust test CLI")]
|
||||
@@ -29,6 +31,8 @@ enum Command {
|
||||
Version {
|
||||
#[arg(long)]
|
||||
json: bool,
|
||||
#[arg(long)]
|
||||
jsonl: bool,
|
||||
},
|
||||
Ping {
|
||||
#[command(flatten)]
|
||||
@@ -325,7 +329,15 @@ async fn run(cli: Cli) -> Result<(), Error> {
|
||||
after_worker_sequence,
|
||||
max_events,
|
||||
json,
|
||||
jsonl,
|
||||
} => {
|
||||
if max_events > MAX_AGGREGATE_EVENTS {
|
||||
return Err(Error::InvalidArgument {
|
||||
name: "max-events".to_owned(),
|
||||
detail: format!("must be less than or equal to {MAX_AGGREGATE_EVENTS}"),
|
||||
});
|
||||
}
|
||||
|
||||
let client = connect(connection).await?;
|
||||
let mut stream = client
|
||||
.stream_events(StreamEventsRequest {
|
||||
@@ -334,19 +346,30 @@ async fn run(cli: Cli) -> Result<(), Error> {
|
||||
})
|
||||
.await?;
|
||||
let mut events = Vec::new();
|
||||
while events.len() < max_events {
|
||||
let mut event_count = 0usize;
|
||||
while event_count < max_events {
|
||||
let Some(event) = stream.next().await else {
|
||||
break;
|
||||
};
|
||||
events.push(event?);
|
||||
}
|
||||
if json {
|
||||
println!("{}", json!({ "eventCount": events.len() }));
|
||||
} else {
|
||||
for event in events {
|
||||
let event = event?;
|
||||
event_count += 1;
|
||||
if jsonl {
|
||||
println!(
|
||||
"{}",
|
||||
json!({
|
||||
"workerSequence": event.worker_sequence,
|
||||
"family": event.family,
|
||||
})
|
||||
);
|
||||
} else if json {
|
||||
events.push(event);
|
||||
} else {
|
||||
println!("{} {}", event.worker_sequence, event.family);
|
||||
}
|
||||
}
|
||||
if json {
|
||||
println!("{}", json!({ "eventCount": event_count }));
|
||||
}
|
||||
}
|
||||
Command::Write {
|
||||
connection,
|
||||
|
||||
@@ -5,7 +5,7 @@ use tonic::transport::{Certificate, Channel, ClientTlsConfig};
|
||||
use tonic::Request;
|
||||
|
||||
use crate::auth::AuthInterceptor;
|
||||
use crate::error::{ensure_command_success, Error};
|
||||
use crate::error::{ensure_command_success, ensure_protocol_success, Error};
|
||||
use crate::generated::mxaccess_gateway::v1::mx_access_gateway_client::MxAccessGatewayClient;
|
||||
use crate::generated::mxaccess_gateway::v1::{
|
||||
CloseSessionReply, CloseSessionRequest, MxCommandReply, MxCommandRequest, MxEvent,
|
||||
@@ -23,6 +23,7 @@ pub type EventStream =
|
||||
pub struct GatewayClient {
|
||||
inner: RawGatewayClient,
|
||||
call_timeout: std::time::Duration,
|
||||
stream_timeout: Option<std::time::Duration>,
|
||||
}
|
||||
|
||||
impl GatewayClient {
|
||||
@@ -57,6 +58,7 @@ impl GatewayClient {
|
||||
Ok(Self {
|
||||
inner: MxAccessGatewayClient::with_interceptor(channel, interceptor),
|
||||
call_timeout: options.call_timeout(),
|
||||
stream_timeout: options.stream_timeout(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -83,6 +85,7 @@ impl GatewayClient {
|
||||
|
||||
pub async fn open_session(&self, request: OpenSessionRequest) -> Result<Session, Error> {
|
||||
let reply = self.open_session_raw(request).await?;
|
||||
ensure_protocol_success("open session", reply.protocol_status.as_ref())?;
|
||||
Ok(Session::new(reply.session_id, self.clone()))
|
||||
}
|
||||
|
||||
@@ -107,7 +110,7 @@ impl GatewayClient {
|
||||
|
||||
pub async fn stream_events(&self, request: StreamEventsRequest) -> Result<EventStream, Error> {
|
||||
let mut client = self.inner.clone();
|
||||
let response = client.stream_events(self.unary_request(request)).await?;
|
||||
let response = client.stream_events(self.stream_request(request)).await?;
|
||||
let stream = futures_util::StreamExt::map(response.into_inner(), |result| {
|
||||
result.map_err(Error::from)
|
||||
});
|
||||
@@ -120,4 +123,13 @@ impl GatewayClient {
|
||||
request.set_timeout(self.call_timeout);
|
||||
request
|
||||
}
|
||||
|
||||
fn stream_request<T>(&self, message: T) -> Request<T> {
|
||||
let mut request = Request::new(message);
|
||||
if let Some(timeout) = self.stream_timeout {
|
||||
request.set_timeout(timeout);
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use thiserror::Error as ThisError;
|
||||
use tonic::Code;
|
||||
|
||||
use crate::generated::mxaccess_gateway::v1::{MxCommandReply, ProtocolStatusCode};
|
||||
use crate::generated::mxaccess_gateway::v1::{MxCommandReply, ProtocolStatus, ProtocolStatusCode};
|
||||
|
||||
#[derive(Debug, ThisError)]
|
||||
pub enum Error {
|
||||
@@ -47,6 +47,13 @@ pub enum Error {
|
||||
|
||||
#[error("gateway command failed: {0}")]
|
||||
Command(#[from] Box<CommandError>),
|
||||
|
||||
#[error("gateway {operation} failed: {code:?}: {message}")]
|
||||
ProtocolStatus {
|
||||
operation: &'static str,
|
||||
code: ProtocolStatusCode,
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -125,6 +132,27 @@ pub fn ensure_command_success(reply: MxCommandReply) -> Result<MxCommandReply, E
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ensure_protocol_success(
|
||||
operation: &'static str,
|
||||
status: Option<&ProtocolStatus>,
|
||||
) -> Result<(), Error> {
|
||||
let code = status
|
||||
.and_then(|status| ProtocolStatusCode::try_from(status.code).ok())
|
||||
.unwrap_or(ProtocolStatusCode::Unspecified);
|
||||
|
||||
if code == ProtocolStatusCode::Ok {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::ProtocolStatus {
|
||||
operation,
|
||||
code,
|
||||
message: status
|
||||
.map(|status| status.message.clone())
|
||||
.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn redact_credentials(message: &str) -> String {
|
||||
message
|
||||
.split_whitespace()
|
||||
|
||||
@@ -13,6 +13,7 @@ pub struct ClientOptions {
|
||||
server_name_override: Option<String>,
|
||||
connect_timeout: Duration,
|
||||
call_timeout: Duration,
|
||||
stream_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl ClientOptions {
|
||||
@@ -25,6 +26,7 @@ impl ClientOptions {
|
||||
server_name_override: None,
|
||||
connect_timeout: Duration::from_secs(10),
|
||||
call_timeout: Duration::from_secs(30),
|
||||
stream_timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +60,11 @@ impl ClientOptions {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_stream_timeout(mut self, stream_timeout: Duration) -> Self {
|
||||
self.stream_timeout = Some(stream_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn endpoint(&self) -> &str {
|
||||
&self.endpoint
|
||||
}
|
||||
@@ -85,6 +92,10 @@ impl ClientOptions {
|
||||
pub fn call_timeout(&self) -> Duration {
|
||||
self.call_timeout
|
||||
}
|
||||
|
||||
pub fn stream_timeout(&self) -> Option<Duration> {
|
||||
self.stream_timeout
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientOptions {
|
||||
@@ -104,6 +115,7 @@ impl fmt::Debug for ClientOptions {
|
||||
.field("server_name_override", &self.server_name_override)
|
||||
.field("connect_timeout", &self.connect_timeout)
|
||||
.field("call_timeout", &self.call_timeout)
|
||||
.field("stream_timeout", &self.stream_timeout)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::client::{EventStream, GatewayClient};
|
||||
use crate::error::Error;
|
||||
use crate::error::{ensure_protocol_success, Error};
|
||||
use crate::generated::mxaccess_gateway::v1::mx_command::Payload;
|
||||
use crate::generated::mxaccess_gateway::v1::mx_command_reply;
|
||||
use crate::generated::mxaccess_gateway::v1::{
|
||||
@@ -11,6 +11,8 @@ use crate::generated::mxaccess_gateway::v1::{
|
||||
};
|
||||
use crate::value::MxValue;
|
||||
|
||||
const MAX_BULK_ITEMS: usize = 1_000;
|
||||
|
||||
/// Session identifier returned by the gateway.
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
@@ -40,12 +42,14 @@ impl Session {
|
||||
}
|
||||
|
||||
pub async fn close(&self) -> Result<(), Error> {
|
||||
self.client
|
||||
let reply = self
|
||||
.client
|
||||
.close_session_raw(CloseSessionRequest {
|
||||
session_id: self.id.clone(),
|
||||
client_correlation_id: "rust-client-close-session".to_owned(),
|
||||
})
|
||||
.await?;
|
||||
ensure_protocol_success("close session", reply.protocol_status.as_ref())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -137,6 +141,7 @@ impl Session {
|
||||
server_handle: i32,
|
||||
tag_addresses: Vec<String>,
|
||||
) -> Result<Vec<SubscribeResult>, Error> {
|
||||
ensure_bulk_size("tag_addresses", tag_addresses.len())?;
|
||||
let reply = self
|
||||
.invoke(
|
||||
MxCommandKind::AddItemBulk,
|
||||
@@ -155,6 +160,7 @@ impl Session {
|
||||
server_handle: i32,
|
||||
item_handles: Vec<i32>,
|
||||
) -> Result<Vec<SubscribeResult>, Error> {
|
||||
ensure_bulk_size("item_handles", item_handles.len())?;
|
||||
let reply = self
|
||||
.invoke(
|
||||
MxCommandKind::AdviseItemBulk,
|
||||
@@ -173,6 +179,7 @@ impl Session {
|
||||
server_handle: i32,
|
||||
item_handles: Vec<i32>,
|
||||
) -> Result<Vec<SubscribeResult>, Error> {
|
||||
ensure_bulk_size("item_handles", item_handles.len())?;
|
||||
let reply = self
|
||||
.invoke(
|
||||
MxCommandKind::RemoveItemBulk,
|
||||
@@ -191,6 +198,7 @@ impl Session {
|
||||
server_handle: i32,
|
||||
item_handles: Vec<i32>,
|
||||
) -> Result<Vec<SubscribeResult>, Error> {
|
||||
ensure_bulk_size("item_handles", item_handles.len())?;
|
||||
let reply = self
|
||||
.invoke(
|
||||
MxCommandKind::UnAdviseItemBulk,
|
||||
@@ -209,6 +217,7 @@ impl Session {
|
||||
server_handle: i32,
|
||||
tag_addresses: Vec<String>,
|
||||
) -> Result<Vec<SubscribeResult>, Error> {
|
||||
ensure_bulk_size("tag_addresses", tag_addresses.len())?;
|
||||
let reply = self
|
||||
.invoke(
|
||||
MxCommandKind::SubscribeBulk,
|
||||
@@ -227,6 +236,7 @@ impl Session {
|
||||
server_handle: i32,
|
||||
item_handles: Vec<i32>,
|
||||
) -> Result<Vec<SubscribeResult>, Error> {
|
||||
ensure_bulk_size("item_handles", item_handles.len())?;
|
||||
let reply = self
|
||||
.invoke(
|
||||
MxCommandKind::UnsubscribeBulk,
|
||||
@@ -327,6 +337,17 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_bulk_size(name: &'static str, len: usize) -> Result<(), Error> {
|
||||
if len > MAX_BULK_ITEMS {
|
||||
Err(Error::InvalidArgument {
|
||||
name: name.to_owned(),
|
||||
detail: format!("bulk commands are limited to {MAX_BULK_ITEMS} item(s)"),
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn register_server_handle(reply: &MxCommandReply) -> i32 {
|
||||
match reply.payload.as_ref() {
|
||||
Some(mx_command_reply::Payload::Register(register)) => register.server_handle,
|
||||
|
||||
Reference in New Issue
Block a user