package main import ( "context" "encoding/json" "errors" "flag" "fmt" "io" "os" "strconv" "time" "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/mxgateway" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/reflect/protoreflect" ) type versionOutput struct { ClientVersion string `json:"clientVersion"` GatewayProtocolVersion uint32 `json:"gatewayProtocolVersion"` WorkerProtocolVersion uint32 `json:"workerProtocolVersion"` } type commonOptions struct { Endpoint string `json:"endpoint"` APIKey string `json:"apiKey"` APIKeyEnv string `json:"apiKeyEnv,omitempty"` Plaintext bool `json:"plaintext"` CACertFile string `json:"caCertFile,omitempty"` ServerName string `json:"serverNameOverride,omitempty"` CallTimeout string `json:"callTimeout,omitempty"` apiKeyValue string timeout time.Duration } type openSessionOutput struct { Command string `json:"command"` Options commonOptions `json:"options"` Reply json.RawMessage `json:"reply"` } type commandReplyOutput struct { Command string `json:"command"` Options commonOptions `json:"options"` Reply json.RawMessage `json:"reply"` } func main() { if err := runWithIO(context.Background(), os.Args[1:], os.Stdout, os.Stderr); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(2) } } func run(args []string) error { return runWithIO(context.Background(), args, os.Stdout, os.Stderr) } func runWithIO(ctx context.Context, args []string, stdout, stderr io.Writer) error { if len(args) == 0 { writeUsage(stderr) return errors.New("missing command") } switch args[0] { case "version": return runVersion(args[1:], stdout, stderr) case "open-session": return runOpenSession(ctx, args[1:], stdout, stderr) case "close-session": return runCloseSession(ctx, args[1:], stdout, stderr) case "register": return runRegister(ctx, args[1:], stdout, stderr) case "add-item": return runAddItem(ctx, args[1:], stdout, stderr) case "advise": return runAdvise(ctx, args[1:], stdout, stderr) case "write": return runWrite(ctx, args[1:], stdout, stderr) case "stream-events": return runStreamEvents(ctx, args[1:], stdout, stderr) case "smoke": return runSmoke(ctx, args[1:], stdout, stderr) default: writeUsage(stderr) return fmt.Errorf("unknown command %q", args[0]) } } func runVersion(args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("version", flag.ContinueOnError) flags.SetOutput(stderr) jsonOutput := flags.Bool("json", false, "write JSON output") if err := flags.Parse(args); err != nil { return err } output := versionOutput{ ClientVersion: mxgateway.ClientVersion, GatewayProtocolVersion: mxgateway.GatewayProtocolVersion, WorkerProtocolVersion: mxgateway.WorkerProtocolVersion, } if *jsonOutput { return writeJSON(stdout, output) } fmt.Fprintf(stdout, "mxgw-go %s\n", output.ClientVersion) fmt.Fprintf(stdout, "gateway protocol %d\n", output.GatewayProtocolVersion) fmt.Fprintf(stdout, "worker protocol %d\n", output.WorkerProtocolVersion) return nil } func runOpenSession(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("open-session", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") clientName := flags.String("client-session-name", "", "client session name") backend := flags.String("backend", "", "requested backend") if err := flags.Parse(args); err != nil { return err } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() reply, err := client.OpenSessionRaw(ctx, (&mxgateway.OpenSessionOptions{ RequestedBackend: *backend, ClientSessionName: *clientName, }).Request()) if err != nil { return err } if *jsonOutput { return writeJSON(stdout, openSessionOutput{ Command: "open-session", Options: options, Reply: mustMarshalProto(reply), }) } fmt.Fprintln(stdout, reply.GetSessionId()) return nil } func runCloseSession(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("close-session", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") sessionID := flags.String("session-id", "", "gateway session id") if err := flags.Parse(args); err != nil { return err } if *sessionID == "" { return errors.New("session-id is required") } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() reply, err := client.CloseSessionRaw(ctx, &mxgateway.CloseSessionRequest{SessionId: *sessionID}) if err != nil { return err } if *jsonOutput { return writeJSON(stdout, commandReplyOutput{ Command: "close-session", Options: options, Reply: mustMarshalProto(reply), }) } fmt.Fprintln(stdout, reply.GetFinalState()) return nil } func runRegister(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("register", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") sessionID := flags.String("session-id", "", "gateway session id") clientName := flags.String("client-name", "", "MXAccess client name") if err := flags.Parse(args); err != nil { return err } if *sessionID == "" || *clientName == "" { return errors.New("session-id and client-name are required") } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() session := mxgateway.NewSessionForID(client, *sessionID) reply, err := session.RegisterRaw(ctx, *clientName) return writeCommandOutput(stdout, *jsonOutput, "register", options, reply, err) } func runAddItem(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("add-item", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") sessionID := flags.String("session-id", "", "gateway session id") serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") item := flags.String("item", "", "item definition") if err := flags.Parse(args); err != nil { return err } if *sessionID == "" || *item == "" { return errors.New("session-id and item are required") } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() session := mxgateway.NewSessionForID(client, *sessionID) reply, err := session.AddItemRaw(ctx, int32(*serverHandle), *item) return writeCommandOutput(stdout, *jsonOutput, "add-item", options, reply, err) } func runAdvise(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("advise", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") sessionID := flags.String("session-id", "", "gateway session id") serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") itemHandle := flags.Int("item-handle", 0, "MXAccess item handle") if err := flags.Parse(args); err != nil { return err } if *sessionID == "" { return errors.New("session-id is required") } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() session := mxgateway.NewSessionForID(client, *sessionID) reply, err := session.AdviseRaw(ctx, int32(*serverHandle), int32(*itemHandle)) return writeCommandOutput(stdout, *jsonOutput, "advise", options, reply, err) } func runWrite(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("write", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") sessionID := flags.String("session-id", "", "gateway session id") serverHandle := flags.Int("server-handle", 0, "MXAccess server handle") itemHandle := flags.Int("item-handle", 0, "MXAccess item handle") valueType := flags.String("type", "string", "value type: bool, int32, int64, float, double, string") valueText := flags.String("value", "", "value text") userID := flags.Int("user-id", 0, "MXAccess user id") if err := flags.Parse(args); err != nil { return err } if *sessionID == "" { return errors.New("session-id is required") } value, err := parseValue(*valueType, *valueText) if err != nil { return err } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() session := mxgateway.NewSessionForID(client, *sessionID) reply, err := session.WriteRaw(ctx, int32(*serverHandle), int32(*itemHandle), value, int32(*userID)) return writeCommandOutput(stdout, *jsonOutput, "write", options, reply, err) } func runStreamEvents(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("stream-events", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") sessionID := flags.String("session-id", "", "gateway session id") after := flags.Uint64("after-worker-sequence", 0, "first worker sequence to read after") limit := flags.Int("limit", 0, "maximum events to read; 0 means unbounded") if err := flags.Parse(args); err != nil { return err } if *sessionID == "" { return errors.New("session-id is required") } client, _, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() session := mxgateway.NewSessionForID(client, *sessionID) streamCtx, cancelStream := context.WithCancel(ctx) defer cancelStream() events, err := session.EventsAfter(streamCtx, *after) if err != nil { return err } count := 0 for result := range events { if result.Err != nil { return result.Err } if *jsonOutput { fmt.Fprintln(stdout, string(mustMarshalProto(result.Event))) } else { fmt.Fprintf(stdout, "%d %s\n", result.Event.GetWorkerSequence(), result.Event.GetFamily()) } count++ if *limit > 0 && count >= *limit { cancelStream() return nil } } return nil } func runSmoke(ctx context.Context, args []string, stdout, stderr io.Writer) error { flags := flag.NewFlagSet("smoke", flag.ContinueOnError) flags.SetOutput(stderr) common := bindCommonFlags(flags) jsonOutput := flags.Bool("json", false, "write JSON output") clientName := flags.String("client-name", "mxgw-go-smoke", "MXAccess client name") item := flags.String("item", "", "item definition") if err := flags.Parse(args); err != nil { return err } if *item == "" { return errors.New("item is required") } client, options, err := dialForCommand(ctx, common) if err != nil { return err } defer client.Close() session, err := client.OpenSession(ctx, mxgateway.OpenSessionOptions{ClientSessionName: *clientName}) if err != nil { return err } defer session.Close(context.Background()) serverHandle, err := session.Register(ctx, *clientName) if err != nil { return err } itemHandle, err := session.AddItem(ctx, serverHandle, *item) if err != nil { return err } if err := session.Advise(ctx, serverHandle, itemHandle); err != nil { return err } output := map[string]any{ "command": "smoke", "options": options, "sessionId": session.ID(), "serverHandle": serverHandle, "itemHandle": itemHandle, } if *jsonOutput { return writeJSON(stdout, output) } fmt.Fprintf(stdout, "session=%s server=%d item=%d\n", session.ID(), serverHandle, itemHandle) return nil } func bindCommonFlags(flags *flag.FlagSet) *commonOptions { common := &commonOptions{} flags.StringVar(&common.Endpoint, "endpoint", "localhost:5000", "gateway endpoint") flags.StringVar(&common.APIKey, "api-key", "", "gateway API key") flags.StringVar(&common.APIKeyEnv, "api-key-env", "MXGATEWAY_API_KEY", "environment variable containing the API key") flags.BoolVar(&common.Plaintext, "plaintext", false, "use plaintext transport") flags.StringVar(&common.CACertFile, "ca-cert", "", "CA certificate file") flags.StringVar(&common.ServerName, "server-name-override", "", "TLS server name override") flags.StringVar(&common.CallTimeout, "call-timeout", "30s", "per-call timeout") return common } func dialForCommand(ctx context.Context, common *commonOptions) (*mxgateway.Client, commonOptions, error) { options, err := common.resolved() if err != nil { return nil, options, err } client, err := mxgateway.Dial(ctx, mxgateway.Options{ Endpoint: options.Endpoint, APIKey: options.apiKeyValue, Plaintext: options.Plaintext, CACertFile: options.CACertFile, ServerNameOverride: options.ServerName, CallTimeout: options.timeout, }) return client, options, err } func (o *commonOptions) resolved() (commonOptions, error) { resolved := *o if resolved.APIKey == "" && resolved.APIKeyEnv != "" { resolved.apiKeyValue = os.Getenv(resolved.APIKeyEnv) } else { resolved.apiKeyValue = resolved.APIKey } resolved.APIKey = mxgateway.RedactAPIKey(resolved.apiKeyValue) if resolved.CallTimeout != "" { timeout, err := time.ParseDuration(resolved.CallTimeout) if err != nil { return resolved, err } resolved.timeout = timeout } return resolved, nil } func parseValue(valueType, valueText string) (*mxgateway.MxValue, error) { switch valueType { case "bool": value, err := strconv.ParseBool(valueText) if err != nil { return nil, err } return mxgateway.BoolValue(value), nil case "int32": value, err := strconv.ParseInt(valueText, 10, 32) if err != nil { return nil, err } return mxgateway.Int32Value(int32(value)), nil case "int64": value, err := strconv.ParseInt(valueText, 10, 64) if err != nil { return nil, err } return mxgateway.Int64Value(value), nil case "float": value, err := strconv.ParseFloat(valueText, 32) if err != nil { return nil, err } return mxgateway.FloatValue(float32(value)), nil case "double": value, err := strconv.ParseFloat(valueText, 64) if err != nil { return nil, err } return mxgateway.DoubleValue(value), nil case "string": return mxgateway.StringValue(valueText), nil default: return nil, fmt.Errorf("unsupported value type %q", valueType) } } func writeCommandOutput(stdout io.Writer, jsonOutput bool, command string, options commonOptions, reply *mxgateway.MxCommandReply, err error) error { if err != nil { return err } if jsonOutput { return writeJSON(stdout, commandReplyOutput{ Command: command, Options: options, Reply: mustMarshalProto(reply), }) } fmt.Fprintln(stdout, reply.GetKind()) return nil } func writeJSON(writer io.Writer, value any) error { encoder := json.NewEncoder(writer) encoder.SetIndent("", " ") return encoder.Encode(value) } func mustMarshalProto(message protojsonMessage) json.RawMessage { data, err := protojson.MarshalOptions{UseProtoNames: false}.Marshal(message) if err != nil { panic(err) } return data } type protojsonMessage interface { ProtoReflect() protoreflect.Message } func writeUsage(writer io.Writer) { fmt.Fprintln(writer, "usage: mxgw-go ") }