package mxgateway import ( "context" "crypto/tls" "errors" "time" pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/types/known/durationpb" ) const ( defaultDialTimeout = 10 * time.Second defaultCallTimeout = 30 * time.Second ) // Client owns a gateway gRPC connection and exposes session-oriented helpers. type Client struct { conn *grpc.ClientConn raw pb.MxAccessGatewayClient opts Options } // Dial opens a gRPC connection to the gateway and configures auth metadata, // transport security, and blocking dial cancellation from ctx. func Dial(ctx context.Context, opts Options) (*Client, error) { if opts.Endpoint == "" { return nil, errors.New("mxgateway: endpoint is required") } dialCtx := ctx cancel := func() {} if opts.DialTimeout > 0 { dialCtx, cancel = context.WithTimeout(ctx, opts.DialTimeout) } else if _, ok := ctx.Deadline(); !ok { dialCtx, cancel = context.WithTimeout(ctx, defaultDialTimeout) } defer cancel() transportCredentials, err := resolveTransportCredentials(opts) if err != nil { return nil, err } dialOptions := []grpc.DialOption{ grpc.WithTransportCredentials(transportCredentials), grpc.WithUnaryInterceptor(unaryAuthInterceptor(opts.APIKey)), grpc.WithStreamInterceptor(streamAuthInterceptor(opts.APIKey)), grpc.WithBlock(), } dialOptions = append(dialOptions, opts.DialOptions...) conn, err := grpc.DialContext(dialCtx, opts.Endpoint, dialOptions...) if err != nil { return nil, &GatewayError{Op: "dial", Err: err} } return NewClient(conn, opts), nil } // NewClient wraps an existing gRPC connection. The caller owns closing conn // unless it calls Close on the returned Client. func NewClient(conn *grpc.ClientConn, opts Options) *Client { return &Client{ conn: conn, raw: pb.NewMxAccessGatewayClient(conn), opts: opts, } } // RawClient returns the generated gRPC client for command-specific parity tests. func (c *Client) RawClient() RawGatewayClient { return c.raw } // OpenSession creates a gateway-backed MXAccess session. func (c *Client) OpenSession(ctx context.Context, opts OpenSessionOptions) (*Session, error) { reply, err := c.OpenSessionRaw(ctx, opts.Request()) if err != nil { return nil, err } return newSession(c, reply), nil } // OpenSessionRaw sends a raw OpenSession request and validates protocol status. func (c *Client) OpenSessionRaw(ctx context.Context, req *OpenSessionRequest) (*OpenSessionReply, error) { if req == nil { return nil, errors.New("mxgateway: open session request is required") } callCtx, cancel := c.callContext(ctx) defer cancel() reply, err := c.raw.OpenSession(callCtx, req) if err != nil { return nil, &GatewayError{Op: "open session", Err: err} } if err := EnsureProtocolSuccess("open session", reply.GetProtocolStatus(), nil); err != nil { return reply, err } return reply, nil } // Invoke sends a raw MXAccess command request and validates protocol and // MXAccess status fields while preserving the raw reply on typed errors. func (c *Client) Invoke(ctx context.Context, req *MxCommandRequest) (*MxCommandReply, error) { if req == nil { return nil, errors.New("mxgateway: command request is required") } callCtx, cancel := c.callContext(ctx) defer cancel() reply, err := c.raw.Invoke(callCtx, req) if err != nil { return nil, &GatewayError{Op: "invoke", Err: err} } if err := EnsureProtocolSuccess("invoke", reply.GetProtocolStatus(), reply); err != nil { return reply, err } if err := EnsureMxAccessSuccess("invoke", reply); err != nil { return reply, err } return reply, nil } // CloseSessionRaw sends a raw CloseSession request and validates protocol // status. func (c *Client) CloseSessionRaw(ctx context.Context, req *CloseSessionRequest) (*CloseSessionReply, error) { if req == nil { return nil, errors.New("mxgateway: close session request is required") } callCtx, cancel := c.callContext(ctx) defer cancel() reply, err := c.raw.CloseSession(callCtx, req) if err != nil { return nil, &GatewayError{Op: "close session", Err: err} } if err := EnsureProtocolSuccess("close session", reply.GetProtocolStatus(), nil); err != nil { return reply, err } return reply, nil } // StreamEventsRaw starts the generated event stream for callers that need direct // control over Recv. func (c *Client) StreamEventsRaw(ctx context.Context, req *StreamEventsRequest) (RawEventStream, error) { if req == nil { return nil, errors.New("mxgateway: stream events request is required") } stream, err := c.raw.StreamEvents(ctx, req) if err != nil { return nil, &GatewayError{Op: "stream events", Err: err} } return stream, nil } // Close closes the underlying gRPC connection. func (c *Client) Close() error { if c == nil || c.conn == nil { return nil } return c.conn.Close() } func (c *Client) callContext(ctx context.Context) (context.Context, context.CancelFunc) { timeout := c.opts.CallTimeout if timeout == 0 { timeout = defaultCallTimeout } if timeout < 0 { return ctx, func() {} } if _, ok := ctx.Deadline(); ok { return ctx, func() {} } return context.WithTimeout(ctx, timeout) } func resolveTransportCredentials(opts Options) (credentials.TransportCredentials, error) { if opts.TransportCredentials != nil { return opts.TransportCredentials, nil } if opts.Plaintext { return insecure.NewCredentials(), nil } if opts.CACertFile != "" { return credentials.NewClientTLSFromFile(opts.CACertFile, opts.ServerNameOverride) } if opts.TLSConfig != nil { cfg := opts.TLSConfig.Clone() if opts.ServerNameOverride != "" { cfg.ServerName = opts.ServerNameOverride } return credentials.NewTLS(cfg), nil } return credentials.NewTLS(&tls.Config{ MinVersion: tls.VersionTLS12, ServerName: opts.ServerNameOverride, }), nil } // OpenSessionOptions describes fields used to create an OpenSessionRequest. type OpenSessionOptions struct { RequestedBackend string ClientSessionName string ClientCorrelationID string CommandTimeout time.Duration } // Request returns the raw protobuf OpenSessionRequest for these options. func (o OpenSessionOptions) Request() *OpenSessionRequest { req := &OpenSessionRequest{ RequestedBackend: o.RequestedBackend, ClientSessionName: o.ClientSessionName, ClientCorrelationId: o.ClientCorrelationID, } if o.CommandTimeout > 0 { req.CommandTimeout = durationpb.New(o.CommandTimeout) } return req }