237 lines
6.4 KiB
Go
237 lines
6.4 KiB
Go
package mxgateway
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"time"
|
|
|
|
pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/protobuf/types/known/durationpb"
|
|
)
|
|
|
|
const (
|
|
defaultDialTimeout = 10 * time.Second
|
|
defaultCallTimeout = 30 * time.Second
|
|
)
|
|
|
|
// Client owns a gateway gRPC connection and exposes session-oriented helpers.
|
|
type Client struct {
|
|
conn *grpc.ClientConn
|
|
raw pb.MxAccessGatewayClient
|
|
opts Options
|
|
}
|
|
|
|
// Dial opens a gRPC connection to the gateway and configures auth metadata,
|
|
// transport security, and blocking dial cancellation from ctx.
|
|
func Dial(ctx context.Context, opts Options) (*Client, error) {
|
|
if opts.Endpoint == "" {
|
|
return nil, errors.New("mxgateway: endpoint is required")
|
|
}
|
|
|
|
dialCtx := ctx
|
|
cancel := func() {}
|
|
if opts.DialTimeout > 0 {
|
|
dialCtx, cancel = context.WithTimeout(ctx, opts.DialTimeout)
|
|
} else if _, ok := ctx.Deadline(); !ok {
|
|
dialCtx, cancel = context.WithTimeout(ctx, defaultDialTimeout)
|
|
}
|
|
defer cancel()
|
|
|
|
transportCredentials, err := resolveTransportCredentials(opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dialOptions := []grpc.DialOption{
|
|
grpc.WithTransportCredentials(transportCredentials),
|
|
grpc.WithUnaryInterceptor(unaryAuthInterceptor(opts.APIKey)),
|
|
grpc.WithStreamInterceptor(streamAuthInterceptor(opts.APIKey)),
|
|
grpc.WithBlock(),
|
|
}
|
|
dialOptions = append(dialOptions, opts.DialOptions...)
|
|
|
|
conn, err := grpc.DialContext(dialCtx, opts.Endpoint, dialOptions...)
|
|
if err != nil {
|
|
return nil, &GatewayError{Op: "dial", Err: err}
|
|
}
|
|
|
|
return NewClient(conn, opts), nil
|
|
}
|
|
|
|
// NewClient wraps an existing gRPC connection. The caller owns closing conn
|
|
// unless it calls Close on the returned Client.
|
|
func NewClient(conn *grpc.ClientConn, opts Options) *Client {
|
|
return &Client{
|
|
conn: conn,
|
|
raw: pb.NewMxAccessGatewayClient(conn),
|
|
opts: opts,
|
|
}
|
|
}
|
|
|
|
// RawClient returns the generated gRPC client for command-specific parity tests.
|
|
func (c *Client) RawClient() RawGatewayClient {
|
|
return c.raw
|
|
}
|
|
|
|
// OpenSession creates a gateway-backed MXAccess session.
|
|
func (c *Client) OpenSession(ctx context.Context, opts OpenSessionOptions) (*Session, error) {
|
|
reply, err := c.OpenSessionRaw(ctx, opts.Request())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return newSession(c, reply), nil
|
|
}
|
|
|
|
// OpenSessionRaw sends a raw OpenSession request and validates protocol status.
|
|
func (c *Client) OpenSessionRaw(ctx context.Context, req *OpenSessionRequest) (*OpenSessionReply, error) {
|
|
if req == nil {
|
|
return nil, errors.New("mxgateway: open session request is required")
|
|
}
|
|
|
|
callCtx, cancel := c.callContext(ctx)
|
|
defer cancel()
|
|
|
|
reply, err := c.raw.OpenSession(callCtx, req)
|
|
if err != nil {
|
|
return nil, &GatewayError{Op: "open session", Err: err}
|
|
}
|
|
if err := EnsureProtocolSuccess("open session", reply.GetProtocolStatus(), nil); err != nil {
|
|
return reply, err
|
|
}
|
|
|
|
return reply, nil
|
|
}
|
|
|
|
// Invoke sends a raw MXAccess command request and validates protocol and
|
|
// MXAccess status fields while preserving the raw reply on typed errors.
|
|
func (c *Client) Invoke(ctx context.Context, req *MxCommandRequest) (*MxCommandReply, error) {
|
|
if req == nil {
|
|
return nil, errors.New("mxgateway: command request is required")
|
|
}
|
|
|
|
callCtx, cancel := c.callContext(ctx)
|
|
defer cancel()
|
|
|
|
reply, err := c.raw.Invoke(callCtx, req)
|
|
if err != nil {
|
|
return nil, &GatewayError{Op: "invoke", Err: err}
|
|
}
|
|
if err := EnsureProtocolSuccess("invoke", reply.GetProtocolStatus(), reply); err != nil {
|
|
return reply, err
|
|
}
|
|
if err := EnsureMxAccessSuccess("invoke", reply); err != nil {
|
|
return reply, err
|
|
}
|
|
|
|
return reply, nil
|
|
}
|
|
|
|
// CloseSessionRaw sends a raw CloseSession request and validates protocol
|
|
// status.
|
|
func (c *Client) CloseSessionRaw(ctx context.Context, req *CloseSessionRequest) (*CloseSessionReply, error) {
|
|
if req == nil {
|
|
return nil, errors.New("mxgateway: close session request is required")
|
|
}
|
|
|
|
callCtx, cancel := c.callContext(ctx)
|
|
defer cancel()
|
|
|
|
reply, err := c.raw.CloseSession(callCtx, req)
|
|
if err != nil {
|
|
return nil, &GatewayError{Op: "close session", Err: err}
|
|
}
|
|
if err := EnsureProtocolSuccess("close session", reply.GetProtocolStatus(), nil); err != nil {
|
|
return reply, err
|
|
}
|
|
|
|
return reply, nil
|
|
}
|
|
|
|
// StreamEventsRaw starts the generated event stream for callers that need direct
|
|
// control over Recv.
|
|
func (c *Client) StreamEventsRaw(ctx context.Context, req *StreamEventsRequest) (RawEventStream, error) {
|
|
if req == nil {
|
|
return nil, errors.New("mxgateway: stream events request is required")
|
|
}
|
|
|
|
stream, err := c.raw.StreamEvents(ctx, req)
|
|
if err != nil {
|
|
return nil, &GatewayError{Op: "stream events", Err: err}
|
|
}
|
|
|
|
return stream, nil
|
|
}
|
|
|
|
// Close closes the underlying gRPC connection.
|
|
func (c *Client) Close() error {
|
|
if c == nil || c.conn == nil {
|
|
return nil
|
|
}
|
|
|
|
return c.conn.Close()
|
|
}
|
|
|
|
func (c *Client) callContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
|
timeout := c.opts.CallTimeout
|
|
if timeout == 0 {
|
|
timeout = defaultCallTimeout
|
|
}
|
|
if timeout < 0 {
|
|
return ctx, func() {}
|
|
}
|
|
if _, ok := ctx.Deadline(); ok {
|
|
return ctx, func() {}
|
|
}
|
|
return context.WithTimeout(ctx, timeout)
|
|
}
|
|
|
|
func resolveTransportCredentials(opts Options) (credentials.TransportCredentials, error) {
|
|
if opts.TransportCredentials != nil {
|
|
return opts.TransportCredentials, nil
|
|
}
|
|
if opts.Plaintext {
|
|
return insecure.NewCredentials(), nil
|
|
}
|
|
if opts.CACertFile != "" {
|
|
return credentials.NewClientTLSFromFile(opts.CACertFile, opts.ServerNameOverride)
|
|
}
|
|
if opts.TLSConfig != nil {
|
|
cfg := opts.TLSConfig.Clone()
|
|
if opts.ServerNameOverride != "" {
|
|
cfg.ServerName = opts.ServerNameOverride
|
|
}
|
|
return credentials.NewTLS(cfg), nil
|
|
}
|
|
|
|
return credentials.NewTLS(&tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
ServerName: opts.ServerNameOverride,
|
|
}), nil
|
|
}
|
|
|
|
// OpenSessionOptions describes fields used to create an OpenSessionRequest.
|
|
type OpenSessionOptions struct {
|
|
RequestedBackend string
|
|
ClientSessionName string
|
|
ClientCorrelationID string
|
|
CommandTimeout time.Duration
|
|
}
|
|
|
|
// Request returns the raw protobuf OpenSessionRequest for these options.
|
|
func (o OpenSessionOptions) Request() *OpenSessionRequest {
|
|
req := &OpenSessionRequest{
|
|
RequestedBackend: o.RequestedBackend,
|
|
ClientSessionName: o.ClientSessionName,
|
|
ClientCorrelationId: o.ClientCorrelationID,
|
|
}
|
|
if o.CommandTimeout > 0 {
|
|
req.CommandTimeout = durationpb.New(o.CommandTimeout)
|
|
}
|
|
return req
|
|
}
|