package mxgateway import ( "context" "errors" "io" "net" "testing" "time" pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) const bufSize = 1024 * 1024 func TestDialAttachesAuthMetadataToUnaryCalls(t *testing.T) { fake := &fakeGatewayServer{ openReply: &pb.OpenSessionReply{ SessionId: "session-1", GatewayProtocolVersion: GatewayProtocolVersion, WorkerProtocolVersion: WorkerProtocolVersion, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() _, err := client.OpenSession(context.Background(), OpenSessionOptions{ClientSessionName: "fixture"}) if err != nil { t.Fatalf("OpenSession() error = %v", err) } if got := fake.openAuth; got != "Bearer test-api-key" { t.Fatalf("authorization metadata = %q, want %q", got, "Bearer test-api-key") } } func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") ctx, cancel := context.WithCancel(context.Background()) events, err := session.Events(ctx) if err != nil { t.Fatalf("Events() error = %v", err) } <-fake.streamStarted first := <-events if first.Err != nil { t.Fatalf("first event error = %v", first.Err) } if first.Event.GetWorkerSequence() != 1 { t.Fatalf("worker sequence = %d, want 1", first.Event.GetWorkerSequence()) } if got := fake.streamAuth; got != "Bearer test-api-key" { t.Fatalf("stream authorization metadata = %q, want %q", got, "Bearer test-api-key") } cancel() select { case _, ok := <-events: if ok { t.Fatal("events channel produced an extra item after cancellation") } case <-time.After(2 * time.Second): t.Fatal("events channel did not close after cancellation") } } func TestEventSubscriptionCloseStopsStream(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), streamDone: make(chan struct{}), } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") subscription, err := session.SubscribeEvents(context.Background()) if err != nil { t.Fatalf("SubscribeEvents() error = %v", err) } <-fake.streamStarted first := <-subscription.Events() if first.Err != nil { t.Fatalf("first event error = %v", first.Err) } subscription.Close() select { case <-fake.streamDone: case <-time.After(2 * time.Second): t.Fatal("event stream did not stop after subscription close") } select { case _, ok := <-subscription.Events(): if ok { t.Fatal("subscription channel remained open after close") } case <-time.After(2 * time.Second): t.Fatal("subscription channel did not close") } } func TestEventsAfterCancelsStreamWhenCompatibilityChannelIsAbandoned(t *testing.T) { fake := &fakeGatewayServer{ streamStarted: make(chan struct{}), streamDone: make(chan struct{}), streamEventCount: 64, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") events, err := session.EventsAfter(context.Background(), 0) if err != nil { t.Fatalf("EventsAfter() error = %v", err) } <-fake.streamStarted select { case <-fake.streamDone: case <-time.After(2 * time.Second): t.Fatal("compatibility event stream did not stop after result channel filled") } for { select { case _, ok := <-events: if !ok { return } case <-time.After(2 * time.Second): t.Fatal("compatibility event channel did not close") } } } func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_AddItem2{ AddItem2: &pb.AddItem2Reply{ItemHandle: 42}, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") itemHandle, err := session.AddItem2(context.Background(), 12, "Area001.Pump001.Speed", "runtime") if err != nil { t.Fatalf("AddItem2() error = %v", err) } if itemHandle != 42 { t.Fatalf("item handle = %d, want 42", itemHandle) } req := fake.invokeRequest if req.GetSessionId() != "session-1" { t.Fatalf("session id = %q, want session-1", req.GetSessionId()) } if req.GetClientCorrelationId() == "" { t.Fatal("client correlation id is empty") } if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if req.GetCommand().GetAddItem2().GetItemContext() != "runtime" { t.Fatalf("item context = %q, want runtime", req.GetCommand().GetAddItem2().GetItemContext()) } } func TestSubscribeBulkBuildsOneBulkCommandAndReturnsResults(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_SubscribeBulk{ SubscribeBulk: &pb.BulkSubscribeReply{ Results: []*pb.SubscribeResult{ { ServerHandle: 12, TagAddress: "Area001.Pump001.Speed", ItemHandle: 34, WasSuccessful: true, }, }, }, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") results, err := session.SubscribeBulk(context.Background(), 12, []string{"Area001.Pump001.Speed"}) if err != nil { t.Fatalf("SubscribeBulk() error = %v", err) } if len(results) != 1 || results[0].GetItemHandle() != 34 { t.Fatalf("results = %#v, want item handle 34", results) } req := fake.invokeRequest if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if got := req.GetCommand().GetSubscribeBulk().GetTagAddresses(); len(got) != 1 || got[0] != "Area001.Pump001.Speed" { t.Fatalf("tag addresses = %#v", got) } } func TestWriteBulkBuildsOneBulkCommandAndReturnsPerEntryResults(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_WRITE_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_WriteBulk{ WriteBulk: &pb.BulkWriteReply{ Results: []*pb.BulkWriteResult{ {ItemHandle: 10, WasSuccessful: true}, {ItemHandle: 11, WasSuccessful: true}, }, }, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") entries := []*WriteBulkEntry{ {ItemHandle: 10, Value: Int32Value(7), UserId: 100}, {ItemHandle: 11, Value: Int32Value(8), UserId: 100}, } results, err := session.WriteBulk(context.Background(), 12, entries) if err != nil { t.Fatalf("WriteBulk() error = %v", err) } if len(results) != 2 { t.Fatalf("results len = %d, want 2", len(results)) } req := fake.invokeRequest if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_WRITE_BULK { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if got := req.GetCommand().GetWriteBulk().GetEntries(); len(got) != 2 { t.Fatalf("entry count = %d, want 2", len(got)) } } func TestWriteBulkRejectsNilEntries(t *testing.T) { fake := &fakeGatewayServer{} client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") if _, err := session.WriteBulk(context.Background(), 12, nil); err == nil { t.Fatal("WriteBulk(nil) returned no error") } if _, err := session.Write2Bulk(context.Background(), 12, nil); err == nil { t.Fatal("Write2Bulk(nil) returned no error") } if _, err := session.WriteSecuredBulk(context.Background(), 12, nil); err == nil { t.Fatal("WriteSecuredBulk(nil) returned no error") } if _, err := session.WriteSecured2Bulk(context.Background(), 12, nil); err == nil { t.Fatal("WriteSecured2Bulk(nil) returned no error") } if _, err := session.ReadBulk(context.Background(), 12, nil, 0); err == nil { t.Fatal("ReadBulk(nil) returned no error") } } func TestBulkMethodsShortCircuitOnEmptySliceWithoutRoundTrip(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") results, err := session.WriteBulk(context.Background(), 12, []*WriteBulkEntry{}) if err != nil { t.Fatalf("WriteBulk(empty) error = %v", err) } if len(results) != 0 { t.Fatalf("WriteBulk(empty) results len = %d, want 0", len(results)) } if fake.invokeRequest != nil { t.Fatal("WriteBulk(empty) sent a round trip; expected short-circuit") } results2, err := session.Write2Bulk(context.Background(), 12, []*Write2BulkEntry{}) if err != nil { t.Fatalf("Write2Bulk(empty) error = %v", err) } if len(results2) != 0 { t.Fatalf("Write2Bulk(empty) results len = %d, want 0", len(results2)) } if fake.invokeRequest != nil { t.Fatal("Write2Bulk(empty) sent a round trip; expected short-circuit") } results3, err := session.WriteSecuredBulk(context.Background(), 12, []*WriteSecuredBulkEntry{}) if err != nil { t.Fatalf("WriteSecuredBulk(empty) error = %v", err) } if len(results3) != 0 { t.Fatalf("WriteSecuredBulk(empty) results len = %d, want 0", len(results3)) } if fake.invokeRequest != nil { t.Fatal("WriteSecuredBulk(empty) sent a round trip; expected short-circuit") } results4, err := session.WriteSecured2Bulk(context.Background(), 12, []*WriteSecured2BulkEntry{}) if err != nil { t.Fatalf("WriteSecured2Bulk(empty) error = %v", err) } if len(results4) != 0 { t.Fatalf("WriteSecured2Bulk(empty) results len = %d, want 0", len(results4)) } if fake.invokeRequest != nil { t.Fatal("WriteSecured2Bulk(empty) sent a round trip; expected short-circuit") } readResults, err := session.ReadBulk(context.Background(), 12, []string{}, 0) if err != nil { t.Fatalf("ReadBulk(empty) error = %v", err) } if len(readResults) != 0 { t.Fatalf("ReadBulk(empty) results len = %d, want 0", len(readResults)) } if fake.invokeRequest != nil { t.Fatal("ReadBulk(empty) sent a round trip; expected short-circuit") } } func TestReadBulkForwardsTimeoutAndUnpacksCachedFlag(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_READ_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, Payload: &pb.MxCommandReply_ReadBulk{ ReadBulk: &pb.BulkReadReply{ Results: []*pb.BulkReadResult{ {TagAddress: "Tank01.Level", WasSuccessful: true, WasCached: true}, {TagAddress: "Tank02.Level", WasSuccessful: true, WasCached: false}, }, }, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") results, err := session.ReadBulk(context.Background(), 12, []string{"Tank01.Level", "Tank02.Level"}, 250*time.Millisecond) if err != nil { t.Fatalf("ReadBulk() error = %v", err) } if len(results) != 2 { t.Fatalf("results len = %d, want 2", len(results)) } if !results[0].GetWasCached() || results[1].GetWasCached() { t.Fatalf("WasCached flags = [%v %v], want [true false]", results[0].GetWasCached(), results[1].GetWasCached()) } req := fake.invokeRequest if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_READ_BULK { t.Fatalf("command kind = %s", req.GetCommand().GetKind()) } if got := req.GetCommand().GetReadBulk().GetTimeoutMs(); got != 250 { t.Fatalf("timeout ms = %d, want 250", got) } } func TestReadBulkSaturatesTimeoutAboveMaxUint32(t *testing.T) { fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_READ_BULK, ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") // 100 days in milliseconds exceeds MaxUint32 (~49.7 days). hugeTimeout := 100 * 24 * time.Hour _, err := session.ReadBulk(context.Background(), 12, []string{"Tank01.Level"}, hugeTimeout) if err != nil { t.Fatalf("ReadBulk() error = %v", err) } got := fake.invokeRequest.GetCommand().GetReadBulk().GetTimeoutMs() if got != ^uint32(0) { t.Fatalf("timeout ms = %d, want %d (MaxUint32)", got, ^uint32(0)) } } func TestInvokeReturnsTypedMxAccessErrorWithRawReply(t *testing.T) { hresult := int32(-2147467259) fake := &fakeGatewayServer{ invokeReply: &pb.MxCommandReply{ SessionId: "session-1", Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE, Hresult: &hresult, DiagnosticMessage: "native failure", ProtocolStatus: &pb.ProtocolStatus{Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, Message: "MXAccess failed"}, Statuses: []*pb.MxStatusProxy{{Success: 0, DiagnosticText: "failed"}}, }, } client, cleanup := newBufconnClient(t, fake) defer cleanup() session := NewSessionForID(client, "session-1") err := session.Advise(context.Background(), 12, 34) var mxErr *MxAccessError if !errors.As(err, &mxErr) { t.Fatalf("error %T does not support errors.As(*MxAccessError)", err) } if mxErr.Reply.GetHresult() != hresult { t.Fatalf("raw reply HRESULT = %d, want %d", mxErr.Reply.GetHresult(), hresult) } var commandErr *CommandError if !errors.As(err, &commandErr) { t.Fatalf("error %T does not support errors.As(*CommandError)", err) } if commandErr.Reply.GetDiagnosticMessage() != "native failure" { t.Fatalf("raw diagnostic = %q", commandErr.Reply.GetDiagnosticMessage()) } } func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) { t.Helper() listener := bufconn.Listen(bufSize) server := grpc.NewServer() pb.RegisterMxAccessGatewayServer(server, fake) go func() { if err := server.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) { t.Errorf("bufconn server failed: %v", err) } }() dialer := func(ctx context.Context, _ string) (net.Conn, error) { return listener.DialContext(ctx) } client, err := Dial(context.Background(), Options{ Endpoint: "bufnet", APIKey: "test-api-key", Plaintext: true, DialOptions: []grpc.DialOption{ grpc.WithContextDialer(dialer), }, }) if err != nil { t.Fatalf("Dial() error = %v", err) } return client, func() { client.Close() server.Stop() listener.Close() } } type fakeGatewayServer struct { pb.UnimplementedMxAccessGatewayServer openReply *pb.OpenSessionReply openAuth string streamAuth string streamStarted chan struct{} streamDone chan struct{} streamEventCount int invokeReply *pb.MxCommandReply invokeRequest *pb.MxCommandRequest } func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) { s.openAuth = authorizationFromContext(ctx) if s.openReply != nil { return s.openReply, nil } return &pb.OpenSessionReply{ SessionId: "session-1", ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) CloseSession(ctx context.Context, req *pb.CloseSessionRequest) (*pb.CloseSessionReply, error) { return &pb.CloseSessionReply{ SessionId: req.GetSessionId(), ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest) (*pb.MxCommandReply, error) { s.invokeRequest = req if s.invokeReply != nil { return s.invokeReply, nil } return &pb.MxCommandReply{ SessionId: req.GetSessionId(), Kind: req.GetCommand().GetKind(), ProtocolStatus: &pb.ProtocolStatus{ Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK, }, }, nil } func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error { s.streamAuth = authorizationFromContext(stream.Context()) if s.streamDone != nil { defer close(s.streamDone) } if s.streamStarted != nil { close(s.streamStarted) } eventCount := s.streamEventCount if eventCount == 0 { eventCount = 1 } for sequence := 1; sequence <= eventCount; sequence++ { if err := stream.Send(&pb.MxEvent{ SessionId: req.GetSessionId(), Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE, WorkerSequence: uint64(sequence), }); err != nil { return err } } <-stream.Context().Done() return io.EOF } func authorizationFromContext(ctx context.Context) string { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "" } values := md.Get(authorizationHeader) if len(values) == 0 { return "" } return values[0] }