package mxgateway import ( "context" "crypto/tls" "errors" "net" "reflect" "strings" "testing" "time" pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" ) // --- Client.Go-008: resolveTransportCredentials precedence ----------------- // TestResolveTransportCredentialsPrecedence covers every branch of // resolveTransportCredentials, which previously only had the Plaintext path // exercised. func TestResolveTransportCredentialsPrecedence(t *testing.T) { custom := insecure.NewCredentials() t.Run("TransportCredentialsWins", func(t *testing.T) { creds, err := resolveTransportCredentials(Options{ TransportCredentials: custom, Plaintext: true, // must be ignored }) if err != nil { t.Fatalf("unexpected error: %v", err) } if creds != custom { t.Fatal("expected the explicit TransportCredentials to be returned as-is") } }) t.Run("Plaintext", func(t *testing.T) { creds, err := resolveTransportCredentials(Options{Plaintext: true}) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := creds.Info().SecurityProtocol; got != "insecure" { t.Fatalf("expected insecure credentials, got security protocol %q", got) } }) t.Run("CACertFileMissingErrors", func(t *testing.T) { _, err := resolveTransportCredentials(Options{CACertFile: "does-not-exist.pem"}) if err == nil { t.Fatal("expected an error for a missing CA cert file") } }) t.Run("TLSConfigWithServerNameOverride", func(t *testing.T) { creds, err := resolveTransportCredentials(Options{ TLSConfig: &tls.Config{MinVersion: tls.VersionTLS13}, ServerNameOverride: "gateway.internal", }) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := creds.Info().ServerName; got != "gateway.internal" { t.Fatalf("expected ServerName override to be applied, got %q", got) } }) t.Run("DefaultTLSFloor", func(t *testing.T) { creds, err := resolveTransportCredentials(Options{ServerNameOverride: "host"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := creds.Info().SecurityProtocol; got != "tls" { t.Fatalf("expected the default TLS credentials, got %q", got) } }) } // TestResolveTransportCredentialsDoesNotMutateTLSConfig confirms the supplied // TLSConfig is cloned, not mutated, when ServerNameOverride is applied. func TestResolveTransportCredentialsDoesNotMutateTLSConfig(t *testing.T) { cfg := &tls.Config{MinVersion: tls.VersionTLS12} if _, err := resolveTransportCredentials(Options{ TLSConfig: cfg, ServerNameOverride: "override", }); err != nil { t.Fatalf("unexpected error: %v", err) } if cfg.ServerName != "" { t.Fatalf("resolveTransportCredentials mutated the caller's TLSConfig (ServerName=%q)", cfg.ServerName) } } // --- Client.Go-008: callContext deadline arithmetic ------------------------ // TestCallContextDeadlineArithmetic covers the shared callContext deadline // logic, including the negative-timeout disable case and the // caller-deadline-is-sooner case. func TestCallContextDeadlineArithmetic(t *testing.T) { t.Run("ZeroUsesDefault", func(t *testing.T) { ctx, cancel := callContext(context.Background(), 0) defer cancel() deadline, ok := ctx.Deadline() if !ok { t.Fatal("expected a deadline for the default timeout") } remaining := time.Until(deadline) if remaining <= 0 || remaining > defaultCallTimeout+time.Second { t.Fatalf("default deadline out of range: %v", remaining) } }) t.Run("NegativeDisablesBound", func(t *testing.T) { base := context.Background() ctx, cancel := callContext(base, -1) defer cancel() if _, ok := ctx.Deadline(); ok { t.Fatal("a negative timeout must disable the deadline entirely") } if ctx != base { t.Fatal("a negative timeout must return the caller context unchanged") } }) t.Run("PositiveAppliesTimeout", func(t *testing.T) { ctx, cancel := callContext(context.Background(), 5*time.Second) defer cancel() deadline, ok := ctx.Deadline() if !ok { t.Fatal("expected a deadline") } remaining := time.Until(deadline) if remaining <= 0 || remaining > 5*time.Second+time.Second { t.Fatalf("deadline out of range: %v", remaining) } }) t.Run("CallerDeadlineSoonerIsKept", func(t *testing.T) { base, baseCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer baseCancel() ctx, cancel := callContext(base, 30*time.Second) defer cancel() if ctx != base { t.Fatal("a caller deadline sooner than the timeout must be kept as-is") } }) t.Run("CallerDeadlineLaterIsShortened", func(t *testing.T) { base, baseCancel := context.WithTimeout(context.Background(), time.Hour) defer baseCancel() ctx, cancel := callContext(base, time.Second) defer cancel() deadline, ok := ctx.Deadline() if !ok { t.Fatal("expected a deadline") } if remaining := time.Until(deadline); remaining > 2*time.Second { t.Fatalf("expected the shorter timeout to win, got %v remaining", remaining) } }) } // --- Client.Go-008: NativeValue / NativeArray edge branches ---------------- // TestNativeValueEdgeKinds covers the array, raw-bytes, null, and // nil-input branches of NativeValue. func TestNativeValueEdgeKinds(t *testing.T) { t.Run("NilInput", func(t *testing.T) { got, err := NativeValue(nil) if err != nil || got != nil { t.Fatalf("NativeValue(nil) = (%v, %v), want (nil, nil)", got, err) } }) t.Run("ExplicitNull", func(t *testing.T) { got, err := NativeValue(&pb.MxValue{IsNull: true}) if err != nil || got != nil { t.Fatalf("NativeValue(null) = (%v, %v), want (nil, nil)", got, err) } }) t.Run("RawBytes", func(t *testing.T) { raw := []byte{0x01, 0x02, 0x03} got, err := NativeValue(&pb.MxValue{Kind: &pb.MxValue_RawValue{RawValue: raw}}) if err != nil { t.Fatalf("unexpected error: %v", err) } gotBytes, ok := got.([]byte) if !ok || !reflect.DeepEqual(gotBytes, raw) { t.Fatalf("NativeValue raw = %v, want %v", got, raw) } // The result must be a copy, not aliasing the protobuf field. gotBytes[0] = 0xFF if raw[0] != 0x01 { t.Fatal("NativeValue raw result aliases the protobuf backing array") } }) t.Run("ArrayValue", func(t *testing.T) { value := &pb.MxValue{Kind: &pb.MxValue_ArrayValue{ ArrayValue: &pb.MxArray{Values: &pb.MxArray_Int32Values{ Int32Values: &pb.Int32Array{Values: []int32{7, 8}}, }}, }} got, err := NativeValue(value) if err != nil { t.Fatalf("unexpected error: %v", err) } if !reflect.DeepEqual(got, []int32{7, 8}) { t.Fatalf("NativeValue array = %v, want [7 8]", got) } }) } // TestNativeArrayEdgeKinds covers the nil, raw-bytes, timestamp-with-nil, and // unsupported-kind branches of NativeArray. func TestNativeArrayEdgeKinds(t *testing.T) { t.Run("NilInput", func(t *testing.T) { got, err := NativeArray(nil) if err != nil || got != nil { t.Fatalf("NativeArray(nil) = (%v, %v), want (nil, nil)", got, err) } }) t.Run("RawValues", func(t *testing.T) { got, err := NativeArray(&pb.MxArray{Values: &pb.MxArray_RawValues{ RawValues: &pb.RawArray{Values: [][]byte{{0x0A}, {0x0B}}}, }}) if err != nil { t.Fatalf("unexpected error: %v", err) } want := [][]byte{{0x0A}, {0x0B}} if !reflect.DeepEqual(got, want) { t.Fatalf("NativeArray raw = %v, want %v", got, want) } }) t.Run("TimestampWithNilEntry", func(t *testing.T) { got, err := NativeArray(&pb.MxArray{Values: &pb.MxArray_TimestampValues{ TimestampValues: &pb.TimestampArray{Values: []*timestamppb.Timestamp{nil}}, }}) if err != nil { t.Fatalf("unexpected error: %v", err) } times, ok := got.([]time.Time) if !ok || len(times) != 1 || !times[0].IsZero() { t.Fatalf("NativeArray timestamp-with-nil = %v, want [zero-time]", got) } }) t.Run("UnsupportedKind", func(t *testing.T) { // An MxArray with no oneof set hits the default branch. _, err := NativeArray(&pb.MxArray{}) if err == nil { t.Fatal("expected an error for an MxArray with no values set") } if !strings.Contains(err.Error(), "unsupported array value kind") { t.Fatalf("unexpected error text: %v", err) } }) } // TestNativeValueUnsupportedKind covers the default branch of NativeValue. func TestNativeValueUnsupportedKind(t *testing.T) { // An MxValue with no oneof Kind set and IsNull false hits the default. _, err := NativeValue(&pb.MxValue{}) if err == nil { t.Fatal("expected an error for an MxValue with no kind set") } if !strings.Contains(err.Error(), "unsupported value kind") { t.Fatalf("unexpected error text: %v", err) } } // --- Client.Go-005: dial migration ----------------------------------------- // TestDialFailsFastWhenGatewayUnreachable confirms that after the migration to // grpc.NewClient the DialTimeout-bounded readiness probe still fails fast (and // wraps the failure in *GatewayError) when the gateway cannot be reached. func TestDialFailsFastWhenGatewayUnreachable(t *testing.T) { dialer := func(ctx context.Context, _ string) (net.Conn, error) { return nil, errors.New("connection refused") } start := time.Now() client, err := Dial(context.Background(), Options{ Endpoint: "passthrough:///unreachable", APIKey: "k", Plaintext: true, DialTimeout: 500 * time.Millisecond, DialOptions: []grpc.DialOption{grpc.WithContextDialer(dialer)}, }) elapsed := time.Since(start) if err == nil { client.Close() t.Fatal("expected Dial to fail for an unreachable gateway") } var gwErr *GatewayError if !errors.As(err, &gwErr) || gwErr.Op != "dial" { t.Fatalf("expected a *GatewayError with Op=dial, got %#v", err) } if elapsed > 5*time.Second { t.Fatalf("Dial did not honor DialTimeout: took %v", elapsed) } } // TestDialReadinessProbeReachesReady confirms the readiness probe succeeds // against a live (bufconn) gateway, i.e. the lazy grpc.NewClient connection is // driven to Ready before Dial returns. func TestDialReadinessProbeReachesReady(t *testing.T) { client, cleanup := newBufconnClient(t, &fakeGatewayServer{ openReply: &pb.OpenSessionReply{}, }) defer cleanup() if client == nil { t.Fatal("expected a connected client") } } // --- Client.Go-006: error taxonomy ---------------------------------------- // TestGatewayErrorCode confirms GatewayError.Code surfaces the wrapped gRPC // status code without the caller unwrapping it. func TestGatewayErrorCode(t *testing.T) { var nilErr *GatewayError if got := nilErr.Code(); got != codes.OK { t.Fatalf("nil GatewayError.Code() = %v, want OK", got) } gwErr := &GatewayError{Op: "invoke", Err: status.Error(codes.Unavailable, "down")} if got := gwErr.Code(); got != codes.Unavailable { t.Fatalf("GatewayError.Code() = %v, want Unavailable", got) } plain := &GatewayError{Op: "dial", Err: errors.New("boom")} if got := plain.Code(); got != codes.Unknown { t.Fatalf("GatewayError.Code() for a non-status error = %v, want Unknown", got) } } // TestIsTransient verifies the transient/permanent classification including // the unwrap-through-GatewayError path. func TestIsTransient(t *testing.T) { tests := []struct { name string err error want bool }{ {name: "nil", err: nil, want: false}, {name: "unavailable wrapped", err: &GatewayError{Op: "invoke", Err: status.Error(codes.Unavailable, "x")}, want: true}, {name: "deadline wrapped", err: &GatewayError{Op: "invoke", Err: status.Error(codes.DeadlineExceeded, "x")}, want: true}, {name: "resource exhausted", err: &GatewayError{Err: status.Error(codes.ResourceExhausted, "x")}, want: true}, {name: "unauthenticated permanent", err: &GatewayError{Err: status.Error(codes.Unauthenticated, "x")}, want: false}, {name: "invalid argument permanent", err: &GatewayError{Err: status.Error(codes.InvalidArgument, "x")}, want: false}, {name: "bare status unavailable", err: status.Error(codes.Unavailable, "x"), want: true}, {name: "plain error", err: errors.New("nope"), want: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsTransient(tt.err); got != tt.want { t.Fatalf("IsTransient(%v) = %v, want %v", tt.err, got, tt.want) } }) } } // --- Client.Go-007: correlation id fallback -------------------------------- // TestNewCorrelationIDUsesRandEntropy confirms the happy path yields a // 32-hex-character id. func TestNewCorrelationIDUsesRandEntropy(t *testing.T) { id := newCorrelationID() if len(id) != 32 { t.Fatalf("expected a 32-char hex id, got %q (len %d)", id, len(id)) } } // TestNewCorrelationIDFallsBackOnRandFailure reproduces Client.Go-007: when // crypto/rand fails, newCorrelationID must not return an empty string but a // unique, non-empty fallback id so the command stays traceable. func TestNewCorrelationIDFallsBackOnRandFailure(t *testing.T) { original := randRead randRead = func([]byte) (int, error) { return 0, errors.New("entropy unavailable") } defer func() { randRead = original }() first := newCorrelationID() second := newCorrelationID() if first == "" || second == "" { t.Fatal("newCorrelationID returned an empty id on rand failure") } if !strings.HasPrefix(first, "fallback-") { t.Fatalf("expected a fallback- prefixed id, got %q", first) } if first == second { t.Fatalf("fallback correlation ids must be unique, got %q twice", first) } }