386 lines
11 KiB
Go
386 lines
11 KiB
Go
package mxgateway
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
)
|
|
|
|
const bufSize = 1024 * 1024
|
|
|
|
func TestDialAttachesAuthMetadataToUnaryCalls(t *testing.T) {
|
|
fake := &fakeGatewayServer{
|
|
openReply: &pb.OpenSessionReply{
|
|
SessionId: "session-1",
|
|
GatewayProtocolVersion: GatewayProtocolVersion,
|
|
WorkerProtocolVersion: WorkerProtocolVersion,
|
|
ProtocolStatus: &pb.ProtocolStatus{
|
|
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
|
|
},
|
|
},
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
|
|
_, err := client.OpenSession(context.Background(), OpenSessionOptions{ClientSessionName: "fixture"})
|
|
if err != nil {
|
|
t.Fatalf("OpenSession() error = %v", err)
|
|
}
|
|
|
|
if got := fake.openAuth; got != "Bearer test-api-key" {
|
|
t.Fatalf("authorization metadata = %q, want %q", got, "Bearer test-api-key")
|
|
}
|
|
}
|
|
|
|
func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) {
|
|
fake := &fakeGatewayServer{
|
|
streamStarted: make(chan struct{}),
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
session := NewSessionForID(client, "session-1")
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
events, err := session.Events(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Events() error = %v", err)
|
|
}
|
|
<-fake.streamStarted
|
|
|
|
first := <-events
|
|
if first.Err != nil {
|
|
t.Fatalf("first event error = %v", first.Err)
|
|
}
|
|
if first.Event.GetWorkerSequence() != 1 {
|
|
t.Fatalf("worker sequence = %d, want 1", first.Event.GetWorkerSequence())
|
|
}
|
|
if got := fake.streamAuth; got != "Bearer test-api-key" {
|
|
t.Fatalf("stream authorization metadata = %q, want %q", got, "Bearer test-api-key")
|
|
}
|
|
|
|
cancel()
|
|
select {
|
|
case _, ok := <-events:
|
|
if ok {
|
|
t.Fatal("events channel produced an extra item after cancellation")
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("events channel did not close after cancellation")
|
|
}
|
|
}
|
|
|
|
func TestEventSubscriptionCloseStopsStream(t *testing.T) {
|
|
fake := &fakeGatewayServer{
|
|
streamStarted: make(chan struct{}),
|
|
streamDone: make(chan struct{}),
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
session := NewSessionForID(client, "session-1")
|
|
|
|
subscription, err := session.SubscribeEvents(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("SubscribeEvents() error = %v", err)
|
|
}
|
|
<-fake.streamStarted
|
|
first := <-subscription.Events()
|
|
if first.Err != nil {
|
|
t.Fatalf("first event error = %v", first.Err)
|
|
}
|
|
|
|
subscription.Close()
|
|
|
|
select {
|
|
case <-fake.streamDone:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("event stream did not stop after subscription close")
|
|
}
|
|
select {
|
|
case _, ok := <-subscription.Events():
|
|
if ok {
|
|
t.Fatal("subscription channel remained open after close")
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("subscription channel did not close")
|
|
}
|
|
}
|
|
|
|
func TestEventsAfterCancelsStreamWhenCompatibilityChannelIsAbandoned(t *testing.T) {
|
|
fake := &fakeGatewayServer{
|
|
streamStarted: make(chan struct{}),
|
|
streamDone: make(chan struct{}),
|
|
streamEventCount: 64,
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
session := NewSessionForID(client, "session-1")
|
|
|
|
events, err := session.EventsAfter(context.Background(), 0)
|
|
if err != nil {
|
|
t.Fatalf("EventsAfter() error = %v", err)
|
|
}
|
|
<-fake.streamStarted
|
|
|
|
select {
|
|
case <-fake.streamDone:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("compatibility event stream did not stop after result channel filled")
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case _, ok := <-events:
|
|
if !ok {
|
|
return
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("compatibility event channel did not close")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) {
|
|
fake := &fakeGatewayServer{
|
|
invokeReply: &pb.MxCommandReply{
|
|
SessionId: "session-1",
|
|
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2,
|
|
ProtocolStatus: &pb.ProtocolStatus{
|
|
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
|
|
},
|
|
Payload: &pb.MxCommandReply_AddItem2{
|
|
AddItem2: &pb.AddItem2Reply{ItemHandle: 42},
|
|
},
|
|
},
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
session := NewSessionForID(client, "session-1")
|
|
|
|
itemHandle, err := session.AddItem2(context.Background(), 12, "Area001.Pump001.Speed", "runtime")
|
|
if err != nil {
|
|
t.Fatalf("AddItem2() error = %v", err)
|
|
}
|
|
|
|
if itemHandle != 42 {
|
|
t.Fatalf("item handle = %d, want 42", itemHandle)
|
|
}
|
|
req := fake.invokeRequest
|
|
if req.GetSessionId() != "session-1" {
|
|
t.Fatalf("session id = %q, want session-1", req.GetSessionId())
|
|
}
|
|
if req.GetClientCorrelationId() == "" {
|
|
t.Fatal("client correlation id is empty")
|
|
}
|
|
if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 {
|
|
t.Fatalf("command kind = %s", req.GetCommand().GetKind())
|
|
}
|
|
if req.GetCommand().GetAddItem2().GetItemContext() != "runtime" {
|
|
t.Fatalf("item context = %q, want runtime", req.GetCommand().GetAddItem2().GetItemContext())
|
|
}
|
|
}
|
|
|
|
func TestSubscribeBulkBuildsOneBulkCommandAndReturnsResults(t *testing.T) {
|
|
fake := &fakeGatewayServer{
|
|
invokeReply: &pb.MxCommandReply{
|
|
SessionId: "session-1",
|
|
Kind: pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK,
|
|
ProtocolStatus: &pb.ProtocolStatus{
|
|
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
|
|
},
|
|
Payload: &pb.MxCommandReply_SubscribeBulk{
|
|
SubscribeBulk: &pb.BulkSubscribeReply{
|
|
Results: []*pb.SubscribeResult{
|
|
{
|
|
ServerHandle: 12,
|
|
TagAddress: "Area001.Pump001.Speed",
|
|
ItemHandle: 34,
|
|
WasSuccessful: true,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
session := NewSessionForID(client, "session-1")
|
|
|
|
results, err := session.SubscribeBulk(context.Background(), 12, []string{"Area001.Pump001.Speed"})
|
|
if err != nil {
|
|
t.Fatalf("SubscribeBulk() error = %v", err)
|
|
}
|
|
|
|
if len(results) != 1 || results[0].GetItemHandle() != 34 {
|
|
t.Fatalf("results = %#v, want item handle 34", results)
|
|
}
|
|
req := fake.invokeRequest
|
|
if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_SUBSCRIBE_BULK {
|
|
t.Fatalf("command kind = %s", req.GetCommand().GetKind())
|
|
}
|
|
if got := req.GetCommand().GetSubscribeBulk().GetTagAddresses(); len(got) != 1 || got[0] != "Area001.Pump001.Speed" {
|
|
t.Fatalf("tag addresses = %#v", got)
|
|
}
|
|
}
|
|
|
|
func TestInvokeReturnsTypedMxAccessErrorWithRawReply(t *testing.T) {
|
|
hresult := int32(-2147467259)
|
|
fake := &fakeGatewayServer{
|
|
invokeReply: &pb.MxCommandReply{
|
|
SessionId: "session-1",
|
|
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE,
|
|
Hresult: &hresult,
|
|
DiagnosticMessage: "native failure",
|
|
ProtocolStatus: &pb.ProtocolStatus{Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, Message: "MXAccess failed"},
|
|
Statuses: []*pb.MxStatusProxy{{Success: 0, DiagnosticText: "failed"}},
|
|
},
|
|
}
|
|
client, cleanup := newBufconnClient(t, fake)
|
|
defer cleanup()
|
|
session := NewSessionForID(client, "session-1")
|
|
|
|
err := session.Advise(context.Background(), 12, 34)
|
|
|
|
var mxErr *MxAccessError
|
|
if !errors.As(err, &mxErr) {
|
|
t.Fatalf("error %T does not support errors.As(*MxAccessError)", err)
|
|
}
|
|
if mxErr.Reply.GetHresult() != hresult {
|
|
t.Fatalf("raw reply HRESULT = %d, want %d", mxErr.Reply.GetHresult(), hresult)
|
|
}
|
|
var commandErr *CommandError
|
|
if !errors.As(err, &commandErr) {
|
|
t.Fatalf("error %T does not support errors.As(*CommandError)", err)
|
|
}
|
|
if commandErr.Reply.GetDiagnosticMessage() != "native failure" {
|
|
t.Fatalf("raw diagnostic = %q", commandErr.Reply.GetDiagnosticMessage())
|
|
}
|
|
}
|
|
|
|
func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) {
|
|
t.Helper()
|
|
|
|
listener := bufconn.Listen(bufSize)
|
|
server := grpc.NewServer()
|
|
pb.RegisterMxAccessGatewayServer(server, fake)
|
|
go func() {
|
|
if err := server.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
|
t.Errorf("bufconn server failed: %v", err)
|
|
}
|
|
}()
|
|
|
|
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
|
|
return listener.DialContext(ctx)
|
|
}
|
|
client, err := Dial(context.Background(), Options{
|
|
Endpoint: "bufnet",
|
|
APIKey: "test-api-key",
|
|
Plaintext: true,
|
|
DialOptions: []grpc.DialOption{
|
|
grpc.WithContextDialer(dialer),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Dial() error = %v", err)
|
|
}
|
|
|
|
return client, func() {
|
|
client.Close()
|
|
server.Stop()
|
|
listener.Close()
|
|
}
|
|
}
|
|
|
|
type fakeGatewayServer struct {
|
|
pb.UnimplementedMxAccessGatewayServer
|
|
|
|
openReply *pb.OpenSessionReply
|
|
openAuth string
|
|
streamAuth string
|
|
streamStarted chan struct{}
|
|
streamDone chan struct{}
|
|
streamEventCount int
|
|
invokeReply *pb.MxCommandReply
|
|
invokeRequest *pb.MxCommandRequest
|
|
}
|
|
|
|
func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) {
|
|
s.openAuth = authorizationFromContext(ctx)
|
|
if s.openReply != nil {
|
|
return s.openReply, nil
|
|
}
|
|
return &pb.OpenSessionReply{
|
|
SessionId: "session-1",
|
|
ProtocolStatus: &pb.ProtocolStatus{
|
|
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *fakeGatewayServer) CloseSession(ctx context.Context, req *pb.CloseSessionRequest) (*pb.CloseSessionReply, error) {
|
|
return &pb.CloseSessionReply{
|
|
SessionId: req.GetSessionId(),
|
|
ProtocolStatus: &pb.ProtocolStatus{
|
|
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest) (*pb.MxCommandReply, error) {
|
|
s.invokeRequest = req
|
|
if s.invokeReply != nil {
|
|
return s.invokeReply, nil
|
|
}
|
|
return &pb.MxCommandReply{
|
|
SessionId: req.GetSessionId(),
|
|
Kind: req.GetCommand().GetKind(),
|
|
ProtocolStatus: &pb.ProtocolStatus{
|
|
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error {
|
|
s.streamAuth = authorizationFromContext(stream.Context())
|
|
if s.streamDone != nil {
|
|
defer close(s.streamDone)
|
|
}
|
|
if s.streamStarted != nil {
|
|
close(s.streamStarted)
|
|
}
|
|
eventCount := s.streamEventCount
|
|
if eventCount == 0 {
|
|
eventCount = 1
|
|
}
|
|
for sequence := 1; sequence <= eventCount; sequence++ {
|
|
if err := stream.Send(&pb.MxEvent{
|
|
SessionId: req.GetSessionId(),
|
|
Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE,
|
|
WorkerSequence: uint64(sequence),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
<-stream.Context().Done()
|
|
return io.EOF
|
|
}
|
|
|
|
func authorizationFromContext(ctx context.Context) string {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
values := md.Get(authorizationHeader)
|
|
if len(values) == 0 {
|
|
return ""
|
|
}
|
|
return values[0]
|
|
}
|