package mxgateway import ( "context" "errors" "io" "net" "testing" "time" pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) const bufSize = 1024 * 1024 func TestDialAttachesAuthMetadataToUnaryCalls(t *testing.T) { fake := &fakeGatewayServer{ openReply: &pb.OpenSessionReply{ SessionId: "session-1", GatewayProtocolVersion: GatewayProtocolVersion, WorkerProtocolVersion: WorkerProtocolVersion, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() _, err := client.OpenSession(context.Background(), OpenSessionOptions{ClientSessionName: "fixture"}) if err != nil { t.Fatalf("OpenSession() error = %v", err) } if got := fake.openAuth; got != "Bearer test-api-key" { t.Fatalf("authorization metadata = %q, want %q", got, "Bearer test-api-key") } } func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") ctx, cancel := context.WithCancel(context.Background()) events, err := session.Events(ctx) if err != nil { t.Fatalf("Events() error = %v", err) } <-fake.streamStarted first := <-events if first.Err != nil { t.Fatalf("first event error = %v", first.Err) } if first.Event.GetWorkerSequence() != 1 { t.Fatalf("worker sequence = %d, want 1", first.Event.GetWorkerSequence()) } if got := fake.streamAuth; got != "Bearer test-api-key" { t.Fatalf("stream authorization metadata = %q, want %q", got, "Bearer test-api-key") } cancel() select { case _, ok := <-events: if ok { t.Fatal("events channel produced an extra item after cancellation") } case <-time.After(2 * time.Second): t.Fatal("events channel did not close after cancellation") } } func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_AddItem2{ AddItem2: &pb.AddItem2Reply{ItemHandle: 42}, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") itemHandle, err := session.AddItem2(context.Background(), 12, "Area001.Pump001.Speed", "runtime") if err != nil { t.Fatalf("AddItem2() error = %v", err) } if itemHandle != 42 { t.Fatalf("item handle = %d, want 42", itemHandle) } req := fake.invokeRequest if req.GetSessionId() != "session-1" { t.Fatalf("session id = %q, want session-1", req.GetSessionId()) } if req.GetClientCorrelationId() == "" { t.Fatal("client correlation id is empty") } if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if req.GetCommand().GetAddItem2().GetItemContext() != "runtime" { t.Fatalf("item context = %q, want runtime", req.GetCommand().GetAddItem2().GetItemContext()) } } func TestInvokeReturnsTypedMxAccessErrorWithRawReply(t *testing.T) { hresult := int32(-2147467259) fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE, Hresult: &hresult, DiagnosticMessage: "native failure", ProtocolStatus: &pb.ProtocolStatus{Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, Message: "MXAccess failed"}, Statuses: []*pb.MxStatusProxy{{Success: 0, DiagnosticText: "failed"}}, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") err := session.Advise(context.Background(), 12, 34) var mxErr *MxAccessError if !errors.As(err, &mxErr) { t.Fatalf("error %T does not support errors.As(*MxAccessError)", err) } if mxErr.Reply.GetHresult() != hresult { t.Fatalf("raw reply HRESULT = %d, want %d", mxErr.Reply.GetHresult(), hresult) } var commandErr *CommandError if !errors.As(err, &commandErr) { t.Fatalf("error %T does not support errors.As(*CommandError)", err) } if commandErr.Reply.GetDiagnosticMessage() != "native failure" { t.Fatalf("raw diagnostic = %q", commandErr.Reply.GetDiagnosticMessage()) } } func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) { t.Helper() listener := bufconn.Listen(bufSize) server := grpc.NewServer() pb.RegisterMxAccessGatewayServer(server, fake) go func() { if err := server.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { t.Errorf("bufconn server failed: %v", err) } }() dialer := func(ctx context.Context, _ string) (net.Conn, error) { return listener.DialContext(ctx) } client, err := Dial(context.Background(), Options{ Endpoint: "bufnet", APIKey: "test-api-key", Plaintext: true, DialOptions: []grpc.DialOption{ grpc.WithContextDialer(dialer), }, }) if err != nil { t.Fatalf("Dial() error = %v", err) } return client, func() { client.Close() server.Stop() listener.Close() } } type fakeGatewayServer struct { pb.UnimplementedMxAccessGatewayServer openReply *pb.OpenSessionReply openAuth string streamAuth string streamStarted chan struct{} invokeReply *pb.MxCommandReply invokeRequest *pb.MxCommandRequest } func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) { s.openAuth = authorizationFromContext(ctx) if s.openReply != nil { return s.openReply, nil } return &pb.OpenSessionReply{ SessionId: "session-1", ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) CloseSession(ctx context.Context, req *pb.CloseSessionRequest) (*pb.CloseSessionReply, error) { return &pb.CloseSessionReply{ SessionId: req.GetSessionId(), ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest) (*pb.MxCommandReply, error) { s.invokeRequest = req if s.invokeReply != nil { return s.invokeReply, nil } return &pb.MxCommandReply{ SessionId: req.GetSessionId(), Kind: req.GetCommand().GetKind(), ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error { s.streamAuth = authorizationFromContext(stream.Context()) if s.streamStarted != nil { close(s.streamStarted) } if err := stream.Send(&pb.MxEvent{ SessionId: req.GetSessionId(), Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE, WorkerSequence: 1, }); err != nil { return err } <-stream.Context().Done() return io.EOF } func authorizationFromContext(ctx context.Context) string { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "" } values := md.Get(authorizationHeader) if len(values) == 0 { return "" } return values[0] }