Files
mxaccessgw/clients/go/mxgateway/client_session_test.go

262 lines
7.3 KiB
Go

package mxgateway
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
pb "gitea.dohertylan.com/dohertj2/mxaccessgw/clients/go/internal/generated"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/test/bufconn"
)
const bufSize = 1024 * 1024
func TestDialAttachesAuthMetadataToUnaryCalls(t *testing.T) {
fake := &fakeGatewayServer{
openReply: &pb.OpenSessionReply{
SessionId: "session-1",
GatewayProtocolVersion: GatewayProtocolVersion,
WorkerProtocolVersion: WorkerProtocolVersion,
ProtocolStatus: &pb.ProtocolStatus{
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
},
},
}
client, cleanup := newBufconnClient(t, fake)
defer cleanup()
_, err := client.OpenSession(context.Background(), OpenSessionOptions{ClientSessionName: "fixture"})
if err != nil {
t.Fatalf("OpenSession() error = %v", err)
}
if got := fake.openAuth; got != "Bearer test-api-key" {
t.Fatalf("authorization metadata = %q, want %q", got, "Bearer test-api-key")
}
}
func TestStreamEventsAttachesAuthMetadataAndClosesOnCancellation(t *testing.T) {
fake := &fakeGatewayServer{
streamStarted: make(chan struct{}),
}
client, cleanup := newBufconnClient(t, fake)
defer cleanup()
session := NewSessionForID(client, "session-1")
ctx, cancel := context.WithCancel(context.Background())
events, err := session.Events(ctx)
if err != nil {
t.Fatalf("Events() error = %v", err)
}
<-fake.streamStarted
first := <-events
if first.Err != nil {
t.Fatalf("first event error = %v", first.Err)
}
if first.Event.GetWorkerSequence() != 1 {
t.Fatalf("worker sequence = %d, want 1", first.Event.GetWorkerSequence())
}
if got := fake.streamAuth; got != "Bearer test-api-key" {
t.Fatalf("stream authorization metadata = %q, want %q", got, "Bearer test-api-key")
}
cancel()
select {
case _, ok := <-events:
if ok {
t.Fatal("events channel produced an extra item after cancellation")
}
case <-time.After(2 * time.Second):
t.Fatal("events channel did not close after cancellation")
}
}
func TestSessionHelpersBuildCommandsAndExposeRawReply(t *testing.T) {
fake := &fakeGatewayServer{
invokeReply: &pb.MxCommandReply{
SessionId: "session-1",
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2,
ProtocolStatus: &pb.ProtocolStatus{
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
},
Payload: &pb.MxCommandReply_AddItem2{
AddItem2: &pb.AddItem2Reply{ItemHandle: 42},
},
},
}
client, cleanup := newBufconnClient(t, fake)
defer cleanup()
session := NewSessionForID(client, "session-1")
itemHandle, err := session.AddItem2(context.Background(), 12, "Area001.Pump001.Speed", "runtime")
if err != nil {
t.Fatalf("AddItem2() error = %v", err)
}
if itemHandle != 42 {
t.Fatalf("item handle = %d, want 42", itemHandle)
}
req := fake.invokeRequest
if req.GetSessionId() != "session-1" {
t.Fatalf("session id = %q, want session-1", req.GetSessionId())
}
if req.GetClientCorrelationId() == "" {
t.Fatal("client correlation id is empty")
}
if req.GetCommand().GetKind() != pb.MxCommandKind_MX_COMMAND_KIND_ADD_ITEM2 {
t.Fatalf("command kind = %s", req.GetCommand().GetKind())
}
if req.GetCommand().GetAddItem2().GetItemContext() != "runtime" {
t.Fatalf("item context = %q, want runtime", req.GetCommand().GetAddItem2().GetItemContext())
}
}
func TestInvokeReturnsTypedMxAccessErrorWithRawReply(t *testing.T) {
hresult := int32(-2147467259)
fake := &fakeGatewayServer{
invokeReply: &pb.MxCommandReply{
SessionId: "session-1",
Kind: pb.MxCommandKind_MX_COMMAND_KIND_ADVISE,
Hresult: &hresult,
DiagnosticMessage: "native failure",
ProtocolStatus: &pb.ProtocolStatus{Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_MXACCESS_FAILURE, Message: "MXAccess failed"},
Statuses: []*pb.MxStatusProxy{{Success: 0, DiagnosticText: "failed"}},
},
}
client, cleanup := newBufconnClient(t, fake)
defer cleanup()
session := NewSessionForID(client, "session-1")
err := session.Advise(context.Background(), 12, 34)
var mxErr *MxAccessError
if !errors.As(err, &mxErr) {
t.Fatalf("error %T does not support errors.As(*MxAccessError)", err)
}
if mxErr.Reply.GetHresult() != hresult {
t.Fatalf("raw reply HRESULT = %d, want %d", mxErr.Reply.GetHresult(), hresult)
}
var commandErr *CommandError
if !errors.As(err, &commandErr) {
t.Fatalf("error %T does not support errors.As(*CommandError)", err)
}
if commandErr.Reply.GetDiagnosticMessage() != "native failure" {
t.Fatalf("raw diagnostic = %q", commandErr.Reply.GetDiagnosticMessage())
}
}
func newBufconnClient(t *testing.T, fake *fakeGatewayServer) (*Client, func()) {
t.Helper()
listener := bufconn.Listen(bufSize)
server := grpc.NewServer()
pb.RegisterMxAccessGatewayServer(server, fake)
go func() {
if err := server.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Errorf("bufconn server failed: %v", err)
}
}()
dialer := func(ctx context.Context, _ string) (net.Conn, error) {
return listener.DialContext(ctx)
}
client, err := Dial(context.Background(), Options{
Endpoint: "bufnet",
APIKey: "test-api-key",
Plaintext: true,
DialOptions: []grpc.DialOption{
grpc.WithContextDialer(dialer),
},
})
if err != nil {
t.Fatalf("Dial() error = %v", err)
}
return client, func() {
client.Close()
server.Stop()
listener.Close()
}
}
type fakeGatewayServer struct {
pb.UnimplementedMxAccessGatewayServer
openReply *pb.OpenSessionReply
openAuth string
streamAuth string
streamStarted chan struct{}
invokeReply *pb.MxCommandReply
invokeRequest *pb.MxCommandRequest
}
func (s *fakeGatewayServer) OpenSession(ctx context.Context, req *pb.OpenSessionRequest) (*pb.OpenSessionReply, error) {
s.openAuth = authorizationFromContext(ctx)
if s.openReply != nil {
return s.openReply, nil
}
return &pb.OpenSessionReply{
SessionId: "session-1",
ProtocolStatus: &pb.ProtocolStatus{
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
},
}, nil
}
func (s *fakeGatewayServer) CloseSession(ctx context.Context, req *pb.CloseSessionRequest) (*pb.CloseSessionReply, error) {
return &pb.CloseSessionReply{
SessionId: req.GetSessionId(),
ProtocolStatus: &pb.ProtocolStatus{
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
},
}, nil
}
func (s *fakeGatewayServer) Invoke(ctx context.Context, req *pb.MxCommandRequest) (*pb.MxCommandReply, error) {
s.invokeRequest = req
if s.invokeReply != nil {
return s.invokeReply, nil
}
return &pb.MxCommandReply{
SessionId: req.GetSessionId(),
Kind: req.GetCommand().GetKind(),
ProtocolStatus: &pb.ProtocolStatus{
Code: pb.ProtocolStatusCode_PROTOCOL_STATUS_CODE_OK,
},
}, nil
}
func (s *fakeGatewayServer) StreamEvents(req *pb.StreamEventsRequest, stream grpc.ServerStreamingServer[pb.MxEvent]) error {
s.streamAuth = authorizationFromContext(stream.Context())
if s.streamStarted != nil {
close(s.streamStarted)
}
if err := stream.Send(&pb.MxEvent{
SessionId: req.GetSessionId(),
Family: pb.MxEventFamily_MX_EVENT_FAMILY_ON_DATA_CHANGE,
WorkerSequence: 1,
}); err != nil {
return err
}
<-stream.Context().Done()
return io.EOF
}
func authorizationFromContext(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ""
}
values := md.Get(authorizationHeader)
if len(values) == 0 {
return ""
}
return values[0]
}