//! `INmxSvcCallback` callback exporter — TCP server. //! //! Direct port of `src/MxNativeClient/ManagedCallbackExporter.cs`. Spins a //! tokio TCP listener, accepts incoming DCE/RPC connections from //! `NmxSvc.exe`, walks them through Bind / AlterContext / Request / //! Auth3 PDUs, and dispatches `IRemUnknown` and `INmxSvcCallback` requests. //! The .NET reference is the executable spec; every wire shape and //! HRESULT cited inline. //! //! Two interfaces are served: //! //! - `IRemUnknown` (IID `00000131-0000-0000-C000-000000000046`, //! `RemUnknownMessages.cs:7`) — opnums 3 (`RemQueryInterface`), //! 4 (`RemAddRef`), 5 (`RemRelease`). The QI handler returns `S_OK` //! for `IRemUnknown` / `INmxSvcCallback` / `IUnknown` and //! `E_NOINTERFACE` (`0x80004002`) otherwise — mirrors //! `ManagedCallbackExporter.cs:196-200`. //! - `INmxSvcCallback` (IID `B49F92F7-C748-4169-8ECA-A0670B012746`, //! `NmxProcedureMetadata.cs:6`) — opnums 3 (`DataReceived`), //! 4 (`StatusReceived`). The handler decodes the inbound buffer via //! [`mxaccess_rpc::nmx_callback_messages::parse_callback_request`], //! emits a typed event, and returns the success response built by //! [`mxaccess_rpc::nmx_callback_messages::encode_callback_response`]. //! //! Auth3 PDUs are accepted but ignored (`ManagedCallbackExporter.cs:133-137`) //! — NTLM packet integrity for inbound frames is not yet wired (open //! follow-up F2). #![allow(clippy::indexing_slicing)] use std::net::SocketAddr; use mxaccess_rpc::error::RpcError; use mxaccess_rpc::guid::Guid; use mxaccess_rpc::nmx_callback_messages; use mxaccess_rpc::nmx_metadata::INMX_SVC_CALLBACK_IID; use mxaccess_rpc::objref::ComObjRefBuilder; use mxaccess_rpc::orpc::{OrpcThat, StdObjRef}; use mxaccess_rpc::pdu::{ BindPdu, FaultPdu, PacketType, PduHeader, RequestPdu, ResponsePdu, SyntaxId, }; use mxaccess_rpc::rem_unknown::{ IREM_UNKNOWN_IID, REM_ADD_REF_OPNUM, REM_QUERY_INTERFACE_OPNUM, REM_RELEASE_OPNUM, }; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; /// `IUnknown` IID `00000000-0000-0000-C000-000000000046`. The QI handler /// also returns success for this in addition to `IRemUnknown` and /// `INmxSvcCallback` (`ManagedCallbackExporter.cs:198`). pub const IUNKNOWN_IID: Guid = Guid::new([ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46, ]); /// `S_OK` HRESULT. const S_OK: i32 = 0; /// `E_NOINTERFACE` HRESULT — returned by `RemQueryInterface` for any IID we /// don't speak (`ManagedCallbackExporter.cs:200`). const E_NOINTERFACE: i32 = 0x8000_4002u32 as i32; /// `RPC_S_PROCNUM_OUT_OF_RANGE` (1783 = 0x6F7) — returned as a fault for any /// request whose `(iid, opnum)` we don't dispatch /// (`ManagedCallbackExporter.cs:171,184,187`). const RPC_S_PROCNUM_OUT_OF_RANGE: u32 = 0x0000_06F7; /// Fixed PDU framing constants. The .NET reference asserts these via /// hand-rolled writes at `ManagedCallbackExporter.cs:226-254`. const BIND_ACK_MAX_FRAGMENT: u16 = 4280; const BIND_ACK_ASSOC_GROUP_ID: u32 = 0x0000_5353; const RESPONSE_PACKET_FLAGS: u8 = 0x03; const RESPONSE_DATA_REPRESENTATION: u32 = 0x0000_0010; /// Identities the exporter publishes in OBJREFs and STDOBJREFs. /// /// The .NET reference generates these via `Guid.NewGuid()` and /// `RandomNumberGenerator.Fill` (`ManagedCallbackExporter.cs:14-20`); the /// Rust port mirrors that for production but exposes a [`fixed`] constructor /// so tests can pin all four values for deterministic byte comparisons. /// /// [`fixed`]: ExporterIdentities::fixed #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ExporterIdentities { pub oxid: u64, pub oid: u64, pub callback_ipid: Guid, pub rem_unknown_ipid: Guid, } impl ExporterIdentities { /// Generate identities from `rand::random()` — production default. /// Mirrors `ManagedCallbackExporter.cs:14-20` (`RandomUInt64` + /// `Guid.NewGuid()` × 2). #[must_use] pub fn random() -> Self { Self { oxid: rand::random(), oid: rand::random(), callback_ipid: Guid::new(rand::random()), rem_unknown_ipid: Guid::new(rand::random()), } } /// Construct with caller-supplied values. Useful in tests where /// deterministic OBJREF byte comparisons are required. #[must_use] pub const fn fixed(oxid: u64, oid: u64, callback_ipid: Guid, rem_unknown_ipid: Guid) -> Self { Self { oxid, oid, callback_ipid, rem_unknown_ipid, } } } /// Diagnostic events emitted by the exporter as it serves connections. /// /// The .NET reference logs strings into a `List` (per /// `ManagedCallbackExporter.cs:12,33-42,315-321`). The Rust port emits /// typed events instead — same information, more useful to consumers /// that want to assert structurally. #[derive(Debug, Clone, PartialEq, Eq)] pub enum CallbackEvent { /// A new TCP connection was accepted (`cs:78`). ClientConnected { remote: SocketAddr }, /// `accept()` returned an error (`cs:89-92`). AcceptError { reason: String }, /// Bind / AlterContext negotiated a presentation context (`cs:121`). Bind { context_id: u16, iid: Guid }, /// Auth3 PDU received and ignored (`cs:133-137`). Auth3Ignored, /// Request PDU received (`cs:162`). Request { iid: Guid, context_id: u16, opnum: u16, stub_len: usize, }, /// `IRemUnknown::RemQueryInterface` invoked (`cs:191-200`). RemQueryInterface { requested_iid: Guid, hresult: i32 }, /// `INmxSvcCallback::DataReceived` or `StatusReceived` invoked /// (`cs:177-183`). CallbackInvoked { opnum: u16, body: Vec }, /// Request whose `(iid, opnum)` was not dispatched — fault returned /// (`cs:171,184,187`). UnhandledRequest { iid: Guid, opnum: u16 }, /// Client closed the connection (`cs:107-110`). ClientDisconnected, /// PDU header parse / fragment-length / packet-type error. ProtocolError { reason: String }, } /// The TCP exporter handle. Drop aborts the accept task. /// /// Construction is asynchronous because [`tokio::net::TcpListener::bind`] /// is. Use [`CallbackExporter::bind`] to start; the returned tuple includes /// an [`mpsc::UnboundedReceiver`] that streams diagnostic /// events as the server runs. pub struct CallbackExporter { local_addr: SocketAddr, identities: ExporterIdentities, shutdown_tx: Option>, accept_handle: Option>, } impl CallbackExporter { /// Bind to `addr` (use port 0 for an OS-assigned ephemeral port) and /// start the accept loop. Returns the handle plus the diagnostic /// event stream. /// /// # Errors /// Returns [`std::io::Error`] from [`TcpListener::bind`]. pub async fn bind( addr: SocketAddr, identities: ExporterIdentities, ) -> std::io::Result<(Self, mpsc::UnboundedReceiver)> { let listener = TcpListener::bind(addr).await?; let local_addr = listener.local_addr()?; let (event_tx, event_rx) = mpsc::unbounded_channel(); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let accept_handle = tokio::spawn(accept_loop(listener, identities, event_tx, shutdown_rx)); Ok(( Self { local_addr, identities, shutdown_tx: Some(shutdown_tx), accept_handle: Some(accept_handle), }, event_rx, )) } /// Local address the listener is bound to. #[must_use] pub fn local_addr(&self) -> SocketAddr { self.local_addr } /// Identities this exporter publishes. #[must_use] pub fn identities(&self) -> ExporterIdentities { self.identities } /// Build a callback OBJREF to publish back to the AVEVA service. /// /// Mirrors `ManagedCallbackExporter.CreateCallbackObjRef` /// (`cs:44-54`): the IID is `INmxSvcCallback`, /// `public_refs = 5`, OXID/OID/IPID come from `self.identities`, and /// the single string binding is `"[]"`. /// /// `std_flags = 0x280` — `SORF_OXRES4 | SORF_OXRES6` (= `0x80 | /// 0x200`). Mirrors the .NET reference's `ManagedCallbackExporter` /// (`cs:48`). #[must_use] pub fn create_callback_objref(&self, hostname: &str) -> Vec { let binding = format!("{hostname}[{port}]", port = self.local_addr.port()); ComObjRefBuilder::create_standard_objref( INMX_SVC_CALLBACK_IID, 0x280, 5, self.identities.oxid, self.identities.oid, self.identities.callback_ipid, &[binding.as_str()], ) } /// Gracefully shut down the accept loop. Returns once the loop task /// exits. Any in-flight connections continue until they EOF on read. pub async fn shutdown(mut self) { if let Some(tx) = self.shutdown_tx.take() { // Receiver dropped is fine; the loop is already exiting. let _ = tx.send(()); } if let Some(handle) = self.accept_handle.take() { let _ = handle.await; } } } impl Drop for CallbackExporter { fn drop(&mut self) { if let Some(tx) = self.shutdown_tx.take() { // Best-effort signal; if the loop is gone the receiver will be // dropped which is also fine. let _ = tx.send(()); } if let Some(handle) = self.accept_handle.take() { handle.abort(); } } } async fn accept_loop( listener: TcpListener, identities: ExporterIdentities, event_tx: mpsc::UnboundedSender, mut shutdown_rx: oneshot::Receiver<()>, ) { loop { tokio::select! { _ = &mut shutdown_rx => return, accept = listener.accept() => match accept { Ok((stream, remote)) => { let _ = event_tx.send(CallbackEvent::ClientConnected { remote }); let event_tx = event_tx.clone(); tokio::spawn(serve_client(stream, identities, event_tx)); } Err(err) => { let _ = event_tx.send(CallbackEvent::AcceptError { reason: err.to_string() }); // Listener errors are usually fatal (e.g. socket closed). return; } }, } } } async fn serve_client( mut stream: TcpStream, identities: ExporterIdentities, event_tx: mpsc::UnboundedSender, ) { // The .NET reference threads `currentContext` through `EncodeBindAck` // but the encoder discards it (`ManagedCallbackExporter.cs:252: // `_ = contextId;`), so we simply record the negotiated context id in // the Bind event and read `request.context_id` per-request. let mut current_iid: Guid = Guid::ZERO; loop { let pdu = match read_pdu(&mut stream).await { Ok(Some(p)) => p, Ok(None) => { let _ = event_tx.send(CallbackEvent::ClientDisconnected); return; } Err(err) => { let _ = event_tx.send(CallbackEvent::ProtocolError { reason: err.to_string(), }); return; } }; let header = match PduHeader::decode(&pdu) { Ok(h) => h, Err(err) => { let _ = event_tx.send(CallbackEvent::ProtocolError { reason: format!("{err}"), }); return; } }; match header.packet_type { PacketType::Bind | PacketType::AlterContext => match BindPdu::decode(&pdu) { Ok(bind) => { let context_id = if let Some(first) = bind.presentation_contexts.first() { current_iid = Guid::new(first.abstract_syntax.uuid_bytes); first.context_id } else { current_iid = Guid::ZERO; 0 }; let _ = event_tx.send(CallbackEvent::Bind { context_id, iid: current_iid, }); let response = encode_bind_ack(header.call_id); if stream.write_all(&response).await.is_err() { return; } } Err(err) => { let _ = event_tx.send(CallbackEvent::ProtocolError { reason: format!("{err}"), }); return; } }, PacketType::Request => { let request = match RequestPdu::decode(&pdu) { Ok(r) => r, Err(err) => { let _ = event_tx.send(CallbackEvent::ProtocolError { reason: format!("{err}"), }); return; } }; let _ = event_tx.send(CallbackEvent::Request { iid: current_iid, context_id: request.context_id, opnum: request.opnum, stub_len: request.stub_data.len(), }); let response_bytes = handle_request(&request, current_iid, &identities, &event_tx); if stream.write_all(&response_bytes).await.is_err() { return; } } PacketType::Auth3 => { let _ = event_tx.send(CallbackEvent::Auth3Ignored); // No response per `cs:133-137`. } other => { let _ = event_tx.send(CallbackEvent::ProtocolError { reason: format!("unhandled PDU type {other:?}"), }); return; } } } } fn handle_request( request: &RequestPdu, current_iid: Guid, identities: &ExporterIdentities, event_tx: &mpsc::UnboundedSender, ) -> Vec { if current_iid == IREM_UNKNOWN_IID { match request.opnum { REM_QUERY_INTERFACE_OPNUM => { let response_body = encode_rem_query_interface_response(&request.stub_data, identities, event_tx); return wrap_response(request.header.call_id, request.context_id, response_body); } REM_ADD_REF_OPNUM | REM_RELEASE_OPNUM => { return wrap_response( request.header.call_id, request.context_id, encode_orpc_hresult_response(S_OK).to_vec(), ); } _ => { let _ = event_tx.send(CallbackEvent::UnhandledRequest { iid: current_iid, opnum: request.opnum, }); return encode_fault( request.header.call_id, request.context_id, RPC_S_PROCNUM_OUT_OF_RANGE, ); } } } if current_iid == INMX_SVC_CALLBACK_IID { if request.opnum == nmx_callback_messages::DATA_RECEIVED_OPNUM || request.opnum == nmx_callback_messages::STATUS_RECEIVED_OPNUM { // Decode the callback body for the event stream; if it fails to // parse, still return success per the .NET reference (`cs:179-181`) // — diagnostic info goes through `ProtocolError` instead. let body = match nmx_callback_messages::parse_callback_request(&request.stub_data) { Ok(parsed) => parsed.body, Err(err) => { let _ = event_tx.send(CallbackEvent::ProtocolError { reason: format!("callback request decode: {err}"), }); Vec::new() } }; let _ = event_tx.send(CallbackEvent::CallbackInvoked { opnum: request.opnum, body, }); let resp = nmx_callback_messages::encode_callback_response(S_OK); return wrap_response(request.header.call_id, request.context_id, resp.to_vec()); } let _ = event_tx.send(CallbackEvent::UnhandledRequest { iid: current_iid, opnum: request.opnum, }); return encode_fault( request.header.call_id, request.context_id, RPC_S_PROCNUM_OUT_OF_RANGE, ); } let _ = event_tx.send(CallbackEvent::UnhandledRequest { iid: current_iid, opnum: request.opnum, }); encode_fault( request.header.call_id, request.context_id, RPC_S_PROCNUM_OUT_OF_RANGE, ) } /// Build the `RemQueryInterface` response stub. Mirrors /// `ManagedCallbackExporter.EncodeRemQueryInterfaceResponse` /// (`cs:190-216`). fn encode_rem_query_interface_response( request_stub: &[u8], identities: &ExporterIdentities, event_tx: &mpsc::UnboundedSender, ) -> Vec { // Requested IID at offset 60..76 of the request body (`cs:192`). If the // stub is short, default to ZERO — same as .NET `Guid.Empty` fallback. let requested_iid = if request_stub.len() >= 76 { Guid::parse(&request_stub[60..76]).unwrap_or(Guid::ZERO) } else { Guid::ZERO }; // Pick the IPID to publish (`cs:195`): RemUnknown gets RemUnknownIpid; // anything else (including the success cases for callback / IUnknown) // gets CallbackIpid. let ipid = if requested_iid == IREM_UNKNOWN_IID { identities.rem_unknown_ipid } else { identities.callback_ipid }; let std = StdObjRef { flags: 0x280, public_refs: 5, oxid: identities.oxid, oid: identities.oid, ipid, }; let hresult = if requested_iid == IREM_UNKNOWN_IID || requested_iid == INMX_SVC_CALLBACK_IID || requested_iid == IUNKNOWN_IID { S_OK } else { E_NOINTERFACE }; let _ = event_tx.send(CallbackEvent::RemQueryInterface { requested_iid, hresult, }); // Layout (cs:202-215): // 0..8 OrpcThat (zeroed) // 8..12 referent_id u32 LE = 0x00020000 // 12..16 max_count u32 LE = 1 // 16..20 hresult i32 LE // 20..24 4-byte NDR pad ahead of STDOBJREF (zero) // 24..64 STDOBJREF (40 bytes) // 64..68 error_code u32 LE = 0 let mut buf = vec![0u8; OrpcThat::ENCODED_LEN + 4 + 4 + 4 + 4 + StdObjRef::ENCODED_LEN + 4]; let orpc_that = OrpcThat::default().encode(); buf[..OrpcThat::ENCODED_LEN].copy_from_slice(&orpc_that); let mut off = OrpcThat::ENCODED_LEN; buf[off..off + 4].copy_from_slice(&0x0002_0000u32.to_le_bytes()); off += 4; buf[off..off + 4].copy_from_slice(&1u32.to_le_bytes()); off += 4; buf[off..off + 4].copy_from_slice(&hresult.to_le_bytes()); off += 4; // 4-byte pad (zero). off += 4; let std_bytes = std.encode(); buf[off..off + StdObjRef::ENCODED_LEN].copy_from_slice(&std_bytes); off += StdObjRef::ENCODED_LEN; buf[off..off + 4].copy_from_slice(&0u32.to_le_bytes()); buf } /// 12-byte simple HRESULT response — `OrpcThat(8) + hresult(4)`. Mirrors /// `EncodeOrpcHResultResponse` (`cs:218-224`). fn encode_orpc_hresult_response(hresult: i32) -> [u8; 12] { let mut buf = [0u8; 12]; buf[..OrpcThat::ENCODED_LEN].copy_from_slice(&OrpcThat::default().encode()); buf[OrpcThat::ENCODED_LEN..].copy_from_slice(&hresult.to_le_bytes()); buf } /// Wrap a stub-data body in a Response PDU. Mirrors `EncodeResponse` /// (`cs:256-267`). fn wrap_response(call_id: u32, context_id: u16, stub_data: Vec) -> Vec { let header = PduHeader { version: 5, version_minor: 0, packet_type: PacketType::Response, packet_flags: RESPONSE_PACKET_FLAGS, data_representation: RESPONSE_DATA_REPRESENTATION, fragment_length: 0, // overwritten by RequestPdu::encode auth_length: 0, call_id, }; ResponsePdu { header, allocation_hint: u32::try_from(stub_data.len()).unwrap_or(u32::MAX), context_id, cancel_count: 0, reserved23: 0, stub_data, } .encode() } /// Build a Fault PDU. Mirrors `EncodeFault` (`cs:269-277`). fn encode_fault(call_id: u32, context_id: u16, status: u32) -> Vec { let header = PduHeader { version: 5, version_minor: 0, packet_type: PacketType::Fault, packet_flags: RESPONSE_PACKET_FLAGS, data_representation: RESPONSE_DATA_REPRESENTATION, fragment_length: 0, // overwritten by FaultPdu::encode auth_length: 0, call_id, }; FaultPdu { header, allocation_hint: 0, context_id, cancel_count: 0, reserved23: 0, status, stub_data: Vec::new(), } .encode() } /// Build the BindAck PDU. Hand-rolled — neither `mxaccess-rpc::pdu` nor the /// .NET `DceRpcBindPdu` class encode this shape; the original code at /// `ManagedCallbackExporter.cs:226-254` writes it byte-by-byte with a single /// presentation-context acceptance entry pointing at the NDR20 transfer /// syntax. /// /// Layout per `[C706]` §12.6.4.4 + `cs:226-254`: /// /// ```text /// offset size field /// 0 16 PduHeader (ptype=BindAck=12, flags=0x03, drep=0x10) /// 16 2 max_xmit_fragment u16 LE = 4280 /// 18 2 max_recv_fragment u16 LE = 4280 /// 20 4 assoc_group_id u32 LE = 0x5353 /// 24 2 sec_addr_length u16 LE /// 26 sl secondary_address (UTF-8 + null) /// 26+sl p pad to 4-byte alignment from start of pdu /// resOff 1 n_results = 1 /// resOff+1 1 reserved /// resOff+2 2 reserved2 u16 LE /// resOff+4 2 result u16 LE = 0 (acceptance) /// resOff+6 2 reason u16 LE = 0 /// resOff+8 16 transfer_syntax_uuid (NDR20) /// resOff+24 2 transfer_syntax_version_major u16 LE /// resOff+26 2 transfer_syntax_version_minor u16 LE /// ``` fn encode_bind_ack(call_id: u32) -> Vec { // .NET writes a single-byte secondary address: empty string + "\0". const SECONDARY: &[u8] = b"\0"; let sec_addr_length = SECONDARY.len(); // Pad to 4-byte alignment from the start of the PDU. let unpadded = 28 + sec_addr_length; let pad = align_up(unpadded, 4) - unpadded; let result_offset = unpadded + pad; // 4 bytes (n_results + reserved + reserved2) + 4 (result + reason) + 20 // (NDR20 syntax id) = 28 bytes after result_offset. let length = result_offset + 4 + 24; let mut pdu = vec![0u8; length]; let header = PduHeader { version: 5, version_minor: 0, packet_type: PacketType::BindAck, packet_flags: RESPONSE_PACKET_FLAGS, data_representation: RESPONSE_DATA_REPRESENTATION, fragment_length: u16::try_from(length).unwrap_or(u16::MAX), auth_length: 0, call_id, }; let _ = header.encode(&mut pdu); pdu[16..18].copy_from_slice(&BIND_ACK_MAX_FRAGMENT.to_le_bytes()); pdu[18..20].copy_from_slice(&BIND_ACK_MAX_FRAGMENT.to_le_bytes()); pdu[20..24].copy_from_slice(&BIND_ACK_ASSOC_GROUP_ID.to_le_bytes()); pdu[24..26].copy_from_slice( &u16::try_from(sec_addr_length) .unwrap_or(u16::MAX) .to_le_bytes(), ); pdu[26..26 + SECONDARY.len()].copy_from_slice(SECONDARY); let mut o = result_offset; pdu[o] = 1; // n_results pdu[o + 1] = 0; // reserved pdu[o + 2] = 0; // reserved2 LSB pdu[o + 3] = 0; // reserved2 MSB o += 4; // result (=0 acceptance) + reason (=0). pdu[o..o + 2].copy_from_slice(&0u16.to_le_bytes()); pdu[o + 2..o + 4].copy_from_slice(&0u16.to_le_bytes()); o += 4; let ndr = SyntaxId::NDR20; pdu[o..o + 16].copy_from_slice(&ndr.uuid_bytes); pdu[o + 16..o + 18].copy_from_slice(&ndr.version_major.to_le_bytes()); pdu[o + 18..o + 20].copy_from_slice(&ndr.version_minor.to_le_bytes()); pdu } const fn align_up(value: usize, alignment: usize) -> usize { let r = value % alignment; if r == 0 { value } else { value + alignment - r } } /// Read one full PDU from `stream`. Returns `Ok(None)` on clean EOF before /// the first byte. Mirrors `ReadPduAsync` + `ReadExactAsync` /// (`cs:279-313`). async fn read_pdu(stream: &mut TcpStream) -> Result>, RpcError> { let mut header_buf = [0u8; PduHeader::LENGTH]; match read_exact_or_eof(stream, &mut header_buf).await { Ok(true) => {} Ok(false) => return Ok(None), Err(_) => return Ok(None), } let header = PduHeader::decode(&header_buf)?; let frag = header.fragment_length as usize; if frag < PduHeader::LENGTH { return Err(RpcError::InvalidFragmentLength { frag_length: frag, buffer_len: 0, auth_length: header.auth_length as usize, }); } let mut pdu = vec![0u8; frag]; pdu[..PduHeader::LENGTH].copy_from_slice(&header_buf); if frag > PduHeader::LENGTH { match read_exact_or_eof(stream, &mut pdu[PduHeader::LENGTH..]).await { Ok(true) => {} Ok(_) | Err(_) => return Ok(None), } } Ok(Some(pdu)) } /// Like `AsyncReadExt::read_exact` but treats clean EOF *before any byte was /// read* as `Ok(false)` so the caller can distinguish "client closed before /// the next PDU" from "client closed mid-PDU." async fn read_exact_or_eof(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result { let mut filled = 0; while filled < buf.len() { let n = stream.read(&mut buf[filled..]).await?; if n == 0 { return Ok(filled != 0 && filled == buf.len()); } filled += n; } Ok(true) } #[cfg(test)] #[allow( clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing, clippy::panic )] mod tests { use super::*; use mxaccess_rpc::orpc::OrpcThis; use mxaccess_rpc::pdu::{BindPdu, PresentationContext, SyntaxId}; use mxaccess_rpc::rem_unknown::{ IREM_UNKNOWN_IID, REM_QUERY_INTERFACE_OPNUM, encode_rem_query_interface_request, }; fn fixed_identities() -> ExporterIdentities { ExporterIdentities::fixed( 0x1111_2222_3333_4444, 0x5555_6666_7777_8888, Guid::new([0xCC; 16]), Guid::new([0xDD; 16]), ) } fn local_addr() -> SocketAddr { "127.0.0.1:0".parse().unwrap() } fn build_bind_pdu(call_id: u32, abstract_iid: Guid) -> Vec { let header = PduHeader { version: 5, version_minor: 0, packet_type: PacketType::Bind, packet_flags: 0x03, data_representation: RESPONSE_DATA_REPRESENTATION, fragment_length: 0, // BindPdu::encode rewrites it auth_length: 0, call_id, }; BindPdu { header, max_transmit_fragment: BIND_ACK_MAX_FRAGMENT, max_receive_fragment: BIND_ACK_MAX_FRAGMENT, association_group_id: 0, presentation_contexts: vec![PresentationContext { context_id: 0, abstract_syntax: SyntaxId { uuid_bytes: *abstract_iid.as_bytes(), version_major: 0, version_minor: 0, }, transfer_syntaxes: vec![SyntaxId::NDR20], }], reserved25_28: [0; 3], } .encode() } fn build_request_pdu(call_id: u32, opnum: u16, stub: Vec) -> Vec { let header = PduHeader { version: 5, version_minor: 0, packet_type: PacketType::Request, packet_flags: 0x03, data_representation: RESPONSE_DATA_REPRESENTATION, fragment_length: 0, auth_length: 0, call_id, }; RequestPdu { header, allocation_hint: stub.len() as u32, context_id: 0, opnum, stub_data: stub, } .encode() } async fn collect_event(rx: &mut mpsc::UnboundedReceiver) -> CallbackEvent { tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) .await .expect("event timeout") .expect("event channel closed") } #[test] fn iunknown_iid_constant_matches_dotnet() { // .NET `new Guid("00000000-0000-0000-C000-000000000046").ToString("D")`. assert_eq!( IUNKNOWN_IID.to_string(), "00000000-0000-0000-c000-000000000046" ); } #[test] fn align_up_matches_dotnet_align() { assert_eq!(align_up(0, 4), 0); assert_eq!(align_up(1, 4), 4); assert_eq!(align_up(28, 4), 28); assert_eq!(align_up(29, 4), 32); assert_eq!(align_up(31, 4), 32); } #[test] fn encode_bind_ack_layout() { let pdu = encode_bind_ack(42); // PDU header decodes back. let header = PduHeader::decode(&pdu).unwrap(); assert_eq!(header.packet_type, PacketType::BindAck); assert_eq!(header.call_id, 42); assert_eq!(header.fragment_length as usize, pdu.len()); // Max fragments at 16/18. assert_eq!(u16::from_le_bytes([pdu[16], pdu[17]]), 4280); assert_eq!(u16::from_le_bytes([pdu[18], pdu[19]]), 4280); // Assoc group id at 20. assert_eq!( u32::from_le_bytes([pdu[20], pdu[21], pdu[22], pdu[23]]), 0x5353 ); // Empty secondary address (length=1, single null byte). assert_eq!(u16::from_le_bytes([pdu[24], pdu[25]]), 1); assert_eq!(pdu[26], 0); // Result-list n_results=1 at the aligned offset (32). // unpadded = 28 + 1 = 29; pad = 3; result_offset = 32. let result_offset = 32; assert_eq!(pdu[result_offset], 1); // Acceptance + NDR20 syntax follow. let syntax_offset = result_offset + 4 + 4; assert_eq!( &pdu[syntax_offset..syntax_offset + 16], &SyntaxId::NDR20.uuid_bytes ); } #[test] fn encode_orpc_hresult_response_layout() { let r = encode_orpc_hresult_response(0); assert_eq!(r.len(), 12); assert_eq!(&r[..8], &[0u8; 8]); assert_eq!(&r[8..], &0i32.to_le_bytes()); } #[tokio::test] async fn bind_request_round_trip_via_real_socket() { let (server, mut events) = CallbackExporter::bind(local_addr(), fixed_identities()) .await .unwrap(); let addr = server.local_addr(); let mut client = TcpStream::connect(addr).await.unwrap(); // 1. Send a Bind PDU advertising IRemUnknown abstract syntax. let bind = build_bind_pdu(1, IREM_UNKNOWN_IID); client.write_all(&bind).await.unwrap(); // Server emits ClientConnected then Bind. let connected = collect_event(&mut events).await; assert!(matches!(connected, CallbackEvent::ClientConnected { .. })); let bind_event = collect_event(&mut events).await; match bind_event { CallbackEvent::Bind { context_id, iid } => { assert_eq!(context_id, 0); assert_eq!(iid, IREM_UNKNOWN_IID); } other => panic!("expected Bind, got {other:?}"), } // Read the BindAck back (16 bytes header, then frag_length determines the body). let mut header_buf = [0u8; 16]; client.read_exact(&mut header_buf).await.unwrap(); let header = PduHeader::decode(&header_buf).unwrap(); assert_eq!(header.packet_type, PacketType::BindAck); let mut body = vec![0u8; header.fragment_length as usize - 16]; client.read_exact(&mut body).await.unwrap(); // 2. Send a RemQueryInterface request asking for IUnknown. let qi_request_body = encode_rem_query_interface_request( Guid::ZERO, // source ipid (unused by server) IUNKNOWN_IID, // requested IID — should resolve to S_OK Guid::new([0xAA; 16]), // causality id 5, ); let qi_pdu = build_request_pdu(2, REM_QUERY_INTERFACE_OPNUM, qi_request_body); client.write_all(&qi_pdu).await.unwrap(); // Server emits Request then RemQueryInterface event. let req_event = collect_event(&mut events).await; match req_event { CallbackEvent::Request { iid, opnum, .. } => { assert_eq!(iid, IREM_UNKNOWN_IID); assert_eq!(opnum, REM_QUERY_INTERFACE_OPNUM); } other => panic!("expected Request, got {other:?}"), } let qi_event = collect_event(&mut events).await; match qi_event { CallbackEvent::RemQueryInterface { requested_iid, hresult, } => { assert_eq!(requested_iid, IUNKNOWN_IID); assert_eq!(hresult, S_OK); } other => panic!("expected RemQueryInterface, got {other:?}"), } // Read response PDU header back. client.read_exact(&mut header_buf).await.unwrap(); let resp_header = PduHeader::decode(&header_buf).unwrap(); assert_eq!(resp_header.packet_type, PacketType::Response); let mut resp_body = vec![0u8; resp_header.fragment_length as usize - 16]; client.read_exact(&mut resp_body).await.unwrap(); server.shutdown().await; } #[tokio::test] async fn unknown_opnum_returns_fault() { let (server, mut events) = CallbackExporter::bind(local_addr(), fixed_identities()) .await .unwrap(); let mut client = TcpStream::connect(server.local_addr()).await.unwrap(); let bind = build_bind_pdu(1, IREM_UNKNOWN_IID); client.write_all(&bind).await.unwrap(); // Drain BindAck. let mut header_buf = [0u8; 16]; client.read_exact(&mut header_buf).await.unwrap(); let header = PduHeader::decode(&header_buf).unwrap(); let mut body = vec![0u8; header.fragment_length as usize - 16]; client.read_exact(&mut body).await.unwrap(); // Drain Connected/Bind events. let _ = collect_event(&mut events).await; let _ = collect_event(&mut events).await; // Send a request with opnum 99 (unknown). let req = build_request_pdu(2, 99, vec![0u8; 76]); client.write_all(&req).await.unwrap(); // Drain Request and UnhandledRequest events. let _ = collect_event(&mut events).await; let unhandled = collect_event(&mut events).await; match unhandled { CallbackEvent::UnhandledRequest { iid, opnum } => { assert_eq!(iid, IREM_UNKNOWN_IID); assert_eq!(opnum, 99); } other => panic!("expected UnhandledRequest, got {other:?}"), } client.read_exact(&mut header_buf).await.unwrap(); let resp_header = PduHeader::decode(&header_buf).unwrap(); assert_eq!(resp_header.packet_type, PacketType::Fault); server.shutdown().await; } #[tokio::test] async fn callback_invocation_emits_event_and_acks() { let (server, mut events) = CallbackExporter::bind(local_addr(), fixed_identities()) .await .unwrap(); let mut client = TcpStream::connect(server.local_addr()).await.unwrap(); // Bind to INmxSvcCallback. let bind = build_bind_pdu(1, INMX_SVC_CALLBACK_IID); client.write_all(&bind).await.unwrap(); let mut header_buf = [0u8; 16]; client.read_exact(&mut header_buf).await.unwrap(); let header = PduHeader::decode(&header_buf).unwrap(); let mut body = vec![0u8; header.fragment_length as usize - 16]; client.read_exact(&mut body).await.unwrap(); let _ = collect_event(&mut events).await; // ClientConnected let _ = collect_event(&mut events).await; // Bind // Build a callback request: OrpcThis + size + max_count + body. let cid = Guid::new([0xEE; 16]); let mut stub = Vec::new(); stub.extend_from_slice(&OrpcThis::create(cid, None).encode()); let payload: &[u8] = &[0xDE, 0xAD, 0xBE, 0xEF]; stub.extend_from_slice(&(payload.len() as i32).to_le_bytes()); stub.extend_from_slice(&(payload.len() as i32).to_le_bytes()); stub.extend_from_slice(payload); let req = build_request_pdu(2, nmx_callback_messages::DATA_RECEIVED_OPNUM, stub); client.write_all(&req).await.unwrap(); let _req_event = collect_event(&mut events).await; // Request let cb_event = collect_event(&mut events).await; match cb_event { CallbackEvent::CallbackInvoked { opnum, body } => { assert_eq!(opnum, nmx_callback_messages::DATA_RECEIVED_OPNUM); assert_eq!(body, payload); } other => panic!("expected CallbackInvoked, got {other:?}"), } // Drain the Response. PDU = 16-byte header + 8 response fields // (allocation_hint, context_id, cancel_count, reserved23) + stub. // For the callback success path the stub is OrpcThat(8) + hresult(4) = 12, // so resp_body (= frag_length - 16) is 8 + 12 = 20 bytes; hresult // sits at resp_body[16..20]. client.read_exact(&mut header_buf).await.unwrap(); let resp_header = PduHeader::decode(&header_buf).unwrap(); assert_eq!(resp_header.packet_type, PacketType::Response); let mut resp_body = vec![0u8; resp_header.fragment_length as usize - 16]; client.read_exact(&mut resp_body).await.unwrap(); assert_eq!(resp_body.len(), 20); assert_eq!(&resp_body[16..20], &S_OK.to_le_bytes()); server.shutdown().await; } #[tokio::test] async fn shutdown_terminates_accept_loop() { let (server, _events) = CallbackExporter::bind(local_addr(), fixed_identities()) .await .unwrap(); let addr = server.local_addr(); server.shutdown().await; // Subsequent connect should refuse (loop is gone, listener dropped). let res = tokio::time::timeout( std::time::Duration::from_millis(200), TcpStream::connect(addr), ) .await; // Either the timeout fires or connect returns an error — both are // acceptable evidence the listener stopped accepting. assert!(res.is_err() || res.unwrap().is_err()); } #[test] fn create_callback_objref_uses_callback_iid_and_port() { // Pure-codec test (no listener required for this part). let identities = fixed_identities(); let exporter_no_listener_objref = ComObjRefBuilder::create_standard_objref( INMX_SVC_CALLBACK_IID, 0x280, 5, identities.oxid, identities.oid, identities.callback_ipid, &["host[12345]"], ); // 8..24 is the IID assert_eq!( &exporter_no_listener_objref[8..24], INMX_SVC_CALLBACK_IID.as_bytes() ); } #[test] fn rem_query_interface_response_inspects_offset_60() { // Build a request stub with a known IID at offset 60..76 and verify // the encoder reads it back. let mut stub = vec![0u8; 76]; stub[60..76].copy_from_slice(INMX_SVC_CALLBACK_IID.as_bytes()); let identities = fixed_identities(); let (tx, _rx) = mpsc::unbounded_channel(); let body = encode_rem_query_interface_response(&stub, &identities, &tx); // hresult is at offset OrpcThat(8) + referent(4) + max_count(4) = 16. let hresult = i32::from_le_bytes([body[16], body[17], body[18], body[19]]); assert_eq!(hresult, S_OK); // Now with an unknown IID — should return E_NOINTERFACE. let mut stub2 = vec![0u8; 76]; stub2[60..76].copy_from_slice(&[0x99; 16]); let (tx2, _rx2) = mpsc::unbounded_channel(); let body2 = encode_rem_query_interface_response(&stub2, &identities, &tx2); let hresult2 = i32::from_le_bytes([body2[16], body2[17], body2[18], body2[19]]); assert_eq!(hresult2, E_NOINTERFACE); } }