use std::pin::Pin; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; use std::task::{Context, Poll}; use std::time::Duration; use futures_core::Stream; use futures_util::StreamExt; use serde_json::Value; use tokio::net::TcpListener; use tokio::sync::{mpsc, Mutex}; use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream}; use tonic::transport::Server; use tonic::{Request, Response, Status}; use zb_mom_ww_mxgateway_client::generated::mxaccess_gateway::v1::mx_access_gateway_server::{ MxAccessGateway, MxAccessGatewayServer, }; use zb_mom_ww_mxgateway_client::generated::mxaccess_gateway::v1::mx_command_reply; use zb_mom_ww_mxgateway_client::generated::mxaccess_gateway::v1::mx_value::Kind; use zb_mom_ww_mxgateway_client::generated::mxaccess_gateway::v1::{ AcknowledgeAlarmReply, AcknowledgeAlarmRequest, ActiveAlarmSnapshot, AddItemReply, AlarmFeedMessage, BulkSubscribeReply, CloseSessionReply, CloseSessionRequest, MxCommandKind, MxCommandReply, MxDataType, MxEvent, MxEventFamily, MxStatusCategory, MxStatusProxy, MxStatusSource, MxValue, OpenSessionReply, OpenSessionRequest, ProtocolStatus, ProtocolStatusCode, QueryActiveAlarmsRequest, SessionState, StreamAlarmsRequest, StreamEventsRequest, SubscribeResult, }; use zb_mom_ww_mxgateway_client::{ ApiKey, ClientOptions, CommandError, Error, GatewayClient, MxStatus, MxValue as ClientMxValue, MxValueProjection, }; #[tokio::test] async fn fake_server_receives_bearer_metadata_and_raw_client_is_reachable() { let state = Arc::new(FakeState::default()); let endpoint = spawn_fake_gateway(state.clone()).await; let mut client = GatewayClient::connect( ClientOptions::new(endpoint).with_api_key(ApiKey::new("mxgw_fixture_secret")), ) .await .unwrap(); let _raw = client.raw_client(); let session = client .open_session(OpenSessionRequest { client_session_name: "rust-test".to_owned(), ..OpenSessionRequest::default() }) .await .unwrap(); assert_eq!(session.id(), "session-fixture"); assert_eq!( state.authorization.lock().await.as_deref(), Some("Bearer mxgw_fixture_secret") ); } #[tokio::test] async fn session_helpers_build_commands_and_preserve_command_errors() { let state = Arc::new(FakeState::default()); let endpoint = spawn_fake_gateway(state.clone()).await; let client = GatewayClient::connect(ClientOptions::new(endpoint)) .await .unwrap(); let session = client.session("session-fixture"); let item_handle = session.add_item(12, "Plant.Area.Tag").await.unwrap(); assert_eq!(item_handle, 34); let last_command = state.last_command_kind.lock().await; assert_eq!(*last_command, Some(MxCommandKind::AddItem as i32)); drop(last_command); let error = session .write(12, 34, ClientMxValue::int32(123), 0) .await .unwrap_err(); let Error::Command(error) = error else { panic!("write failure should preserve the raw command reply: {error:?}"); }; assert_eq!( error.reply().protocol_status.as_ref().unwrap().code, ProtocolStatusCode::MxaccessFailure as i32 ); assert_eq!(error.reply().hresult, Some(-2147220992)); assert_eq!(error.reply().statuses.len(), 2); } #[tokio::test] async fn subscribe_bulk_builds_one_bulk_command_and_returns_results() { let state = Arc::new(FakeState::default()); let endpoint = spawn_fake_gateway(state.clone()).await; let client = GatewayClient::connect(ClientOptions::new(endpoint)) .await .unwrap(); let session = client.session("session-fixture"); let results = session .subscribe_bulk(12, vec!["Area001.Pump001.Speed".to_owned()]) .await .unwrap(); assert_eq!(results[0].item_handle, 34); let last_command = state.last_command_kind.lock().await; assert_eq!(*last_command, Some(MxCommandKind::SubscribeBulk as i32)); } #[tokio::test] async fn event_stream_preserves_order_and_drop_cancels_server_stream() { let state = Arc::new(FakeState::default()); let endpoint = spawn_fake_gateway(state.clone()).await; let client = GatewayClient::connect(ClientOptions::new(endpoint)) .await .unwrap(); let mut stream = client .stream_events(StreamEventsRequest { session_id: "session-fixture".to_owned(), after_worker_sequence: 0, }) .await .unwrap(); assert_eq!(stream.next().await.unwrap().unwrap().worker_sequence, 1); assert_eq!(stream.next().await.unwrap().unwrap().worker_sequence, 2); drop(stream); for _ in 0..20 { if state.stream_dropped.load(Ordering::SeqCst) { return; } tokio::time::sleep(Duration::from_millis(25)).await; } assert!(state.stream_dropped.load(Ordering::SeqCst)); } #[tokio::test] async fn acknowledge_alarm_returns_reply_with_native_status() { let state = Arc::new(FakeState::default()); let endpoint = spawn_fake_gateway(state.clone()).await; let client = GatewayClient::connect(ClientOptions::new(endpoint)) .await .unwrap(); let reply = client .acknowledge_alarm(AcknowledgeAlarmRequest { client_correlation_id: "corr-1".to_owned(), alarm_full_reference: "Tank01.Level.HiHi".to_owned(), comment: "investigating".to_owned(), operator_user: "alice".to_owned(), }) .await .unwrap(); assert_eq!( reply.protocol_status.as_ref().unwrap().code, ProtocolStatusCode::Ok as i32 ); assert_eq!(reply.status.as_ref().unwrap().success, 1); } #[tokio::test] async fn query_active_alarms_streams_snapshot_rows() { let state = Arc::new(FakeState::default()); let endpoint = spawn_fake_gateway(state.clone()).await; let client = GatewayClient::connect(ClientOptions::new(endpoint)) .await .unwrap(); let mut stream = client .query_active_alarms(QueryActiveAlarmsRequest { session_id: "session-fixture".to_owned(), ..QueryActiveAlarmsRequest::default() }) .await .unwrap(); let first = stream.next().await.unwrap().unwrap(); assert_eq!(first.alarm_full_reference, "Tank01.Level.HiHi"); } #[test] fn value_conversion_fixtures_keep_typed_projection_and_raw_metadata() { let fixture = behavior_fixture("values/value-conversion-cases.json"); let cases = fixture["cases"].as_array().unwrap(); let int64_case = case_by_id(cases, "int64.large"); let int64_value = ClientMxValue::from_proto(MxValue { data_type: MxDataType::Integer as i32, variant_type: "VT_I8".to_owned(), kind: Some(Kind::Int64Value( int64_case["value"]["int64Value"] .as_str() .unwrap() .parse() .unwrap(), )), ..MxValue::default() }); assert_eq!( int64_value.projection(), &MxValueProjection::Int64(9_223_372_036_854_770_000) ); let raw_case = case_by_id(cases, "raw-fallback.variant"); let raw_value = ClientMxValue::from_proto(MxValue { data_type: MxDataType::Unknown as i32, variant_type: "VT_RECORD".to_owned(), raw_diagnostic: raw_case["value"]["rawDiagnostic"] .as_str() .unwrap() .to_owned(), raw_data_type: raw_case["value"]["rawDataType"].as_i64().unwrap() as i32, kind: Some(Kind::RawValue(vec![1, 2, 3, 4, 5])), ..MxValue::default() }); assert_eq!( raw_value.projection(), &MxValueProjection::Raw(vec![1, 2, 3, 4, 5]) ); assert_eq!(raw_value.raw().raw_data_type, 32767); assert!(raw_value.raw().raw_diagnostic.contains("No lossless")); } #[test] fn status_conversion_fixtures_preserve_raw_fields() { let fixture = behavior_fixture("statuses/status-conversion-cases.json"); let cases = fixture["cases"].as_array().unwrap(); let raw_case = case_by_id(cases, "raw-unknown-category"); let status = MxStatus::from_proto(MxStatusProxy { success: raw_case["status"]["success"].as_i64().unwrap() as i32, category: MxStatusCategory::Unknown as i32, detected_by: MxStatusSource::Unknown as i32, detail: raw_case["status"]["detail"].as_i64().unwrap() as i32, raw_category: raw_case["status"]["rawCategory"].as_i64().unwrap() as i32, raw_detected_by: raw_case["status"]["rawDetectedBy"].as_i64().unwrap() as i32, diagnostic_text: raw_case["status"]["diagnosticText"] .as_str() .unwrap() .to_owned(), }); assert_eq!(status.success(), 0); assert_eq!(status.category(), Some(MxStatusCategory::Unknown)); assert_eq!(status.raw_category(), 99); assert_eq!(status.raw_detected_by(), 77); assert!(status.diagnostic_text().contains("preserved")); } #[test] fn authentication_and_authorization_statuses_are_distinct_and_redacted() { let auth = Error::from(Status::unauthenticated( "invalid API key mxgw_visible_secret", )); let denied = Error::from(Status::permission_denied("missing scope mxaccess.write")); assert!(matches!(auth, Error::Authentication { .. })); assert!(matches!(denied, Error::Authorization { .. })); assert!(!auth.to_string().contains("visible_secret")); } #[test] fn command_error_display_keeps_raw_reply_accessible() { let reply = mxaccess_failure_reply(); let error = CommandError::new(reply.clone()); assert_eq!(error.reply().hresult, Some(-2147220992)); assert!(error.to_string().contains("MxaccessFailure")); } #[derive(Default)] struct FakeState { authorization: Mutex>, last_command_kind: Mutex>, stream_dropped: Arc, } #[derive(Clone)] struct FakeGateway { state: Arc, } #[tonic::async_trait] impl MxAccessGateway for FakeGateway { async fn open_session( &self, request: Request, ) -> Result, Status> { *self.state.authorization.lock().await = request .metadata() .get("authorization") .and_then(|value| value.to_str().ok()) .map(str::to_owned); Ok(Response::new(OpenSessionReply { session_id: "session-fixture".to_owned(), backend_name: "fake".to_owned(), worker_process_id: 1234, worker_protocol_version: 1, gateway_protocol_version: 1, protocol_status: Some(ok_status("opened")), ..OpenSessionReply::default() })) } async fn close_session( &self, request: Request, ) -> Result, Status> { Ok(Response::new(CloseSessionReply { session_id: request.into_inner().session_id, final_state: SessionState::Closed as i32, protocol_status: Some(ok_status("closed")), })) } async fn invoke( &self, request: Request< zb_mom_ww_mxgateway_client::generated::mxaccess_gateway::v1::MxCommandRequest, >, ) -> Result, Status> { let request = request.into_inner(); let kind = request .command .as_ref() .map(|command| command.kind) .unwrap_or_default(); *self.state.last_command_kind.lock().await = Some(kind); if kind == MxCommandKind::Write as i32 { return Ok(Response::new(mxaccess_failure_reply())); } if kind == MxCommandKind::SubscribeBulk as i32 { return Ok(Response::new(MxCommandReply { session_id: request.session_id, correlation_id: "fake-correlation".to_owned(), kind, protocol_status: Some(ok_status("command ok")), payload: Some(mx_command_reply::Payload::SubscribeBulk( BulkSubscribeReply { results: vec![SubscribeResult { server_handle: 12, tag_address: "Area001.Pump001.Speed".to_owned(), item_handle: 34, was_successful: true, error_message: String::new(), }], }, )), ..MxCommandReply::default() })); } Ok(Response::new(MxCommandReply { session_id: request.session_id, correlation_id: "fake-correlation".to_owned(), kind, protocol_status: Some(ok_status("command ok")), payload: Some(mx_command_reply::Payload::AddItem(AddItemReply { item_handle: 34, })), ..MxCommandReply::default() })) } type StreamEventsStream = DropAwareStream; async fn stream_events( &self, _request: Request, ) -> Result, Status> { let (sender, receiver) = mpsc::channel(4); sender.send(Ok(event(1))).await.unwrap(); sender.send(Ok(event(2))).await.unwrap(); Ok(Response::new(DropAwareStream { inner: ReceiverStream::new(receiver), dropped: self.state.stream_dropped.clone(), })) } async fn acknowledge_alarm( &self, _request: Request, ) -> Result, Status> { Ok(Response::new(AcknowledgeAlarmReply { correlation_id: "corr-1".to_owned(), protocol_status: Some(ok_status("ack ok")), status: Some(MxStatusProxy { success: 1, category: MxStatusCategory::Ok as i32, detected_by: MxStatusSource::RespondingLmx as i32, ..MxStatusProxy::default() }), ..AcknowledgeAlarmReply::default() })) } type StreamAlarmsStream = Pin> + Send + 'static>>; async fn stream_alarms( &self, _request: Request, ) -> Result, Status> { let (_sender, receiver) = mpsc::channel::>(1); let stream = ReceiverStream::new(receiver); Ok(Response::new(Box::pin(stream))) } type QueryActiveAlarmsStream = Pin> + Send + 'static>>; async fn query_active_alarms( &self, _request: Request, ) -> Result, Status> { let (sender, receiver) = mpsc::channel(4); sender .send(Ok(ActiveAlarmSnapshot { alarm_full_reference: "Tank01.Level.HiHi".to_owned(), ..ActiveAlarmSnapshot::default() })) .await .unwrap(); let stream = ReceiverStream::new(receiver); Ok(Response::new(Box::pin(stream))) } } struct DropAwareStream { inner: ReceiverStream>, dropped: Arc, } impl Stream for DropAwareStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_next(context) } } impl Drop for DropAwareStream { fn drop(&mut self) { self.dropped.store(true, Ordering::SeqCst); } } async fn spawn_fake_gateway(state: Arc) -> String { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); let incoming = TcpListenerStream::new(listener); let service = MxAccessGatewayServer::new(FakeGateway { state }); tokio::spawn(async move { Server::builder() .add_service(service) .serve_with_incoming(incoming) .await .unwrap(); }); format!("http://{address}") } fn ok_status(message: &str) -> ProtocolStatus { ProtocolStatus { code: ProtocolStatusCode::Ok as i32, message: message.to_owned(), } } fn mxaccess_failure_reply() -> MxCommandReply { MxCommandReply { session_id: "session-fixture".to_owned(), correlation_id: "gateway-correlation-write-1".to_owned(), kind: MxCommandKind::Write as i32, protocol_status: Some(ProtocolStatus { code: ProtocolStatusCode::MxaccessFailure as i32, message: "MXAccess rejected the write.".to_owned(), }), hresult: Some(-2147220992), statuses: vec![ MxStatusProxy { success: 0, category: MxStatusCategory::SecurityError as i32, detected_by: MxStatusSource::RespondingLmx as i32, detail: 321, raw_category: 8, raw_detected_by: 3, diagnostic_text: "Write denied by provider security.".to_owned(), }, MxStatusProxy { success: 0, category: MxStatusCategory::OperationalError as i32, detected_by: MxStatusSource::RespondingNmx as i32, detail: 902, raw_category: 7, raw_detected_by: 5, diagnostic_text: "Provider rejected the item state.".to_owned(), }, ], ..MxCommandReply::default() } } fn event(sequence: u64) -> MxEvent { MxEvent { family: MxEventFamily::OnDataChange as i32, session_id: "session-fixture".to_owned(), worker_sequence: sequence, ..MxEvent::default() } } fn behavior_fixture(path: &str) -> Value { let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) .join("../proto/fixtures/behavior") .join(path); let data = std::fs::read_to_string(&path).unwrap(); serde_json::from_str(&data).unwrap() } fn case_by_id<'a>(cases: &'a [Value], id: &str) -> &'a Value { cases .iter() .find(|case| case["id"].as_str() == Some(id)) .unwrap_or_else(|| panic!("missing fixture case {id}")) }