package mxgateway import ( "context" "errors" "io" "net" "testing" "time" pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) const bufSize = 1024 * 1024 func TestDialAttachesAuthMetadataToUnaryCalls(t *testing.T) { fake := &fakeGatewayServer{ openReply: &pb.OpenSessionReply{ SessionId: "session-1", GatewayProtocolVersion: GatewayProtocolVersion, WorkerProtocolVersion: WorkerProtocolVersion, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() _, err := client.OpenSession(context.Background(), OpenSessionOptions{ClientSessionName: "fixture"}) if err != nil { t.Fatalf("OpenSession() error = %v", err) } if got := fake.openAuth; got != "Bearer test-api-key" { t.Fatalf("authorization metadata = %q, want %q", got, "Bearer test-api-key") } } func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") ctx, cancel := context.WithCancel(context.Background()) events, err := session.Events(ctx) if err != nil { t.Fatalf("Events() error = %v", err) } <-fake.streamStarted first := <-events if first.Err != nil { t.Fatalf("first event error = %v", first.Err) } if first.Event.GetWorkerSequence() != 1 { t.Fatalf("worker sequence = %d, want 1", first.Event.GetWorkerSequence()) } if got := fake.streamAuth; got != "Bearer test-api-key" { t.Fatalf("stream authorization metadata = %q, want %q", got, "Bearer test-api-key") } cancel() select { case _, ok := <-events: if ok { t.Fatal("events channel produced an extra item after cancellation") } case <-time.After(2 * time.Second): t.Fatal("events channel did not close after cancellation") } } func TestEventSubscriptionCloseStopsStream(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), streamDone: make(chan struct{}), } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") subscription, err := session.SubscribeEvents(context.Background()) if err != nil { t.Fatalf("SubscribeEvents() error = %v", err) } <-fake.streamStarted first := <-subscription.Events() if first.Err != nil { t.Fatalf("first event error = %v", first.Err) } subscription.Close() select { case <-fake.streamDone: case <-time.After(2 * time.Second): t.Fatal("event stream did not stop after subscription close") } select { case _, ok := <-subscription.Events(): if ok { t.Fatal("subscription channel remained open after close") } case <-time.After(2 * time.Second): t.Fatal("subscription channel did not close") } } func TestEventsAfterCancelsStreamWhenCompatibilityChannelIsAbandoned(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), streamDone: make(chan struct{}), streamEventCount: 256, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") events, err := session.EventsAfter(context.Background(), 0) if err != nil { t.Fatalf("EventsAfter() error = %v", err) } <-fake.streamStarted select { case <-fake.streamDone: case <-time.After(2 * time.Second): t.Fatal("compatibility event stream did not stop after result channel filled") } // A slow consumer that abandons the buffer must still receive an explicit // terminal overflow error before the channel closes, so it can tell // "events dropped" apart from "stream ended normally". var sawOverflow bool for { select { case result, ok := <-events: if !ok { if !sawOverflow { t.Fatal("compatibility event channel closed without an ErrEventBufferOverflow result") } return } if result.Err != nil { if !errors.Is(result.Err, ErrEventBufferOverflow) { t.Fatalf("terminal result error = %v, want ErrEventBufferOverflow", result.Err) } sawOverflow = true } case <-time.After(2 * time.Second): t.Fatal("compatibility event channel did not close") } } } func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_AddItem2{ AddItem2: &pb.AddItem2Reply{ItemHandle: 42}, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") itemHandle, err := session.AddItem2(context.Background(), 12, "Area001.Pump001.Speed", "runtime") if err != nil { t.Fatalf("AddItem2() error = %v", err) } if itemHandle != 42 { t.Fatalf("item handle = %d, want 42", itemHandle) } req := fake.invokeRequest if req.GetSessionId() != "session-1" { t.Fatalf("session id = %q, want session-1", req.GetSessionId()) } if req.GetClientCorrelationId() == "" { t.Fatal("client correlation id is empty") } if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if req.GetCommand().GetAddItem2().GetItemContext() != "runtime" { t.Fatalf("item context = %q, want runtime", req.GetCommand().GetAddItem2().GetItemContext()) } } func TestSubscribeBulkBuildsOneBulkCommandAndReturnsResults(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_SubscribeBulk{ SubscribeBulk: &pb.BulkSubscribeReply{ Results: []*pb.SubscribeResult{ { ServerHandle: 12, TagAddress: "Area001.Pump001.Speed", ItemHandle: 34, WasSuccessful: true, }, }, }, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") results, err := session.SubscribeBulk(context.Background(), 12, []string{"Area001.Pump001.Speed"}) if err != nil { t.Fatalf("SubscribeBulk() error = %v", err) } if len(results) != 1 || results[0].GetItemHandle() != 34 { t.Fatalf("results = %#v, want item handle 34", results) } req := fake.invokeRequest if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if got := req.GetCommand().GetSubscribeBulk().GetTagAddresses(); len(got) != 1 || got[0] != "Area001.Pump001.Speed" { t.Fatalf("tag addresses = %#v", got) } } func TestWriteBulkBuildsOneBulkCommandAndReturnsPerEntryResults(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_WRITE_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_WriteBulk{ WriteBulk: &pb.BulkWriteReply{ Results: []*pb.BulkWriteResult{ {ServerHandle: 12, ItemHandle: 901, WasSuccessful: true}, {ServerHandle: 12, ItemHandle: 902, WasSuccessful: false, ErrorMessage: "invalid handle"}, }, }, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") results, err := session.WriteBulk(context.Background(), 12, []*pb.WriteBulkEntry{ {ItemHandle: 901, UserId: 5, Value: &pb.MxValue{DataType: pb.MxDataType_MX_DATA_TYPE_INTEGER, Kind: &pb.MxValue_Int32Value{Int32Value: 11}}}, {ItemHandle: 902, UserId: 5, Value: &pb.MxValue{DataType: pb.MxDataType_MX_DATA_TYPE_INTEGER, Kind: &pb.MxValue_Int32Value{Int32Value: 22}}}, }) if err != nil { t.Fatalf("WriteBulk() error = %v", err) } if len(results) != 2 || !results[0].GetWasSuccessful() || results[1].GetWasSuccessful() { t.Fatalf("results = %#v, want [success, failure]", results) } req := fake.invokeRequest if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_WRITE_BULK { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if got := req.GetCommand().GetWriteBulk().GetEntries(); len(got) != 2 { t.Fatalf("entries = %#v, want 2", got) } } func TestReadBulkForwardsTimeoutAndUnpacksCachedFlag(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_READ_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_ReadBulk{ ReadBulk: &pb.BulkReadReply{ Results: []*pb.BulkReadResult{ { ServerHandle: 12, TagAddress: "Area001.Pump001.Speed", ItemHandle: 34, WasSuccessful: true, WasCached: true, Value: &pb.MxValue{DataType: pb.MxDataType_MX_DATA_TYPE_INTEGER, Kind: &pb.MxValue_Int32Value{Int32Value: 99}}, }, }, }, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") results, err := session.ReadBulk(context.Background(), 12, []string{"Area001.Pump001.Speed"}, 750*time.Millisecond) if err != nil { t.Fatalf("ReadBulk() error = %v", err) } if len(results) != 1 || !results[0].GetWasCached() || results[0].GetValue().GetInt32Value() != 99 { t.Fatalf("results = %#v", results) } if got := fake.invokeRequest.GetCommand().GetReadBulk().GetTimeoutMs(); got != 750 { t.Fatalf("timeout_ms = %d, want 750", got) } } func TestInvokeReturnsTypedMxAccessErrorWithRawReply(t *testing.T) { hresult := int32(-2147467259) fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE, Hresult: &hresult, DiagnosticMessage: "native failure", ProtocolStatus: &pb.ProtocolStatus{Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, Message: "MXAccess failed"}, Statuses: []*pb.MxStatusProxy{{Success: 0, DiagnosticText: "failed"}}, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") err := session.Advise(context.Background(), 12, 34) var mxErr *MxAccessError if !errors.As(err, &mxErr) { t.Fatalf("error %T does not support errors.As(*MxAccessError)", err) } if mxErr.Reply.GetHresult() != hresult { t.Fatalf("raw reply HRESULT = %d, want %d", mxErr.Reply.GetHresult(), hresult) } var commandErr *CommandError if !errors.As(err, &commandErr) { t.Fatalf("error %T does not support errors.As(*CommandError)", err) } if commandErr.Reply.GetDiagnosticMessage() != "native failure" { t.Fatalf("raw diagnostic = %q", commandErr.Reply.GetDiagnosticMessage()) } } func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) { t.Helper() listener := bufconn.Listen(bufSize) server := grpc.NewServer() pb.RegisterMxAccessGatewayServer(server, fake) go func() { if err := server.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { t.Errorf("bufconn server failed: %v", err) } }() dialer := func(ctx context.Context, _ string) (net.Conn, error) { return listener.DialContext(ctx) } // grpc.NewClient defaults the target scheme to dns; the bufconn fake name // is not DNS-resolvable, so use the passthrough scheme to hand the target // straight to the context dialer. client, err := Dial(context.Background(), Options{ Endpoint: "passthrough:///bufnet", APIKey: "test-api-key", Plaintext: true, DialOptions: []grpc.DialOption{ grpc.WithContextDialer(dialer), }, }) if err != nil { t.Fatalf("Dial() error = %v", err) } return client, func() { client.Close() server.Stop() listener.Close() } } type fakeGatewayServer struct { pb.UnimplementedMxAccessGatewayServer openReply *pb.OpenSessionReply openAuth string streamAuth string streamStarted chan struct{} streamDone chan struct{} streamEventCount int invokeReply *pb.MxCommandReply invokeRequest *pb.MxCommandRequest } func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) { s.openAuth = authorizationFromContext(ctx) if s.openReply != nil { return s.openReply, nil } return &pb.OpenSessionReply{ SessionId: "session-1", ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) CloseSession(ctx context.Context, req *pb.CloseSessionRequest) (*pb.CloseSessionReply, error) { return &pb.CloseSessionReply{ SessionId: req.GetSessionId(), ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest) (*pb.MxCommandReply, error) { s.invokeRequest = req if s.invokeReply != nil { return s.invokeReply, nil } return &pb.MxCommandReply{ SessionId: req.GetSessionId(), Kind: req.GetCommand().GetKind(), ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error { s.streamAuth = authorizationFromContext(stream.Context()) if s.streamDone != nil { defer close(s.streamDone) } if s.streamStarted != nil { close(s.streamStarted) } eventCount := s.streamEventCount if eventCount == 0 { eventCount = 1 } for sequence := 1; sequence <= eventCount; sequence++ { if err := stream.Send(&pb.MxEvent{ SessionId: req.GetSessionId(), Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE, WorkerSequence: uint64(sequence), }); err != nil { return err } } <-stream.Context().Done() return io.EOF } func authorizationFromContext(ctx context.Context) string { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "" } values := md.Get(authorizationHeader) if len(values) == 0 { return "" } return values[0] }