From 87930118383703a78ca7bec1662dbdf994fe6ab5 Mon Sep 17 00:00:00 2001 From: Joseph Doherty Date: Sun, 26 Apr 2026 20:09:58 -0400 Subject: [PATCH] Issue #42: implement Go client session values errors and CLI --- clients/go/README.md | 40 +- clients/go/cmd/mxgw-go/main.go | 489 +++++++++++++++++++- clients/go/cmd/mxgw-go/main_test.go | 58 +++ clients/go/mxgateway/auth.go | 30 ++ clients/go/mxgateway/client.go | 236 ++++++++++ clients/go/mxgateway/client_session_test.go | 261 +++++++++++ clients/go/mxgateway/conversion_test.go | 73 +++ clients/go/mxgateway/errors.go | 118 +++++ clients/go/mxgateway/options.go | 26 +- clients/go/mxgateway/session.go | 260 +++++++++++ clients/go/mxgateway/status.go | 6 + clients/go/mxgateway/types.go | 70 +++ clients/go/mxgateway/values.go | 148 ++++++ 13 files changed, 1791 insertions(+), 24 deletions(-) create mode 100644 clients/go/cmd/mxgw-go/main_test.go create mode 100644 clients/go/mxgateway/auth.go create mode 100644 clients/go/mxgateway/client.go create mode 100644 clients/go/mxgateway/client_session_test.go create mode 100644 clients/go/mxgateway/conversion_test.go create mode 100644 clients/go/mxgateway/errors.go create mode 100644 clients/go/mxgateway/session.go create mode 100644 clients/go/mxgateway/status.go create mode 100644 clients/go/mxgateway/types.go create mode 100644 clients/go/mxgateway/values.go diff --git a/clients/go/README.md b/clients/go/README.md index d0d0fa4..00c1024 100644 --- a/clients/go/README.md +++ b/clients/go/README.md @@ -37,19 +37,47 @@ Run the Go module checks from `clients/go`: ```powershell go test ./... go build ./... +go vet ./... ``` -The scaffold tests parse the shared golden JSON fixtures with the generated Go -types. Later client implementation tests add fake gRPC services, auth metadata, -streaming, value conversion, and CLI behavior. +The tests parse the shared JSON fixtures, exercise value and status conversion, +use `bufconn` for fake gateway auth and streaming behavior, and cover CLI JSON +redaction. + +## Client API + +Use `mxgateway.Dial` with `mxgateway.Options` to configure plaintext or TLS +transport, API-key metadata, dial timeout, and per-call timeout: + +```go +client, err := mxgateway.Dial(ctx, mxgateway.Options{ + Endpoint: "localhost:5000", + APIKey: os.Getenv("MXGATEWAY_API_KEY"), + Plaintext: true, +}) +``` + +`Client.OpenSession` returns a `Session` with helpers for `Register`, +`AddItem`, `AddItem2`, `Advise`, `Write`, `Events`, and `Close`. Raw protobuf +messages remain available through the `mxgateway` package aliases and the +`Raw` helper methods. Typed errors support `errors.As` for `GatewayError`, +`CommandError`, and `MxAccessError`; command errors preserve the raw reply. ## CLI -The scaffold CLI exposes version information: +The `mxgw-go` CLI emits JSON with redacted API keys for commands that connect to +the gateway: ```powershell go run ./cmd/mxgw-go version -json +go run ./cmd/mxgw-go open-session -endpoint localhost:5000 -plaintext -json +go run ./cmd/mxgw-go register -session-id -client-name mxgw-go -plaintext -json +go run ./cmd/mxgw-go add-item -session-id -server-handle 1 -item Area001.Tag.Value -plaintext -json +go run ./cmd/mxgw-go advise -session-id -server-handle 1 -item-handle 1 -plaintext -json +go run ./cmd/mxgw-go write -session-id -server-handle 1 -item-handle 1 -type int32 -value 123 -plaintext -json +go run ./cmd/mxgw-go stream-events -session-id -plaintext -json +go run ./cmd/mxgw-go smoke -item Area001.Tag.Value -plaintext -json ``` -Additional commands are implemented with the client/session wrapper work. - +Use `-api-key-env MXGATEWAY_API_KEY` or `-api-key ` when authentication is +enabled. CLI output redacts the key value and never writes the raw secret. diff --git a/clients/go/cmd/mxgw-go/main.go b/clients/go/cmd/mxgw-go/main.go index bbc21c9..c3a0504 100644 --- a/clients/go/cmd/mxgw-go/main.go +++ b/clients/go/cmd/mxgw-go/main.go @@ -1,12 +1,19 @@ package main import ( + "context" "encoding/json" + "errors" "flag" "fmt" + "io" "os" + "strconv" + "time" "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/mxgateway" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/reflect/protoreflect" ) type versionOutput struct { @@ -15,29 +22,76 @@ type versionOutput struct { WorkerProtocolVersion uint32 `json:"workerProtocolVersion"` } +type commonOptions struct { + Endpoint string `json:"endpoint"` + APIKey string `json:"apiKey"` + APIKeyEnv string `json:"apiKeyEnv,omitempty"` + Plaintext bool `json:"plaintext"` + CACertFile string `json:"caCertFile,omitempty"` + ServerName string `json:"serverNameOverride,omitempty"` + CallTimeout string `json:"callTimeout,omitempty"` + + apiKeyValue string + timeout time.Duration +} + +type openSessionOutput struct { + Command string `json:"command"` + Options commonOptions `json:"options"` + Reply json.RawMessage `json:"reply"` +} + +type commandReplyOutput struct { + Command string `json:"command"` + Options commonOptions `json:"options"` + Reply json.RawMessage `json:"reply"` +} + func main() { - if err := run(os.Args[1:]); err != nil { + if err := runWithIO(context.Background(), os.Args[1:], os.Stdout, os.Stderr); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(2) } } func run(args []string) error { + return runWithIO(context.Background(), args, os.Stdout, os.Stderr) +} + +func runWithIO(ctx context.Context, args []string, stdout, stderr io.Writer) error { if len(args) == 0 { - return fmt.Errorf("usage: mxgw-go version [-json]") + writeUsage(stderr) + return errors.New("missing command") } switch args[0] { case "version": - return runVersion(args[1:]) + return runVersion(args[1:], stdout, stderr) + case "open-session": + return runOpenSession(ctx, args[1:], stdout, stderr) + case "close-session": + return runCloseSession(ctx, args[1:], stdout, stderr) + case "register": + return runRegister(ctx, args[1:], stdout, stderr) + case "add-item": + return runAddItem(ctx, args[1:], stdout, stderr) + case "advise": + return runAdvise(ctx, args[1:], stdout, stderr) + case "write": + return runWrite(ctx, args[1:], stdout, stderr) + case "stream-events": + return runStreamEvents(ctx, args[1:], stdout, stderr) + case "smoke": + return runSmoke(ctx, args[1:], stdout, stderr) default: + writeUsage(stderr) return fmt.Errorf("unknown command %q", args[0]) } } -func runVersion(args []string) error { +func runVersion(args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("version", flag.ContinueOnError) - flags.SetOutput(os.Stderr) + flags.SetOutput(stderr) jsonOutput := flags.Bool("json", false, "write JSON output") if err := flags.Parse(args); err != nil { @@ -51,13 +105,426 @@ func runVersion(args []string) error { } if *jsonOutput { - encoder := json.NewEncoder(os.Stdout) - encoder.SetIndent("", " ") - return encoder.Encode(output) + return writeJSON(stdout, output) } - fmt.Fprintf(os.Stdout, "mxgw-go %s\n", output.ClientVersion) - fmt.Fprintf(os.Stdout, "gateway protocol %d\n", output.GatewayProtocolVersion) - fmt.Fprintf(os.Stdout, "worker protocol %d\n", output.WorkerProtocolVersion) + fmt.Fprintf(stdout, "mxgw-go %s\n", output.ClientVersion) + fmt.Fprintf(stdout, "gateway protocol %d\n", output.GatewayProtocolVersion) + fmt.Fprintf(stdout, "worker protocol %d\n", output.WorkerProtocolVersion) return nil } + +func runOpenSession(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("open-session", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + clientName := flags.String("client-session-name", "", "client session name") + backend := flags.String("backend", "", "requested backend") + + if err := flags.Parse(args); err != nil { + return err + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + reply, err := client.OpenSessionRaw(ctx, (&mxgateway.OpenSessionOptions{ + RequestedBackend: *backend, + ClientSessionName: *clientName, + }).Request()) + if err != nil { + return err + } + + if *jsonOutput { + return writeJSON(stdout, openSessionOutput{ + Command: "open-session", + Options: options, + Reply: mustMarshalProto(reply), + }) + } + + fmt.Fprintln(stdout, reply.GetSessionId()) + return nil +} + +func runCloseSession(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("close-session", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" { + return errors.New("session-id is required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + reply, err := client.CloseSessionRaw(ctx, &mxgateway.CloseSessionRequest{SessionId: *sessionID}) + if err != nil { + return err + } + if *jsonOutput { + return writeJSON(stdout, commandReplyOutput{ + Command: "close-session", + Options: options, + Reply: mustMarshalProto(reply), + }) + } + + fmt.Fprintln(stdout, reply.GetFinalState()) + return nil +} + +func runRegister(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("register", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + clientName := flags.String("client-name", "", "MXAccess client name") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" || *clientName == "" { + return errors.New("session-id and client-name are required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + reply, err := session.RegisterRaw(ctx, *clientName) + return writeCommandOutput(stdout, *jsonOutput, "register", options, reply, err) +} + +func runAddItem(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("add-item", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") + item := flags.String("item", "", "item definition") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" || *item == "" { + return errors.New("session-id and item are required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + reply, err := session.AddItemRaw(ctx, int32(*serverHandle), *item) + return writeCommandOutput(stdout, *jsonOutput, "add-item", options, reply, err) +} + +func runAdvise(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("advise", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") + itemHandle := flags.Int("item-handle", 0, "MXAccess item handle") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" { + return errors.New("session-id is required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + reply, err := session.AdviseRaw(ctx, int32(*serverHandle), int32(*itemHandle)) + return writeCommandOutput(stdout, *jsonOutput, "advise", options, reply, err) +} + +func runWrite(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("write", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") + itemHandle := flags.Int("item-handle", 0, "MXAccess item handle") + valueType := flags.String("type", "string", "value type: bool, int32, int64, float, double, string") + valueText := flags.String("value", "", "value text") + userID := flags.Int("user-id", 0, "MXAccess user id") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" { + return errors.New("session-id is required") + } + + value, err := parseValue(*valueType, *valueText) + if err != nil { + return err + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + reply, err := session.WriteRaw(ctx, int32(*serverHandle), int32(*itemHandle), value, int32(*userID)) + return writeCommandOutput(stdout, *jsonOutput, "write", options, reply, err) +} + +func runStreamEvents(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("stream-events", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + sessionID := flags.String("session-id", "", "gateway session id") + after := flags.Uint64("after-worker-sequence", 0, "first worker sequence to read after") + limit := flags.Int("limit", 0, "maximum events to read; 0 means unbounded") + + if err := flags.Parse(args); err != nil { + return err + } + if *sessionID == "" { + return errors.New("session-id is required") + } + + client, _, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session := mxgateway.NewSessionForID(client, *sessionID) + streamCtx, cancelStream := context.WithCancel(ctx) + defer cancelStream() + events, err := session.EventsAfter(streamCtx, *after) + if err != nil { + return err + } + + count := 0 + for result := range events { + if result.Err != nil { + return result.Err + } + if *jsonOutput { + fmt.Fprintln(stdout, string(mustMarshalProto(result.Event))) + } else { + fmt.Fprintf(stdout, "%d %s\n", result.Event.GetWorkerSequence(), result.Event.GetFamily()) + } + count++ + if *limit > 0 && count >= *limit { + cancelStream() + return nil + } + } + return nil +} + +func runSmoke(ctx context.Context, args []string, stdout, stderr io.Writer) error { + flags := flag.NewFlagSet("smoke", flag.ContinueOnError) + flags.SetOutput(stderr) + common := bindCommonFlags(flags) + jsonOutput := flags.Bool("json", false, "write JSON output") + clientName := flags.String("client-name", "mxgw-go-smoke", "MXAccess client name") + item := flags.String("item", "", "item definition") + + if err := flags.Parse(args); err != nil { + return err + } + if *item == "" { + return errors.New("item is required") + } + + client, options, err := dialForCommand(ctx, common) + if err != nil { + return err + } + defer client.Close() + + session, err := client.OpenSession(ctx, mxgateway.OpenSessionOptions{ClientSessionName: *clientName}) + if err != nil { + return err + } + defer session.Close(context.Background()) + + serverHandle, err := session.Register(ctx, *clientName) + if err != nil { + return err + } + itemHandle, err := session.AddItem(ctx, serverHandle, *item) + if err != nil { + return err + } + if err := session.Advise(ctx, serverHandle, itemHandle); err != nil { + return err + } + + output := map[string]any{ + "command": "smoke", + "options": options, + "sessionId": session.ID(), + "serverHandle": serverHandle, + "itemHandle": itemHandle, + } + if *jsonOutput { + return writeJSON(stdout, output) + } + + fmt.Fprintf(stdout, "session=%s server=%d item=%d\n", session.ID(), serverHandle, itemHandle) + return nil +} + +func bindCommonFlags(flags *flag.FlagSet) *commonOptions { + common := &commonOptions{} + flags.StringVar(&common.Endpoint, "endpoint", "localhost:5000", "gateway endpoint") + flags.StringVar(&common.APIKey, "api-key", "", "gateway API key") + flags.StringVar(&common.APIKeyEnv, "api-key-env", "MXGATEWAY_API_KEY", "environment variable containing the API key") + flags.BoolVar(&common.Plaintext, "plaintext", false, "use plaintext transport") + flags.StringVar(&common.CACertFile, "ca-cert", "", "CA certificate file") + flags.StringVar(&common.ServerName, "server-name-override", "", "TLS server name override") + flags.StringVar(&common.CallTimeout, "call-timeout", "30s", "per-call timeout") + return common +} + +func dialForCommand(ctx context.Context, common *commonOptions) (*mxgateway.Client, commonOptions, error) { + options, err := common.resolved() + if err != nil { + return nil, options, err + } + + client, err := mxgateway.Dial(ctx, mxgateway.Options{ + Endpoint: options.Endpoint, + APIKey: options.apiKeyValue, + Plaintext: options.Plaintext, + CACertFile: options.CACertFile, + ServerNameOverride: options.ServerName, + CallTimeout: options.timeout, + }) + return client, options, err +} + +func (o *commonOptions) resolved() (commonOptions, error) { + resolved := *o + if resolved.APIKey == "" && resolved.APIKeyEnv != "" { + resolved.apiKeyValue = os.Getenv(resolved.APIKeyEnv) + } else { + resolved.apiKeyValue = resolved.APIKey + } + resolved.APIKey = mxgateway.RedactAPIKey(resolved.apiKeyValue) + if resolved.CallTimeout != "" { + timeout, err := time.ParseDuration(resolved.CallTimeout) + if err != nil { + return resolved, err + } + resolved.timeout = timeout + } + return resolved, nil +} + +func parseValue(valueType, valueText string) (*mxgateway.MxValue, error) { + switch valueType { + case "bool": + value, err := strconv.ParseBool(valueText) + if err != nil { + return nil, err + } + return mxgateway.BoolValue(value), nil + case "int32": + value, err := strconv.ParseInt(valueText, 10, 32) + if err != nil { + return nil, err + } + return mxgateway.Int32Value(int32(value)), nil + case "int64": + value, err := strconv.ParseInt(valueText, 10, 64) + if err != nil { + return nil, err + } + return mxgateway.Int64Value(value), nil + case "float": + value, err := strconv.ParseFloat(valueText, 32) + if err != nil { + return nil, err + } + return mxgateway.FloatValue(float32(value)), nil + case "double": + value, err := strconv.ParseFloat(valueText, 64) + if err != nil { + return nil, err + } + return mxgateway.DoubleValue(value), nil + case "string": + return mxgateway.StringValue(valueText), nil + default: + return nil, fmt.Errorf("unsupported value type %q", valueType) + } +} + +func writeCommandOutput(stdout io.Writer, jsonOutput bool, command string, options commonOptions, reply *mxgateway.MxCommandReply, err error) error { + if err != nil { + return err + } + if jsonOutput { + return writeJSON(stdout, commandReplyOutput{ + Command: command, + Options: options, + Reply: mustMarshalProto(reply), + }) + } + fmt.Fprintln(stdout, reply.GetKind()) + return nil +} + +func writeJSON(writer io.Writer, value any) error { + encoder := json.NewEncoder(writer) + encoder.SetIndent("", " ") + return encoder.Encode(value) +} + +func mustMarshalProto(message protojsonMessage) json.RawMessage { + data, err := protojson.MarshalOptions{UseProtoNames: false}.Marshal(message) + if err != nil { + panic(err) + } + return data +} + +type protojsonMessage interface { + ProtoReflect() protoreflect.Message +} + +func writeUsage(writer io.Writer) { + fmt.Fprintln(writer, "usage: mxgw-go ") +} diff --git a/clients/go/cmd/mxgw-go/main_test.go b/clients/go/cmd/mxgw-go/main_test.go new file mode 100644 index 0000000..945cf09 --- /dev/null +++ b/clients/go/cmd/mxgw-go/main_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestRunVersionJSON(t *testing.T) { + var stdout bytes.Buffer + var stderr bytes.Buffer + + if err := runWithIO(t.Context(), []string{"version", "-json"}, &stdout, &stderr); err != nil { + t.Fatalf("runWithIO() error = %v; stderr = %s", err, stderr.String()) + } + + var output versionOutput + if err := json.Unmarshal(stdout.Bytes(), &output); err != nil { + t.Fatalf("parse JSON: %v", err) + } + if output.GatewayProtocolVersion == 0 || output.WorkerProtocolVersion == 0 { + t.Fatalf("protocol versions were not populated: %+v", output) + } +} + +func TestCommonOptionsRedactsAPIKey(t *testing.T) { + options, err := (&commonOptions{ + Endpoint: "localhost:5000", + APIKey: "mxgw_super_secret", + Plaintext: true, + CallTimeout: "2s", + }).resolved() + if err != nil { + t.Fatalf("resolved() error = %v", err) + } + + data, err := json.Marshal(options) + if err != nil { + t.Fatalf("marshal options: %v", err) + } + if strings.Contains(string(data), "super_secret") { + t.Fatalf("redacted JSON leaked API key: %s", data) + } + if !strings.Contains(string(data), "mxgw") { + t.Fatalf("redacted JSON did not preserve key shape: %s", data) + } +} + +func TestParseValueBuildsTypedValue(t *testing.T) { + value, err := parseValue("int32", "123") + if err != nil { + t.Fatalf("parseValue() error = %v", err) + } + if got := value.GetInt32Value(); got != 123 { + t.Fatalf("int32 value = %d, want 123", got) + } +} diff --git a/clients/go/mxgateway/auth.go b/clients/go/mxgateway/auth.go new file mode 100644 index 0000000..b8157b9 --- /dev/null +++ b/clients/go/mxgateway/auth.go @@ -0,0 +1,30 @@ +package mxgateway + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +const authorizationHeader = "authorization" + +func unaryAuthInterceptor(apiKey string) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + return invoker(authContext(ctx, apiKey), method, req, reply, cc, opts...) + } +} + +func streamAuthInterceptor(apiKey string) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return streamer(authContext(ctx, apiKey), desc, cc, method, opts...) + } +} + +func authContext(ctx context.Context, apiKey string) context.Context { + if apiKey == "" { + return ctx + } + + return metadata.AppendToOutgoingContext(ctx, authorizationHeader, "Bearer "+apiKey) +} diff --git a/clients/go/mxgateway/client.go b/clients/go/mxgateway/client.go new file mode 100644 index 0000000..0f040cd --- /dev/null +++ b/clients/go/mxgateway/client.go @@ -0,0 +1,236 @@ +package mxgateway + +import ( + "context" + "crypto/tls" + "errors" + "time" + + pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/durationpb" +) + +const ( + defaultDialTimeout = 10 * time.Second + defaultCallTimeout = 30 * time.Second +) + +// Client owns a gateway gRPC connection and exposes session-oriented helpers. +type Client struct { + conn *grpc.ClientConn + raw pb.MxAccessGatewayClient + opts Options +} + +// Dial opens a gRPC connection to the gateway and configures auth metadata, +// transport security, and blocking dial cancellation from ctx. +func Dial(ctx context.Context, opts Options) (*Client, error) { + if opts.Endpoint == "" { + return nil, errors.New("mxgateway: endpoint is required") + } + + dialCtx := ctx + cancel := func() {} + if opts.DialTimeout > 0 { + dialCtx, cancel = context.WithTimeout(ctx, opts.DialTimeout) + } else if _, ok := ctx.Deadline(); !ok { + dialCtx, cancel = context.WithTimeout(ctx, defaultDialTimeout) + } + defer cancel() + + transportCredentials, err := resolveTransportCredentials(opts) + if err != nil { + return nil, err + } + + dialOptions := []grpc.DialOption{ + grpc.WithTransportCredentials(transportCredentials), + grpc.WithUnaryInterceptor(unaryAuthInterceptor(opts.APIKey)), + grpc.WithStreamInterceptor(streamAuthInterceptor(opts.APIKey)), + grpc.WithBlock(), + } + dialOptions = append(dialOptions, opts.DialOptions...) + + conn, err := grpc.DialContext(dialCtx, opts.Endpoint, dialOptions...) + if err != nil { + return nil, &GatewayError{Op: "dial", Err: err} + } + + return NewClient(conn, opts), nil +} + +// NewClient wraps an existing gRPC connection. The caller owns closing conn +// unless it calls Close on the returned Client. +func NewClient(conn *grpc.ClientConn, opts Options) *Client { + return &Client{ + conn: conn, + raw: pb.NewMxAccessGatewayClient(conn), + opts: opts, + } +} + +// RawClient returns the generated gRPC client for command-specific parity tests. +func (c *Client) RawClient() RawGatewayClient { + return c.raw +} + +// OpenSession creates a gateway-backed MXAccess session. +func (c *Client) OpenSession(ctx context.Context, opts OpenSessionOptions) (*Session, error) { + reply, err := c.OpenSessionRaw(ctx, opts.Request()) + if err != nil { + return nil, err + } + + return newSession(c, reply), nil +} + +// OpenSessionRaw sends a raw OpenSession request and validates protocol status. +func (c *Client) OpenSessionRaw(ctx context.Context, req *OpenSessionRequest) (*OpenSessionReply, error) { + if req == nil { + return nil, errors.New("mxgateway: open session request is required") + } + + callCtx, cancel := c.callContext(ctx) + defer cancel() + + reply, err := c.raw.OpenSession(callCtx, req) + if err != nil { + return nil, &GatewayError{Op: "open session", Err: err} + } + if err := EnsureProtocolSuccess("open session", reply.GetProtocolStatus(), nil); err != nil { + return reply, err + } + + return reply, nil +} + +// Invoke sends a raw MXAccess command request and validates protocol and +// MXAccess status fields while preserving the raw reply on typed errors. +func (c *Client) Invoke(ctx context.Context, req *MxCommandRequest) (*MxCommandReply, error) { + if req == nil { + return nil, errors.New("mxgateway: command request is required") + } + + callCtx, cancel := c.callContext(ctx) + defer cancel() + + reply, err := c.raw.Invoke(callCtx, req) + if err != nil { + return nil, &GatewayError{Op: "invoke", Err: err} + } + if err := EnsureProtocolSuccess("invoke", reply.GetProtocolStatus(), reply); err != nil { + return reply, err + } + if err := EnsureMxAccessSuccess("invoke", reply); err != nil { + return reply, err + } + + return reply, nil +} + +// CloseSessionRaw sends a raw CloseSession request and validates protocol +// status. +func (c *Client) CloseSessionRaw(ctx context.Context, req *CloseSessionRequest) (*CloseSessionReply, error) { + if req == nil { + return nil, errors.New("mxgateway: close session request is required") + } + + callCtx, cancel := c.callContext(ctx) + defer cancel() + + reply, err := c.raw.CloseSession(callCtx, req) + if err != nil { + return nil, &GatewayError{Op: "close session", Err: err} + } + if err := EnsureProtocolSuccess("close session", reply.GetProtocolStatus(), nil); err != nil { + return reply, err + } + + return reply, nil +} + +// StreamEventsRaw starts the generated event stream for callers that need direct +// control over Recv. +func (c *Client) StreamEventsRaw(ctx context.Context, req *StreamEventsRequest) (RawEventStream, error) { + if req == nil { + return nil, errors.New("mxgateway: stream events request is required") + } + + stream, err := c.raw.StreamEvents(ctx, req) + if err != nil { + return nil, &GatewayError{Op: "stream events", Err: err} + } + + return stream, nil +} + +// Close closes the underlying gRPC connection. +func (c *Client) Close() error { + if c == nil || c.conn == nil { + return nil + } + + return c.conn.Close() +} + +func (c *Client) callContext(ctx context.Context) (context.Context, context.CancelFunc) { + timeout := c.opts.CallTimeout + if timeout == 0 { + timeout = defaultCallTimeout + } + if timeout < 0 { + return ctx, func() {} + } + if _, ok := ctx.Deadline(); ok { + return ctx, func() {} + } + return context.WithTimeout(ctx, timeout) +} + +func resolveTransportCredentials(opts Options) (credentials.TransportCredentials, error) { + if opts.TransportCredentials != nil { + return opts.TransportCredentials, nil + } + if opts.Plaintext { + return insecure.NewCredentials(), nil + } + if opts.CACertFile != "" { + return credentials.NewClientTLSFromFile(opts.CACertFile, opts.ServerNameOverride) + } + if opts.TLSConfig != nil { + cfg := opts.TLSConfig.Clone() + if opts.ServerNameOverride != "" { + cfg.ServerName = opts.ServerNameOverride + } + return credentials.NewTLS(cfg), nil + } + + return credentials.NewTLS(&tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: opts.ServerNameOverride, + }), nil +} + +// OpenSessionOptions describes fields used to create an OpenSessionRequest. +type OpenSessionOptions struct { + RequestedBackend string + ClientSessionName string + ClientCorrelationID string + CommandTimeout time.Duration +} + +// Request returns the raw protobuf OpenSessionRequest for these options. +func (o OpenSessionOptions) Request() *OpenSessionRequest { + req := &OpenSessionRequest{ + RequestedBackend: o.RequestedBackend, + ClientSessionName: o.ClientSessionName, + ClientCorrelationId: o.ClientCorrelationID, + } + if o.CommandTimeout > 0 { + req.CommandTimeout = durationpb.New(o.CommandTimeout) + } + return req +} diff --git a/clients/go/mxgateway/client_session_test.go b/clients/go/mxgateway/client_session_test.go new file mode 100644 index 0000000..991b438 --- /dev/null +++ b/clients/go/mxgateway/client_session_test.go @@ -0,0 +1,261 @@ +package mxgateway + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/test/bufconn" +) + +const bufSize = 1024 * 1024 + +func TestDialAttachesAuthMetadataToUnaryCalls(t *testing.T) { + fake := &fakeGatewayServer{ + openReply: &pb.OpenSessionReply{ + SessionId: "session-1", + GatewayProtocolVersion: GatewayProtocolVersion, + WorkerProtocolVersion: WorkerProtocolVersion, + ProtocolStatus: &pb.ProtocolStatus{ + Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, + }, + }, + } + client, cleanup := newBufconnClient(t, fake) + defer cleanup() + + _, err := client.OpenSession(context.Background(), OpenSessionOptions{ClientSessionName: "fixture"}) + if err != nil { + t.Fatalf("OpenSession() error = %v", err) + } + + if got := fake.openAuth; got != "Bearer test-api-key" { + t.Fatalf("authorization metadata = %q, want %q", got, "Bearer test-api-key") + } +} + +func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) { + fake := &fakeGatewayServer{ + streamStarted: make(chan struct{}), + } + client, cleanup := newBufconnClient(t, fake) + defer cleanup() + session := NewSessionForID(client, "session-1") + ctx, cancel := context.WithCancel(context.Background()) + + events, err := session.Events(ctx) + if err != nil { + t.Fatalf("Events() error = %v", err) + } + <-fake.streamStarted + + first := <-events + if first.Err != nil { + t.Fatalf("first event error = %v", first.Err) + } + if first.Event.GetWorkerSequence() != 1 { + t.Fatalf("worker sequence = %d, want 1", first.Event.GetWorkerSequence()) + } + if got := fake.streamAuth; got != "Bearer test-api-key" { + t.Fatalf("stream authorization metadata = %q, want %q", got, "Bearer test-api-key") + } + + cancel() + select { + case _, ok := <-events: + if ok { + t.Fatal("events channel produced an extra item after cancellation") + } + case <-time.After(2 * time.Second): + t.Fatal("events channel did not close after cancellation") + } +} + +func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) { + fake := &fakeGatewayServer{ + invokeReply: &pb.MxCommandReply{ + SessionId: "session-1", + Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2, + ProtocolStatus: &pb.ProtocolStatus{ + Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, + }, + Payload: &pb.MxCommandReply_AddItem2{ + AddItem2: &pb.AddItem2Reply{ItemHandle: 42}, + }, + }, + } + client, cleanup := newBufconnClient(t, fake) + defer cleanup() + session := NewSessionForID(client, "session-1") + + itemHandle, err := session.AddItem2(context.Background(), 12, "Area001.Pump001.Speed", "runtime") + if err != nil { + t.Fatalf("AddItem2() error = %v", err) + } + + if itemHandle != 42 { + t.Fatalf("item handle = %d, want 42", itemHandle) + } + req := fake.invokeRequest + if req.GetSessionId() != "session-1" { + t.Fatalf("session id = %q, want session-1", req.GetSessionId()) + } + if req.GetClientCorrelationId() == "" { + t.Fatal("client correlation id is empty") + } + if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 { + t.Fatalf("command kind = %s", req.GetCommand().GetKind()) + } + if req.GetCommand().GetAddItem2().GetItemContext() != "runtime" { + t.Fatalf("item context = %q, want runtime", req.GetCommand().GetAddItem2().GetItemContext()) + } +} + +func TestInvokeReturnsTypedMxAccessErrorWithRawReply(t *testing.T) { + hresult := int32(-2147467259) + fake := &fakeGatewayServer{ + invokeReply: &pb.MxCommandReply{ + SessionId: "session-1", + Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE, + Hresult: &hresult, + DiagnosticMessage: "native failure", + ProtocolStatus: &pb.ProtocolStatus{Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, Message: "MXAccess failed"}, + Statuses: []*pb.MxStatusProxy{{Success: 0, DiagnosticText: "failed"}}, + }, + } + client, cleanup := newBufconnClient(t, fake) + defer cleanup() + session := NewSessionForID(client, "session-1") + + err := session.Advise(context.Background(), 12, 34) + + var mxErr *MxAccessError + if !errors.As(err, &mxErr) { + t.Fatalf("error %T does not support errors.As(*MxAccessError)", err) + } + if mxErr.Reply.GetHresult() != hresult { + t.Fatalf("raw reply HRESULT = %d, want %d", mxErr.Reply.GetHresult(), hresult) + } + var commandErr *CommandError + if !errors.As(err, &commandErr) { + t.Fatalf("error %T does not support errors.As(*CommandError)", err) + } + if commandErr.Reply.GetDiagnosticMessage() != "native failure" { + t.Fatalf("raw diagnostic = %q", commandErr.Reply.GetDiagnosticMessage()) + } +} + +func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) { + t.Helper() + + listener := bufconn.Listen(bufSize) + server := grpc.NewServer() + pb.RegisterMxAccessGatewayServer(server, fake) + go func() { + if err := server.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { + t.Errorf("bufconn server failed: %v", err) + } + }() + + dialer := func(ctx context.Context, _ string) (net.Conn, error) { + return listener.DialContext(ctx) + } + client, err := Dial(context.Background(), Options{ + Endpoint: "bufnet", + APIKey: "test-api-key", + Plaintext: true, + DialOptions: []grpc.DialOption{ + grpc.WithContextDialer(dialer), + }, + }) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + + return client, func() { + client.Close() + server.Stop() + listener.Close() + } +} + +type fakeGatewayServer struct { + pb.UnimplementedMxAccessGatewayServer + + openReply *pb.OpenSessionReply + openAuth string + streamAuth string + streamStarted chan struct{} + invokeReply *pb.MxCommandReply + invokeRequest *pb.MxCommandRequest +} + +func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) { + s.openAuth = authorizationFromContext(ctx) + if s.openReply != nil { + return s.openReply, nil + } + return &pb.OpenSessionReply{ + SessionId: "session-1", + ProtocolStatus: &pb.ProtocolStatus{ + Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, + }, + }, nil +} + +func (s *fakeGatewayServer) CloseSession(ctx context.Context, req *pb.CloseSessionRequest) (*pb.CloseSessionReply, error) { + return &pb.CloseSessionReply{ + SessionId: req.GetSessionId(), + ProtocolStatus: &pb.ProtocolStatus{ + Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, + }, + }, nil +} + +func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest) (*pb.MxCommandReply, error) { + s.invokeRequest = req + if s.invokeReply != nil { + return s.invokeReply, nil + } + return &pb.MxCommandReply{ + SessionId: req.GetSessionId(), + Kind: req.GetCommand().GetKind(), + ProtocolStatus: &pb.ProtocolStatus{ + Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, + }, + }, nil +} + +func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error { + s.streamAuth = authorizationFromContext(stream.Context()) + if s.streamStarted != nil { + close(s.streamStarted) + } + if err := stream.Send(&pb.MxEvent{ + SessionId: req.GetSessionId(), + Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE, + WorkerSequence: 1, + }); err != nil { + return err + } + <-stream.Context().Done() + return io.EOF +} + +func authorizationFromContext(ctx context.Context) string { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "" + } + values := md.Get(authorizationHeader) + if len(values) == 0 { + return "" + } + return values[0] +} diff --git a/clients/go/mxgateway/conversion_test.go b/clients/go/mxgateway/conversion_test.go new file mode 100644 index 0000000..c7633ec --- /dev/null +++ b/clients/go/mxgateway/conversion_test.go @@ -0,0 +1,73 @@ +package mxgateway + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" + "google.golang.org/protobuf/encoding/protojson" +) + +func TestValueConversionFixtures(t *testing.T) { + data, err := os.ReadFile(filepath.Join("..", "..", "proto", "fixtures", "behavior", "values", "value-conversion-cases.json")) + if err != nil { + t.Fatalf("read fixture: %v", err) + } + + var fixture struct { + Cases []struct { + ID string `json:"id"` + ExpectedKind string `json:"expectedKind"` + Value json.RawMessage `json:"value"` + } `json:"cases"` + } + if err := json.Unmarshal(data, &fixture); err != nil { + t.Fatalf("parse fixture manifest: %v", err) + } + + for _, tc := range fixture.Cases { + t.Run(tc.ID, func(t *testing.T) { + var value pb.MxValue + if err := protojson.Unmarshal(tc.Value, &value); err != nil { + t.Fatalf("parse value: %v", err) + } + if _, err := NativeValue(&value); err != nil { + t.Fatalf("NativeValue() error = %v", err) + } + if got := value.ProtoReflect().WhichOneof(value.ProtoReflect().Descriptor().Oneofs().ByName("kind")).JSONName(); got != tc.ExpectedKind { + t.Fatalf("kind = %q, want %q", got, tc.ExpectedKind) + } + }) + } +} + +func TestStatusConversionFixtures(t *testing.T) { + data, err := os.ReadFile(filepath.Join("..", "..", "proto", "fixtures", "behavior", "statuses", "status-conversion-cases.json")) + if err != nil { + t.Fatalf("read fixture: %v", err) + } + + var fixture struct { + Cases []struct { + ID string `json:"id"` + Status json.RawMessage `json:"status"` + } `json:"cases"` + } + if err := json.Unmarshal(data, &fixture); err != nil { + t.Fatalf("parse fixture manifest: %v", err) + } + + for _, tc := range fixture.Cases { + t.Run(tc.ID, func(t *testing.T) { + var status pb.MxStatusProxy + if err := protojson.Unmarshal(tc.Status, &status); err != nil { + t.Fatalf("parse status: %v", err) + } + if got, want := StatusSucceeded(&status), status.GetSuccess() != 0; got != want { + t.Fatalf("StatusSucceeded() = %v, want %v", got, want) + } + }) + } +} diff --git a/clients/go/mxgateway/errors.go b/clients/go/mxgateway/errors.go new file mode 100644 index 0000000..890f14b --- /dev/null +++ b/clients/go/mxgateway/errors.go @@ -0,0 +1,118 @@ +package mxgateway + +import ( + "fmt" + + pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" +) + +// GatewayError wraps transport-level gRPC failures. +type GatewayError struct { + Op string + Err error +} + +func (e *GatewayError) Error() string { + if e == nil { + return "" + } + if e.Op == "" { + return fmt.Sprintf("mxgateway: %v", e.Err) + } + return fmt.Sprintf("mxgateway: %s failed: %v", e.Op, e.Err) +} + +func (e *GatewayError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +// CommandError reports a non-OK gateway protocol status and keeps the raw +// command reply when one exists. +type CommandError struct { + Op string + Status *ProtocolStatus + Reply *MxCommandReply +} + +func (e *CommandError) Error() string { + if e == nil { + return "" + } + status := e.Status + if status == nil { + return fmt.Sprintf("mxgateway: %s failed with missing protocol status", e.Op) + } + if status.GetMessage() == "" { + return fmt.Sprintf("mxgateway: %s failed with protocol status %s", e.Op, status.GetCode()) + } + return fmt.Sprintf("mxgateway: %s failed with protocol status %s: %s", e.Op, status.GetCode(), status.GetMessage()) +} + +// MxAccessError reports HRESULT or MXSTATUS_PROXY failures returned by MXAccess. +type MxAccessError struct { + Command *CommandError + Reply *MxCommandReply +} + +func (e *MxAccessError) Error() string { + if e == nil { + return "" + } + if e.Command != nil && e.Command.Status != nil && e.Command.Status.GetMessage() != "" { + return e.Command.Error() + } + if e.Reply != nil && e.Reply.GetDiagnosticMessage() != "" { + return fmt.Sprintf("mxgateway: MXAccess command %s failed: %s", e.Reply.GetKind(), e.Reply.GetDiagnosticMessage()) + } + if e.Reply != nil && e.Reply.Hresult != nil { + return fmt.Sprintf("mxgateway: MXAccess command %s failed with HRESULT 0x%08X", e.Reply.GetKind(), uint32(e.Reply.GetHresult())) + } + return "mxgateway: MXAccess command failed" +} + +func (e *MxAccessError) Unwrap() error { + if e == nil { + return nil + } + return e.Command +} + +// EnsureProtocolSuccess returns a typed CommandError when status is non-OK. +func EnsureProtocolSuccess(op string, status *ProtocolStatus, reply *MxCommandReply) error { + if status == nil || status.GetCode() == pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK { + return nil + } + + commandError := &CommandError{ + Op: op, + Status: status, + Reply: reply, + } + if status.GetCode() == pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE { + return &MxAccessError{ + Command: commandError, + Reply: reply, + } + } + return commandError +} + +// EnsureMxAccessSuccess returns a typed MxAccessError for failing HRESULTs or +// MXSTATUS_PROXY entries. +func EnsureMxAccessSuccess(op string, reply *MxCommandReply) error { + if reply == nil { + return nil + } + if reply.Hresult != nil && reply.GetHresult() != 0 { + return &MxAccessError{Reply: reply} + } + for _, status := range reply.GetStatuses() { + if !StatusSucceeded(status) { + return &MxAccessError{Reply: reply} + } + } + return nil +} diff --git a/clients/go/mxgateway/options.go b/clients/go/mxgateway/options.go index 8d6bda6..782634d 100644 --- a/clients/go/mxgateway/options.go +++ b/clients/go/mxgateway/options.go @@ -1,14 +1,26 @@ package mxgateway -import "strings" +import ( + "crypto/tls" + "strings" + "time" -// Options configures future gateway connections. + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// Options configures gateway connections. type Options struct { - Endpoint string - APIKey string - Plaintext bool - CACertFile string - ServerNameOverride string + Endpoint string + APIKey string + Plaintext bool + CACertFile string + ServerNameOverride string + DialTimeout time.Duration + CallTimeout time.Duration + TLSConfig *tls.Config + TransportCredentials credentials.TransportCredentials + DialOptions []grpc.DialOption } // RedactedAPIKey returns a display-safe representation of the configured API diff --git a/clients/go/mxgateway/session.go b/clients/go/mxgateway/session.go new file mode 100644 index 0000000..0a7412e --- /dev/null +++ b/clients/go/mxgateway/session.go @@ -0,0 +1,260 @@ +package mxgateway + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "io" + "sync" + + pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// EventResult carries either the next ordered event or a terminal stream error. +type EventResult struct { + Event *MxEvent + Err error +} + +// Session represents one gateway-backed MXAccess session. +type Session struct { + client *Client + openReply *OpenSessionReply + closeMu sync.Mutex + closeReply *CloseSessionReply +} + +func newSession(client *Client, openReply *OpenSessionReply) *Session { + return &Session{ + client: client, + openReply: openReply, + } +} + +// NewSessionForID creates a session wrapper for commands against an existing +// gateway session id. +func NewSessionForID(client *Client, sessionID string) *Session { + return newSession(client, &pb.OpenSessionReply{SessionId: sessionID}) +} + +// ID returns the gateway session identifier. +func (s *Session) ID() string { + return s.openReply.GetSessionId() +} + +// OpenReply returns the raw OpenSession reply. +func (s *Session) OpenReply() *OpenSessionReply { + return s.openReply +} + +// Close closes the gateway session once and returns the raw close reply. +func (s *Session) Close(ctx context.Context) (*CloseSessionReply, error) { + s.closeMu.Lock() + defer s.closeMu.Unlock() + + if s.closeReply != nil { + return s.closeReply, nil + } + + reply, err := s.client.CloseSessionRaw(ctx, &pb.CloseSessionRequest{SessionId: s.ID()}) + if err != nil { + return reply, err + } + s.closeReply = reply + return reply, nil +} + +// Register invokes MXAccess Register and returns the server handle. +func (s *Session) Register(ctx context.Context, clientName string) (int32, error) { + reply, err := s.RegisterRaw(ctx, clientName) + if err != nil { + return 0, err + } + if reply.GetRegister() != nil { + return reply.GetRegister().GetServerHandle(), nil + } + return reply.GetReturnValue().GetInt32Value(), nil +} + +// RegisterRaw invokes MXAccess Register and returns the raw reply. +func (s *Session) RegisterRaw(ctx context.Context, clientName string) (*MxCommandReply, error) { + if clientName == "" { + return nil, errors.New("mxgateway: client name is required") + } + + return s.invokeCommand(ctx, &pb.MxCommand{ + Kind: pb.MxCommandKind_MX_COMMAND_KIND_REGISTER, + Payload: &pb.MxCommand_Register{ + Register: &pb.RegisterCommand{ClientName: clientName}, + }, + }) +} + +// Unregister invokes MXAccess Unregister. +func (s *Session) Unregister(ctx context.Context, serverHandle int32) error { + _, err := s.invokeCommand(ctx, &pb.MxCommand{ + Kind: pb.MxCommandKind_MX_COMMAND_KIND_UNREGISTER, + Payload: &pb.MxCommand_Unregister{ + Unregister: &pb.UnregisterCommand{ServerHandle: serverHandle}, + }, + }) + return err +} + +// AddItem invokes MXAccess AddItem and returns the item handle. +func (s *Session) AddItem(ctx context.Context, serverHandle int32, itemDefinition string) (int32, error) { + reply, err := s.AddItemRaw(ctx, serverHandle, itemDefinition) + if err != nil { + return 0, err + } + if reply.GetAddItem() != nil { + return reply.GetAddItem().GetItemHandle(), nil + } + return reply.GetReturnValue().GetInt32Value(), nil +} + +// AddItemRaw invokes MXAccess AddItem and returns the raw reply. +func (s *Session) AddItemRaw(ctx context.Context, serverHandle int32, itemDefinition string) (*MxCommandReply, error) { + if itemDefinition == "" { + return nil, errors.New("mxgateway: item definition is required") + } + + return s.invokeCommand(ctx, &pb.MxCommand{ + Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM, + Payload: &pb.MxCommand_AddItem{ + AddItem: &pb.AddItemCommand{ + ServerHandle: serverHandle, + ItemDefinition: itemDefinition, + }, + }, + }) +} + +// AddItem2 invokes MXAccess AddItem2 and returns the item handle. +func (s *Session) AddItem2(ctx context.Context, serverHandle int32, itemDefinition, itemContext string) (int32, error) { + reply, err := s.AddItem2Raw(ctx, serverHandle, itemDefinition, itemContext) + if err != nil { + return 0, err + } + if reply.GetAddItem2() != nil { + return reply.GetAddItem2().GetItemHandle(), nil + } + return reply.GetReturnValue().GetInt32Value(), nil +} + +// AddItem2Raw invokes MXAccess AddItem2 and returns the raw reply. +func (s *Session) AddItem2Raw(ctx context.Context, serverHandle int32, itemDefinition, itemContext string) (*MxCommandReply, error) { + if itemDefinition == "" { + return nil, errors.New("mxgateway: item definition is required") + } + + return s.invokeCommand(ctx, &pb.MxCommand{ + Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2, + Payload: &pb.MxCommand_AddItem2{ + AddItem2: &pb.AddItem2Command{ + ServerHandle: serverHandle, + ItemDefinition: itemDefinition, + ItemContext: itemContext, + }, + }, + }) +} + +// Advise invokes MXAccess Advise. +func (s *Session) Advise(ctx context.Context, serverHandle, itemHandle int32) error { + _, err := s.AdviseRaw(ctx, serverHandle, itemHandle) + return err +} + +// AdviseRaw invokes MXAccess Advise and returns the raw reply. +func (s *Session) AdviseRaw(ctx context.Context, serverHandle, itemHandle int32) (*MxCommandReply, error) { + return s.invokeCommand(ctx, &pb.MxCommand{ + Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE, + Payload: &pb.MxCommand_Advise{ + Advise: &pb.AdviseCommand{ + ServerHandle: serverHandle, + ItemHandle: itemHandle, + }, + }, + }) +} + +// Write invokes MXAccess Write. +func (s *Session) Write(ctx context.Context, serverHandle, itemHandle int32, value *MxValue, userID int32) error { + _, err := s.WriteRaw(ctx, serverHandle, itemHandle, value, userID) + return err +} + +// WriteRaw invokes MXAccess Write and returns the raw reply. +func (s *Session) WriteRaw(ctx context.Context, serverHandle, itemHandle int32, value *MxValue, userID int32) (*MxCommandReply, error) { + if value == nil { + return nil, errors.New("mxgateway: write value is required") + } + + return s.invokeCommand(ctx, &pb.MxCommand{ + Kind: pb.MxCommandKind_MX_COMMAND_KIND_WRITE, + Payload: &pb.MxCommand_Write{ + Write: &pb.WriteCommand{ + ServerHandle: serverHandle, + ItemHandle: itemHandle, + Value: value, + UserId: userID, + }, + }, + }) +} + +// Events streams ordered session events until the server ends the stream, +// context cancellation stops Recv, or a terminal error is sent. +func (s *Session) Events(ctx context.Context) (<-chan EventResult, error) { + return s.EventsAfter(ctx, 0) +} + +// EventsAfter streams ordered session events after the given worker sequence. +func (s *Session) EventsAfter(ctx context.Context, afterWorkerSequence uint64) (<-chan EventResult, error) { + stream, err := s.client.StreamEventsRaw(ctx, &pb.StreamEventsRequest{ + SessionId: s.ID(), + AfterWorkerSequence: afterWorkerSequence, + }) + if err != nil { + return nil, err + } + + results := make(chan EventResult, 16) + go func() { + defer close(results) + for { + event, err := stream.Recv() + if err == nil { + results <- EventResult{Event: event} + continue + } + if err == io.EOF || status.Code(err) == codes.Canceled || ctx.Err() != nil { + return + } + results <- EventResult{Err: &GatewayError{Op: "stream events", Err: err}} + return + } + }() + + return results, nil +} + +func (s *Session) invokeCommand(ctx context.Context, command *MxCommand) (*MxCommandReply, error) { + return s.client.Invoke(ctx, &pb.MxCommandRequest{ + SessionId: s.ID(), + ClientCorrelationId: newCorrelationID(), + Command: command, + }) +} + +func newCorrelationID() string { + var buffer [16]byte + if _, err := rand.Read(buffer[:]); err != nil { + return "" + } + return hex.EncodeToString(buffer[:]) +} diff --git a/clients/go/mxgateway/status.go b/clients/go/mxgateway/status.go new file mode 100644 index 0000000..3ea60c6 --- /dev/null +++ b/clients/go/mxgateway/status.go @@ -0,0 +1,6 @@ +package mxgateway + +// StatusSucceeded reports whether an MXSTATUS_PROXY entry represents success. +func StatusSucceeded(status *MxStatusProxy) bool { + return status == nil || status.GetSuccess() != 0 +} diff --git a/clients/go/mxgateway/types.go b/clients/go/mxgateway/types.go new file mode 100644 index 0000000..637b16c --- /dev/null +++ b/clients/go/mxgateway/types.go @@ -0,0 +1,70 @@ +package mxgateway + +import pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" + +// RawGatewayClient is the generated gRPC client interface exposed for callers +// that need direct contract access. +type RawGatewayClient = pb.MxAccessGatewayClient + +// RawEventStream is the generated StreamEvents client stream. +type RawEventStream = pb.MxAccessGateway_StreamEventsClient + +// Generated protobuf aliases keep raw contract access available from the public +// mxgateway package while generated code remains under internal/generated. +type ( + OpenSessionRequest = pb.OpenSessionRequest + OpenSessionReply = pb.OpenSessionReply + CloseSessionRequest = pb.CloseSessionRequest + CloseSessionReply = pb.CloseSessionReply + StreamEventsRequest = pb.StreamEventsRequest + MxCommandRequest = pb.MxCommandRequest + MxCommandReply = pb.MxCommandReply + MxCommand = pb.MxCommand + MxEvent = pb.MxEvent + MxValue = pb.MxValue + Value = pb.MxValue + MxArray = pb.MxArray + MxStatusProxy = pb.MxStatusProxy + ProtocolStatus = pb.ProtocolStatus + RegisterCommand = pb.RegisterCommand + UnregisterCommand = pb.UnregisterCommand + AddItemCommand = pb.AddItemCommand + AddItem2Command = pb.AddItem2Command + AdviseCommand = pb.AdviseCommand + WriteCommand = pb.WriteCommand + Write2Command = pb.Write2Command + RegisterReply = pb.RegisterReply + AddItemReply = pb.AddItemReply + AddItem2Reply = pb.AddItem2Reply +) + +type ( + MxCommandKind = pb.MxCommandKind + MxDataType = pb.MxDataType + MxEventFamily = pb.MxEventFamily + MxStatusCategory = pb.MxStatusCategory + MxStatusSource = pb.MxStatusSource + ProtocolStatusCode = pb.ProtocolStatusCode + SessionState = pb.SessionState +) + +const ( + CommandKindRegister = pb.MxCommandKind_MX_COMMAND_KIND_REGISTER + CommandKindUnregister = pb.MxCommandKind_MX_COMMAND_KIND_UNREGISTER + CommandKindAddItem = pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM + CommandKindAddItem2 = pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 + CommandKindAdvise = pb.MxCommandKind_MX_COMMAND_KIND_ADVISE + CommandKindWrite = pb.MxCommandKind_MX_COMMAND_KIND_WRITE + CommandKindWrite2 = pb.MxCommandKind_MX_COMMAND_KIND_WRITE2 + + DataTypeUnknown = pb.MxDataType_MX_DATA_TYPE_UNKNOWN + DataTypeBoolean = pb.MxDataType_MX_DATA_TYPE_BOOLEAN + DataTypeInteger = pb.MxDataType_MX_DATA_TYPE_INTEGER + DataTypeFloat = pb.MxDataType_MX_DATA_TYPE_FLOAT + DataTypeDouble = pb.MxDataType_MX_DATA_TYPE_DOUBLE + DataTypeString = pb.MxDataType_MX_DATA_TYPE_STRING + DataTypeTime = pb.MxDataType_MX_DATA_TYPE_TIME + + ProtocolStatusOK = pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK + ProtocolStatusMxAccessFailure = pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE +) diff --git a/clients/go/mxgateway/values.go b/clients/go/mxgateway/values.go new file mode 100644 index 0000000..6a44687 --- /dev/null +++ b/clients/go/mxgateway/values.go @@ -0,0 +1,148 @@ +package mxgateway + +import ( + "fmt" + "time" + + pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// BoolValue builds an MXAccess Boolean value. +func BoolValue(value bool) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_BOOLEAN, + VariantType: "VT_BOOL", + Kind: &pb.MxValue_BoolValue{BoolValue: value}, + } +} + +// Int32Value builds an MXAccess Int32 value. +func Int32Value(value int32) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_INTEGER, + VariantType: "VT_I4", + Kind: &pb.MxValue_Int32Value{Int32Value: value}, + } +} + +// Int64Value builds an MXAccess Int64 value. +func Int64Value(value int64) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_INTEGER, + VariantType: "VT_I8", + Kind: &pb.MxValue_Int64Value{Int64Value: value}, + } +} + +// FloatValue builds an MXAccess Float value. +func FloatValue(value float32) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_FLOAT, + VariantType: "VT_R4", + Kind: &pb.MxValue_FloatValue{FloatValue: value}, + } +} + +// DoubleValue builds an MXAccess Double value. +func DoubleValue(value float64) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_DOUBLE, + VariantType: "VT_R8", + Kind: &pb.MxValue_DoubleValue{DoubleValue: value}, + } +} + +// StringValue builds an MXAccess String value. +func StringValue(value string) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_STRING, + VariantType: "VT_BSTR", + Kind: &pb.MxValue_StringValue{StringValue: value}, + } +} + +// TimestampValue builds an MXAccess timestamp value from a Go time. +func TimestampValue(value time.Time) *MxValue { + return &pb.MxValue{ + DataType: pb.MxDataType_MX_DATA_TYPE_TIME, + VariantType: "VT_DATE", + Kind: &pb.MxValue_TimestampValue{TimestampValue: timestamppb.New(value)}, + } +} + +// NativeValue converts a protobuf MxValue to the closest Go representation +// without discarding raw fallback data. +func NativeValue(value *MxValue) (any, error) { + if value == nil || value.GetIsNull() { + return nil, nil + } + + switch kind := value.GetKind().(type) { + case *pb.MxValue_BoolValue: + return kind.BoolValue, nil + case *pb.MxValue_Int32Value: + return kind.Int32Value, nil + case *pb.MxValue_Int64Value: + return kind.Int64Value, nil + case *pb.MxValue_FloatValue: + return kind.FloatValue, nil + case *pb.MxValue_DoubleValue: + return kind.DoubleValue, nil + case *pb.MxValue_StringValue: + return kind.StringValue, nil + case *pb.MxValue_TimestampValue: + if kind.TimestampValue == nil { + return nil, nil + } + return kind.TimestampValue.AsTime(), nil + case *pb.MxValue_ArrayValue: + return NativeArray(kind.ArrayValue) + case *pb.MxValue_RawValue: + return append([]byte(nil), kind.RawValue...), nil + default: + return nil, fmt.Errorf("mxgateway: unsupported value kind %T", kind) + } +} + +// NativeArray converts a protobuf MxArray to the closest Go slice +// representation. +func NativeArray(array *MxArray) (any, error) { + if array == nil { + return nil, nil + } + + switch values := array.GetValues().(type) { + case *pb.MxArray_BoolValues: + return append([]bool(nil), values.BoolValues.GetValues()...), nil + case *pb.MxArray_Int32Values: + return append([]int32(nil), values.Int32Values.GetValues()...), nil + case *pb.MxArray_Int64Values: + return append([]int64(nil), values.Int64Values.GetValues()...), nil + case *pb.MxArray_FloatValues: + return append([]float32(nil), values.FloatValues.GetValues()...), nil + case *pb.MxArray_DoubleValues: + return append([]float64(nil), values.DoubleValues.GetValues()...), nil + case *pb.MxArray_StringValues: + return append([]string(nil), values.StringValues.GetValues()...), nil + case *pb.MxArray_TimestampValues: + result := make([]time.Time, 0, len(values.TimestampValues.GetValues())) + for _, value := range values.TimestampValues.GetValues() { + if value == nil { + result = append(result, time.Time{}) + continue + } + result = append(result, value.AsTime()) + } + return result, nil + case *pb.MxArray_RawValues: + rawValues := values.RawValues.GetValues() + result := make([][]byte, 0, len(rawValues)) + for _, value := range rawValues { + result = append(result, append([]byte(nil), value...)) + } + return result, nil + default: + return nil, fmt.Errorf("mxgateway: unsupported array value kind %T", values) + } +} -- 2.52.0