diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 0760918..b98426f 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -20,6 +20,12 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cfg-if" version = "1.0.4" @@ -140,6 +146,17 @@ dependencies = [ "digest", ] +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + [[package]] name = "mxaccess" version = "0.0.0" @@ -166,6 +183,9 @@ version = "0.0.0" dependencies = [ "mxaccess-codec", "mxaccess-rpc", + "rand", + "tokio", + "tracing", ] [[package]] @@ -207,6 +227,18 @@ dependencies = [ "thiserror", ] +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -273,6 +305,16 @@ dependencies = [ "cipher", ] +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "subtle" version = "2.6.1" @@ -310,6 +352,63 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.52.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + [[package]] name = "typenum" version = "1.20.0" @@ -334,6 +433,21 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + [[package]] name = "zerocopy" version = "0.8.48" diff --git a/rust/crates/mxaccess-callback/Cargo.toml b/rust/crates/mxaccess-callback/Cargo.toml index 2cb04a5..4b5f8ab 100644 --- a/rust/crates/mxaccess-callback/Cargo.toml +++ b/rust/crates/mxaccess-callback/Cargo.toml @@ -11,6 +11,9 @@ authors.workspace = true [dependencies] mxaccess-rpc = { path = "../mxaccess-rpc" } mxaccess-codec = { path = "../mxaccess-codec" } +tokio = { workspace = true } +tracing = { workspace = true } +rand = "0.8" [lints] workspace = true diff --git a/rust/crates/mxaccess-callback/src/exporter.rs b/rust/crates/mxaccess-callback/src/exporter.rs new file mode 100644 index 0000000..d1b10fa --- /dev/null +++ b/rust/crates/mxaccess-callback/src/exporter.rs @@ -0,0 +1,1098 @@ +//! `INmxSvcCallback` callback exporter — TCP server. +//! +//! Direct port of `src/MxNativeClient/ManagedCallbackExporter.cs`. Spins a +//! tokio TCP listener, accepts incoming DCE/RPC connections from +//! `NmxSvc.exe`, walks them through Bind / AlterContext / Request / +//! Auth3 PDUs, and dispatches `IRemUnknown` and `INmxSvcCallback` requests. +//! The .NET reference is the executable spec; every wire shape and +//! HRESULT cited inline. +//! +//! Two interfaces are served: +//! +//! - `IRemUnknown` (IID `00000131-0000-0000-C000-000000000046`, +//! `RemUnknownMessages.cs:7`) — opnums 3 (`RemQueryInterface`), +//! 4 (`RemAddRef`), 5 (`RemRelease`). The QI handler returns `S_OK` +//! for `IRemUnknown` / `INmxSvcCallback` / `IUnknown` and +//! `E_NOINTERFACE` (`0x80004002`) otherwise — mirrors +//! `ManagedCallbackExporter.cs:196-200`. +//! - `INmxSvcCallback` (IID `B49F92F7-C748-4169-8ECA-A0670B012746`, +//! `NmxProcedureMetadata.cs:6`) — opnums 3 (`DataReceived`), +//! 4 (`StatusReceived`). The handler decodes the inbound buffer via +//! [`crate::nmx_callback_messages::parse_callback_request`] (re-export +//! from `mxaccess-rpc::nmx_callback_messages`), emits a typed event, +//! and returns the success response built by +//! [`crate::nmx_callback_messages::encode_callback_response`]. +//! +//! Auth3 PDUs are accepted but ignored (`ManagedCallbackExporter.cs:133-137`) +//! — NTLM packet integrity for inbound frames is not yet wired (open +//! follow-up F2). + +#![allow(clippy::indexing_slicing)] + +use std::net::SocketAddr; + +use mxaccess_rpc::error::RpcError; +use mxaccess_rpc::guid::Guid; +use mxaccess_rpc::nmx_callback_messages; +use mxaccess_rpc::nmx_metadata::INMX_SVC_CALLBACK_IID; +use mxaccess_rpc::objref::ComObjRefBuilder; +use mxaccess_rpc::orpc::{OrpcThat, StdObjRef}; +use mxaccess_rpc::pdu::{ + BindPdu, FaultPdu, PacketType, PduHeader, RequestPdu, ResponsePdu, SyntaxId, +}; +use mxaccess_rpc::rem_unknown::{ + IREM_UNKNOWN_IID, REM_ADD_REF_OPNUM, REM_QUERY_INTERFACE_OPNUM, REM_RELEASE_OPNUM, +}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinHandle; + +/// `IUnknown` IID `00000000-0000-0000-C000-000000000046`. The QI handler +/// also returns success for this in addition to `IRemUnknown` and +/// `INmxSvcCallback` (`ManagedCallbackExporter.cs:198`). +pub const IUNKNOWN_IID: Guid = Guid::new([ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46, +]); + +/// `S_OK` HRESULT. +const S_OK: i32 = 0; + +/// `E_NOINTERFACE` HRESULT — returned by `RemQueryInterface` for any IID we +/// don't speak (`ManagedCallbackExporter.cs:200`). +const E_NOINTERFACE: i32 = 0x8000_4002u32 as i32; + +/// `RPC_S_PROCNUM_OUT_OF_RANGE` (1783 = 0x6F7) — returned as a fault for any +/// request whose `(iid, opnum)` we don't dispatch +/// (`ManagedCallbackExporter.cs:171,184,187`). +const RPC_S_PROCNUM_OUT_OF_RANGE: u32 = 0x0000_06F7; + +/// Fixed PDU framing constants. The .NET reference asserts these via +/// hand-rolled writes at `ManagedCallbackExporter.cs:226-254`. +const BIND_ACK_MAX_FRAGMENT: u16 = 4280; +const BIND_ACK_ASSOC_GROUP_ID: u32 = 0x0000_5353; +const RESPONSE_PACKET_FLAGS: u8 = 0x03; +const RESPONSE_DATA_REPRESENTATION: u32 = 0x0000_0010; + +/// Identities the exporter publishes in OBJREFs and STDOBJREFs. +/// +/// The .NET reference generates these via `Guid.NewGuid()` and +/// `RandomNumberGenerator.Fill` (`ManagedCallbackExporter.cs:14-20`); the +/// Rust port mirrors that for production but exposes a [`fixed`] constructor +/// so tests can pin all four values for deterministic byte comparisons. +/// +/// [`fixed`]: ExporterIdentities::fixed +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ExporterIdentities { + pub oxid: u64, + pub oid: u64, + pub callback_ipid: Guid, + pub rem_unknown_ipid: Guid, +} + +impl ExporterIdentities { + /// Generate identities from `rand::random()` — production default. + /// Mirrors `ManagedCallbackExporter.cs:14-20` (`RandomUInt64` + + /// `Guid.NewGuid()` × 2). + #[must_use] + pub fn random() -> Self { + Self { + oxid: rand::random(), + oid: rand::random(), + callback_ipid: Guid::new(rand::random()), + rem_unknown_ipid: Guid::new(rand::random()), + } + } + + /// Construct with caller-supplied values. Useful in tests where + /// deterministic OBJREF byte comparisons are required. + #[must_use] + pub const fn fixed(oxid: u64, oid: u64, callback_ipid: Guid, rem_unknown_ipid: Guid) -> Self { + Self { + oxid, + oid, + callback_ipid, + rem_unknown_ipid, + } + } +} + +/// Diagnostic events emitted by the exporter as it serves connections. +/// +/// The .NET reference logs strings into a `List` (per +/// `ManagedCallbackExporter.cs:12,33-42,315-321`). The Rust port emits +/// typed events instead — same information, more useful to consumers +/// that want to assert structurally. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CallbackEvent { + /// A new TCP connection was accepted (`cs:78`). + ClientConnected { remote: SocketAddr }, + /// `accept()` returned an error (`cs:89-92`). + AcceptError { reason: String }, + /// Bind / AlterContext negotiated a presentation context (`cs:121`). + Bind { context_id: u16, iid: Guid }, + /// Auth3 PDU received and ignored (`cs:133-137`). + Auth3Ignored, + /// Request PDU received (`cs:162`). + Request { + iid: Guid, + context_id: u16, + opnum: u16, + stub_len: usize, + }, + /// `IRemUnknown::RemQueryInterface` invoked (`cs:191-200`). + RemQueryInterface { requested_iid: Guid, hresult: i32 }, + /// `INmxSvcCallback::DataReceived` or `StatusReceived` invoked + /// (`cs:177-183`). + CallbackInvoked { opnum: u16, body: Vec }, + /// Request whose `(iid, opnum)` was not dispatched — fault returned + /// (`cs:171,184,187`). + UnhandledRequest { iid: Guid, opnum: u16 }, + /// Client closed the connection (`cs:107-110`). + ClientDisconnected, + /// PDU header parse / fragment-length / packet-type error. + ProtocolError { reason: String }, +} + +/// The TCP exporter handle. Drop aborts the accept task. +/// +/// Construction is asynchronous because [`tokio::net::TcpListener::bind`] +/// is. Use [`CallbackExporter::bind`] to start; the returned tuple includes +/// an [`mpsc::UnboundedReceiver`] that streams diagnostic +/// events as the server runs. +pub struct CallbackExporter { + local_addr: SocketAddr, + identities: ExporterIdentities, + shutdown_tx: Option>, + accept_handle: Option>, +} + +impl CallbackExporter { + /// Bind to `addr` (use port 0 for an OS-assigned ephemeral port) and + /// start the accept loop. Returns the handle plus the diagnostic + /// event stream. + /// + /// # Errors + /// Returns [`std::io::Error`] from [`TcpListener::bind`]. + pub async fn bind( + addr: SocketAddr, + identities: ExporterIdentities, + ) -> std::io::Result<(Self, mpsc::UnboundedReceiver)> { + let listener = TcpListener::bind(addr).await?; + let local_addr = listener.local_addr()?; + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + let accept_handle = tokio::spawn(accept_loop(listener, identities, event_tx, shutdown_rx)); + + Ok(( + Self { + local_addr, + identities, + shutdown_tx: Some(shutdown_tx), + accept_handle: Some(accept_handle), + }, + event_rx, + )) + } + + /// Local address the listener is bound to. + #[must_use] + pub fn local_addr(&self) -> SocketAddr { + self.local_addr + } + + /// Identities this exporter publishes. + #[must_use] + pub fn identities(&self) -> ExporterIdentities { + self.identities + } + + /// Build a callback OBJREF to publish back to the AVEVA service. + /// + /// Mirrors `ManagedCallbackExporter.CreateCallbackObjRef` + /// (`cs:44-54`): the IID is `INmxSvcCallback`, `std_flags = 0x280`, + /// `public_refs = 5`, OXID/OID/IPID come from `self.identities`, and + /// the single string binding is `"[]"`. + #[must_use] + pub fn create_callback_objref(&self, hostname: &str) -> Vec { + let binding = format!("{hostname}[{port}]", port = self.local_addr.port()); + ComObjRefBuilder::create_standard_objref( + INMX_SVC_CALLBACK_IID, + 0x280, + 5, + self.identities.oxid, + self.identities.oid, + self.identities.callback_ipid, + &[binding.as_str()], + ) + } + + /// Gracefully shut down the accept loop. Returns once the loop task + /// exits. Any in-flight connections continue until they EOF on read. + pub async fn shutdown(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + // Receiver dropped is fine; the loop is already exiting. + let _ = tx.send(()); + } + if let Some(handle) = self.accept_handle.take() { + let _ = handle.await; + } + } +} + +impl Drop for CallbackExporter { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + // Best-effort signal; if the loop is gone the receiver will be + // dropped which is also fine. + let _ = tx.send(()); + } + if let Some(handle) = self.accept_handle.take() { + handle.abort(); + } + } +} + +async fn accept_loop( + listener: TcpListener, + identities: ExporterIdentities, + event_tx: mpsc::UnboundedSender, + mut shutdown_rx: oneshot::Receiver<()>, +) { + loop { + tokio::select! { + _ = &mut shutdown_rx => return, + accept = listener.accept() => match accept { + Ok((stream, remote)) => { + let _ = event_tx.send(CallbackEvent::ClientConnected { remote }); + let event_tx = event_tx.clone(); + tokio::spawn(serve_client(stream, identities, event_tx)); + } + Err(err) => { + let _ = event_tx.send(CallbackEvent::AcceptError { reason: err.to_string() }); + // Listener errors are usually fatal (e.g. socket closed). + return; + } + }, + } + } +} + +async fn serve_client( + mut stream: TcpStream, + identities: ExporterIdentities, + event_tx: mpsc::UnboundedSender, +) { + // The .NET reference threads `currentContext` through `EncodeBindAck` + // but the encoder discards it (`ManagedCallbackExporter.cs:252: + // `_ = contextId;`), so we simply record the negotiated context id in + // the Bind event and read `request.context_id` per-request. + let mut current_iid: Guid = Guid::ZERO; + + loop { + let pdu = match read_pdu(&mut stream).await { + Ok(Some(p)) => p, + Ok(None) => { + let _ = event_tx.send(CallbackEvent::ClientDisconnected); + return; + } + Err(err) => { + let _ = event_tx.send(CallbackEvent::ProtocolError { + reason: err.to_string(), + }); + return; + } + }; + + let header = match PduHeader::decode(&pdu) { + Ok(h) => h, + Err(err) => { + let _ = event_tx.send(CallbackEvent::ProtocolError { + reason: format!("{err}"), + }); + return; + } + }; + + match header.packet_type { + PacketType::Bind | PacketType::AlterContext => match BindPdu::decode(&pdu) { + Ok(bind) => { + let context_id = if let Some(first) = bind.presentation_contexts.first() { + current_iid = Guid::new(first.abstract_syntax.uuid_bytes); + first.context_id + } else { + current_iid = Guid::ZERO; + 0 + }; + let _ = event_tx.send(CallbackEvent::Bind { + context_id, + iid: current_iid, + }); + let response = encode_bind_ack(header.call_id); + if stream.write_all(&response).await.is_err() { + return; + } + } + Err(err) => { + let _ = event_tx.send(CallbackEvent::ProtocolError { + reason: format!("{err}"), + }); + return; + } + }, + PacketType::Request => { + let request = match RequestPdu::decode(&pdu) { + Ok(r) => r, + Err(err) => { + let _ = event_tx.send(CallbackEvent::ProtocolError { + reason: format!("{err}"), + }); + return; + } + }; + let _ = event_tx.send(CallbackEvent::Request { + iid: current_iid, + context_id: request.context_id, + opnum: request.opnum, + stub_len: request.stub_data.len(), + }); + let response_bytes = handle_request(&request, current_iid, &identities, &event_tx); + if stream.write_all(&response_bytes).await.is_err() { + return; + } + } + PacketType::Auth3 => { + let _ = event_tx.send(CallbackEvent::Auth3Ignored); + // No response per `cs:133-137`. + } + other => { + let _ = event_tx.send(CallbackEvent::ProtocolError { + reason: format!("unhandled PDU type {other:?}"), + }); + return; + } + } + } +} + +fn handle_request( + request: &RequestPdu, + current_iid: Guid, + identities: &ExporterIdentities, + event_tx: &mpsc::UnboundedSender, +) -> Vec { + if current_iid == IREM_UNKNOWN_IID { + match request.opnum { + REM_QUERY_INTERFACE_OPNUM => { + let response_body = + encode_rem_query_interface_response(&request.stub_data, identities, event_tx); + return wrap_response(request.header.call_id, request.context_id, response_body); + } + REM_ADD_REF_OPNUM | REM_RELEASE_OPNUM => { + return wrap_response( + request.header.call_id, + request.context_id, + encode_orpc_hresult_response(S_OK).to_vec(), + ); + } + _ => { + let _ = event_tx.send(CallbackEvent::UnhandledRequest { + iid: current_iid, + opnum: request.opnum, + }); + return encode_fault( + request.header.call_id, + request.context_id, + RPC_S_PROCNUM_OUT_OF_RANGE, + ); + } + } + } + + if current_iid == INMX_SVC_CALLBACK_IID { + if request.opnum == nmx_callback_messages::DATA_RECEIVED_OPNUM + || request.opnum == nmx_callback_messages::STATUS_RECEIVED_OPNUM + { + // Decode the callback body for the event stream; if it fails to + // parse, still return success per the .NET reference (`cs:179-181`) + // — diagnostic info goes through `ProtocolError` instead. + let body = match nmx_callback_messages::parse_callback_request(&request.stub_data) { + Ok(parsed) => parsed.body, + Err(err) => { + let _ = event_tx.send(CallbackEvent::ProtocolError { + reason: format!("callback request decode: {err}"), + }); + Vec::new() + } + }; + let _ = event_tx.send(CallbackEvent::CallbackInvoked { + opnum: request.opnum, + body, + }); + let resp = nmx_callback_messages::encode_callback_response(S_OK); + return wrap_response(request.header.call_id, request.context_id, resp.to_vec()); + } + + let _ = event_tx.send(CallbackEvent::UnhandledRequest { + iid: current_iid, + opnum: request.opnum, + }); + return encode_fault( + request.header.call_id, + request.context_id, + RPC_S_PROCNUM_OUT_OF_RANGE, + ); + } + + let _ = event_tx.send(CallbackEvent::UnhandledRequest { + iid: current_iid, + opnum: request.opnum, + }); + encode_fault( + request.header.call_id, + request.context_id, + RPC_S_PROCNUM_OUT_OF_RANGE, + ) +} + +/// Build the `RemQueryInterface` response stub. Mirrors +/// `ManagedCallbackExporter.EncodeRemQueryInterfaceResponse` +/// (`cs:190-216`). +fn encode_rem_query_interface_response( + request_stub: &[u8], + identities: &ExporterIdentities, + event_tx: &mpsc::UnboundedSender, +) -> Vec { + // Requested IID at offset 60..76 of the request body (`cs:192`). If the + // stub is short, default to ZERO — same as .NET `Guid.Empty` fallback. + let requested_iid = if request_stub.len() >= 76 { + Guid::parse(&request_stub[60..76]).unwrap_or(Guid::ZERO) + } else { + Guid::ZERO + }; + + // Pick the IPID to publish (`cs:195`): RemUnknown gets RemUnknownIpid; + // anything else (including the success cases for callback / IUnknown) + // gets CallbackIpid. + let ipid = if requested_iid == IREM_UNKNOWN_IID { + identities.rem_unknown_ipid + } else { + identities.callback_ipid + }; + + let std = StdObjRef { + flags: 0x280, + public_refs: 5, + oxid: identities.oxid, + oid: identities.oid, + ipid, + }; + + let hresult = if requested_iid == IREM_UNKNOWN_IID + || requested_iid == INMX_SVC_CALLBACK_IID + || requested_iid == IUNKNOWN_IID + { + S_OK + } else { + E_NOINTERFACE + }; + + let _ = event_tx.send(CallbackEvent::RemQueryInterface { + requested_iid, + hresult, + }); + + // Layout (cs:202-215): + // 0..8 OrpcThat (zeroed) + // 8..12 referent_id u32 LE = 0x00020000 + // 12..16 max_count u32 LE = 1 + // 16..20 hresult i32 LE + // 20..24 4-byte NDR pad ahead of STDOBJREF (zero) + // 24..64 STDOBJREF (40 bytes) + // 64..68 error_code u32 LE = 0 + let mut buf = vec![0u8; OrpcThat::ENCODED_LEN + 4 + 4 + 4 + 4 + StdObjRef::ENCODED_LEN + 4]; + let orpc_that = OrpcThat::default().encode(); + buf[..OrpcThat::ENCODED_LEN].copy_from_slice(&orpc_that); + let mut off = OrpcThat::ENCODED_LEN; + buf[off..off + 4].copy_from_slice(&0x0002_0000u32.to_le_bytes()); + off += 4; + buf[off..off + 4].copy_from_slice(&1u32.to_le_bytes()); + off += 4; + buf[off..off + 4].copy_from_slice(&hresult.to_le_bytes()); + off += 4; + // 4-byte pad (zero). + off += 4; + let std_bytes = std.encode(); + buf[off..off + StdObjRef::ENCODED_LEN].copy_from_slice(&std_bytes); + off += StdObjRef::ENCODED_LEN; + buf[off..off + 4].copy_from_slice(&0u32.to_le_bytes()); + buf +} + +/// 12-byte simple HRESULT response — `OrpcThat(8) + hresult(4)`. Mirrors +/// `EncodeOrpcHResultResponse` (`cs:218-224`). +fn encode_orpc_hresult_response(hresult: i32) -> [u8; 12] { + let mut buf = [0u8; 12]; + buf[..OrpcThat::ENCODED_LEN].copy_from_slice(&OrpcThat::default().encode()); + buf[OrpcThat::ENCODED_LEN..].copy_from_slice(&hresult.to_le_bytes()); + buf +} + +/// Wrap a stub-data body in a Response PDU. Mirrors `EncodeResponse` +/// (`cs:256-267`). +fn wrap_response(call_id: u32, context_id: u16, stub_data: Vec) -> Vec { + let header = PduHeader { + version: 5, + version_minor: 0, + packet_type: PacketType::Response, + packet_flags: RESPONSE_PACKET_FLAGS, + data_representation: RESPONSE_DATA_REPRESENTATION, + fragment_length: 0, // overwritten by RequestPdu::encode + auth_length: 0, + call_id, + }; + ResponsePdu { + header, + allocation_hint: u32::try_from(stub_data.len()).unwrap_or(u32::MAX), + context_id, + cancel_count: 0, + reserved23: 0, + stub_data, + } + .encode() +} + +/// Build a Fault PDU. Mirrors `EncodeFault` (`cs:269-277`). +fn encode_fault(call_id: u32, context_id: u16, status: u32) -> Vec { + let header = PduHeader { + version: 5, + version_minor: 0, + packet_type: PacketType::Fault, + packet_flags: RESPONSE_PACKET_FLAGS, + data_representation: RESPONSE_DATA_REPRESENTATION, + fragment_length: 0, // overwritten by FaultPdu::encode + auth_length: 0, + call_id, + }; + FaultPdu { + header, + allocation_hint: 0, + context_id, + cancel_count: 0, + reserved23: 0, + status, + stub_data: Vec::new(), + } + .encode() +} + +/// Build the BindAck PDU. Hand-rolled — neither `mxaccess-rpc::pdu` nor the +/// .NET `DceRpcBindPdu` class encode this shape; the original code at +/// `ManagedCallbackExporter.cs:226-254` writes it byte-by-byte with a single +/// presentation-context acceptance entry pointing at the NDR20 transfer +/// syntax. +/// +/// Layout per `[C706]` §12.6.4.4 + `cs:226-254`: +/// +/// ```text +/// offset size field +/// 0 16 PduHeader (ptype=BindAck=12, flags=0x03, drep=0x10) +/// 16 2 max_xmit_fragment u16 LE = 4280 +/// 18 2 max_recv_fragment u16 LE = 4280 +/// 20 4 assoc_group_id u32 LE = 0x5353 +/// 24 2 sec_addr_length u16 LE +/// 26 sl secondary_address (UTF-8 + null) +/// 26+sl p pad to 4-byte alignment from start of pdu +/// resOff 1 n_results = 1 +/// resOff+1 1 reserved +/// resOff+2 2 reserved2 u16 LE +/// resOff+4 2 result u16 LE = 0 (acceptance) +/// resOff+6 2 reason u16 LE = 0 +/// resOff+8 16 transfer_syntax_uuid (NDR20) +/// resOff+24 2 transfer_syntax_version_major u16 LE +/// resOff+26 2 transfer_syntax_version_minor u16 LE +/// ``` +fn encode_bind_ack(call_id: u32) -> Vec { + // .NET writes a single-byte secondary address: empty string + "\0". + const SECONDARY: &[u8] = b"\0"; + let sec_addr_length = SECONDARY.len(); + // Pad to 4-byte alignment from the start of the PDU. + let unpadded = 28 + sec_addr_length; + let pad = align_up(unpadded, 4) - unpadded; + let result_offset = unpadded + pad; + // 4 bytes (n_results + reserved + reserved2) + 4 (result + reason) + 20 + // (NDR20 syntax id) = 28 bytes after result_offset. + let length = result_offset + 4 + 24; + let mut pdu = vec![0u8; length]; + + let header = PduHeader { + version: 5, + version_minor: 0, + packet_type: PacketType::BindAck, + packet_flags: RESPONSE_PACKET_FLAGS, + data_representation: RESPONSE_DATA_REPRESENTATION, + fragment_length: u16::try_from(length).unwrap_or(u16::MAX), + auth_length: 0, + call_id, + }; + let _ = header.encode(&mut pdu); + + pdu[16..18].copy_from_slice(&BIND_ACK_MAX_FRAGMENT.to_le_bytes()); + pdu[18..20].copy_from_slice(&BIND_ACK_MAX_FRAGMENT.to_le_bytes()); + pdu[20..24].copy_from_slice(&BIND_ACK_ASSOC_GROUP_ID.to_le_bytes()); + pdu[24..26].copy_from_slice( + &u16::try_from(sec_addr_length) + .unwrap_or(u16::MAX) + .to_le_bytes(), + ); + pdu[26..26 + SECONDARY.len()].copy_from_slice(SECONDARY); + + let mut o = result_offset; + pdu[o] = 1; // n_results + pdu[o + 1] = 0; // reserved + pdu[o + 2] = 0; // reserved2 LSB + pdu[o + 3] = 0; // reserved2 MSB + o += 4; + // result (=0 acceptance) + reason (=0). + pdu[o..o + 2].copy_from_slice(&0u16.to_le_bytes()); + pdu[o + 2..o + 4].copy_from_slice(&0u16.to_le_bytes()); + o += 4; + let ndr = SyntaxId::NDR20; + pdu[o..o + 16].copy_from_slice(&ndr.uuid_bytes); + pdu[o + 16..o + 18].copy_from_slice(&ndr.version_major.to_le_bytes()); + pdu[o + 18..o + 20].copy_from_slice(&ndr.version_minor.to_le_bytes()); + + pdu +} + +const fn align_up(value: usize, alignment: usize) -> usize { + let r = value % alignment; + if r == 0 { value } else { value + alignment - r } +} + +/// Read one full PDU from `stream`. Returns `Ok(None)` on clean EOF before +/// the first byte. Mirrors `ReadPduAsync` + `ReadExactAsync` +/// (`cs:279-313`). +async fn read_pdu(stream: &mut TcpStream) -> Result>, RpcError> { + let mut header_buf = [0u8; PduHeader::LENGTH]; + match read_exact_or_eof(stream, &mut header_buf).await { + Ok(true) => {} + Ok(false) => return Ok(None), + Err(_) => return Ok(None), + } + let header = PduHeader::decode(&header_buf)?; + let frag = header.fragment_length as usize; + if frag < PduHeader::LENGTH { + return Err(RpcError::InvalidFragmentLength { + frag_length: frag, + buffer_len: 0, + auth_length: header.auth_length as usize, + }); + } + let mut pdu = vec![0u8; frag]; + pdu[..PduHeader::LENGTH].copy_from_slice(&header_buf); + if frag > PduHeader::LENGTH { + match read_exact_or_eof(stream, &mut pdu[PduHeader::LENGTH..]).await { + Ok(true) => {} + Ok(_) | Err(_) => return Ok(None), + } + } + Ok(Some(pdu)) +} + +/// Like `AsyncReadExt::read_exact` but treats clean EOF *before any byte was +/// read* as `Ok(false)` so the caller can distinguish "client closed before +/// the next PDU" from "client closed mid-PDU." +async fn read_exact_or_eof(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result { + let mut filled = 0; + while filled < buf.len() { + let n = stream.read(&mut buf[filled..]).await?; + if n == 0 { + return Ok(filled != 0 && filled == buf.len()); + } + filled += n; + } + Ok(true) +} + +#[cfg(test)] +#[allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::indexing_slicing, + clippy::panic +)] +mod tests { + use super::*; + use mxaccess_rpc::orpc::OrpcThis; + use mxaccess_rpc::pdu::{BindPdu, PresentationContext, SyntaxId}; + use mxaccess_rpc::rem_unknown::{ + IREM_UNKNOWN_IID, REM_QUERY_INTERFACE_OPNUM, encode_rem_query_interface_request, + }; + + fn fixed_identities() -> ExporterIdentities { + ExporterIdentities::fixed( + 0x1111_2222_3333_4444, + 0x5555_6666_7777_8888, + Guid::new([0xCC; 16]), + Guid::new([0xDD; 16]), + ) + } + + fn local_addr() -> SocketAddr { + "127.0.0.1:0".parse().unwrap() + } + + fn build_bind_pdu(call_id: u32, abstract_iid: Guid) -> Vec { + let header = PduHeader { + version: 5, + version_minor: 0, + packet_type: PacketType::Bind, + packet_flags: 0x03, + data_representation: RESPONSE_DATA_REPRESENTATION, + fragment_length: 0, // BindPdu::encode rewrites it + auth_length: 0, + call_id, + }; + BindPdu { + header, + max_transmit_fragment: BIND_ACK_MAX_FRAGMENT, + max_receive_fragment: BIND_ACK_MAX_FRAGMENT, + association_group_id: 0, + presentation_contexts: vec![PresentationContext { + context_id: 0, + abstract_syntax: SyntaxId { + uuid_bytes: *abstract_iid.as_bytes(), + version_major: 0, + version_minor: 0, + }, + transfer_syntaxes: vec![SyntaxId::NDR20], + }], + reserved25_28: [0; 3], + } + .encode() + } + + fn build_request_pdu(call_id: u32, opnum: u16, stub: Vec) -> Vec { + let header = PduHeader { + version: 5, + version_minor: 0, + packet_type: PacketType::Request, + packet_flags: 0x03, + data_representation: RESPONSE_DATA_REPRESENTATION, + fragment_length: 0, + auth_length: 0, + call_id, + }; + RequestPdu { + header, + allocation_hint: stub.len() as u32, + context_id: 0, + opnum, + stub_data: stub, + } + .encode() + } + + async fn collect_event(rx: &mut mpsc::UnboundedReceiver) -> CallbackEvent { + tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("event timeout") + .expect("event channel closed") + } + + #[test] + fn iunknown_iid_constant_matches_dotnet() { + // .NET `new Guid("00000000-0000-0000-C000-000000000046").ToString("D")`. + assert_eq!( + IUNKNOWN_IID.to_string(), + "00000000-0000-0000-c000-000000000046" + ); + } + + #[test] + fn align_up_matches_dotnet_align() { + assert_eq!(align_up(0, 4), 0); + assert_eq!(align_up(1, 4), 4); + assert_eq!(align_up(28, 4), 28); + assert_eq!(align_up(29, 4), 32); + assert_eq!(align_up(31, 4), 32); + } + + #[test] + fn encode_bind_ack_layout() { + let pdu = encode_bind_ack(42); + // PDU header decodes back. + let header = PduHeader::decode(&pdu).unwrap(); + assert_eq!(header.packet_type, PacketType::BindAck); + assert_eq!(header.call_id, 42); + assert_eq!(header.fragment_length as usize, pdu.len()); + + // Max fragments at 16/18. + assert_eq!(u16::from_le_bytes([pdu[16], pdu[17]]), 4280); + assert_eq!(u16::from_le_bytes([pdu[18], pdu[19]]), 4280); + // Assoc group id at 20. + assert_eq!( + u32::from_le_bytes([pdu[20], pdu[21], pdu[22], pdu[23]]), + 0x5353 + ); + // Empty secondary address (length=1, single null byte). + assert_eq!(u16::from_le_bytes([pdu[24], pdu[25]]), 1); + assert_eq!(pdu[26], 0); + + // Result-list n_results=1 at the aligned offset (32). + // unpadded = 28 + 1 = 29; pad = 3; result_offset = 32. + let result_offset = 32; + assert_eq!(pdu[result_offset], 1); + + // Acceptance + NDR20 syntax follow. + let syntax_offset = result_offset + 4 + 4; + assert_eq!( + &pdu[syntax_offset..syntax_offset + 16], + &SyntaxId::NDR20.uuid_bytes + ); + } + + #[test] + fn encode_orpc_hresult_response_layout() { + let r = encode_orpc_hresult_response(0); + assert_eq!(r.len(), 12); + assert_eq!(&r[..8], &[0u8; 8]); + assert_eq!(&r[8..], &0i32.to_le_bytes()); + } + + #[tokio::test] + async fn bind_request_round_trip_via_real_socket() { + let (server, mut events) = CallbackExporter::bind(local_addr(), fixed_identities()) + .await + .unwrap(); + let addr = server.local_addr(); + + let mut client = TcpStream::connect(addr).await.unwrap(); + + // 1. Send a Bind PDU advertising IRemUnknown abstract syntax. + let bind = build_bind_pdu(1, IREM_UNKNOWN_IID); + client.write_all(&bind).await.unwrap(); + + // Server emits ClientConnected then Bind. + let connected = collect_event(&mut events).await; + assert!(matches!(connected, CallbackEvent::ClientConnected { .. })); + let bind_event = collect_event(&mut events).await; + match bind_event { + CallbackEvent::Bind { context_id, iid } => { + assert_eq!(context_id, 0); + assert_eq!(iid, IREM_UNKNOWN_IID); + } + other => panic!("expected Bind, got {other:?}"), + } + + // Read the BindAck back (16 bytes header, then frag_length determines the body). + let mut header_buf = [0u8; 16]; + client.read_exact(&mut header_buf).await.unwrap(); + let header = PduHeader::decode(&header_buf).unwrap(); + assert_eq!(header.packet_type, PacketType::BindAck); + let mut body = vec![0u8; header.fragment_length as usize - 16]; + client.read_exact(&mut body).await.unwrap(); + + // 2. Send a RemQueryInterface request asking for IUnknown. + let qi_request_body = encode_rem_query_interface_request( + Guid::ZERO, // source ipid (unused by server) + IUNKNOWN_IID, // requested IID — should resolve to S_OK + Guid::new([0xAA; 16]), // causality id + 5, + ); + let qi_pdu = build_request_pdu(2, REM_QUERY_INTERFACE_OPNUM, qi_request_body); + client.write_all(&qi_pdu).await.unwrap(); + + // Server emits Request then RemQueryInterface event. + let req_event = collect_event(&mut events).await; + match req_event { + CallbackEvent::Request { iid, opnum, .. } => { + assert_eq!(iid, IREM_UNKNOWN_IID); + assert_eq!(opnum, REM_QUERY_INTERFACE_OPNUM); + } + other => panic!("expected Request, got {other:?}"), + } + let qi_event = collect_event(&mut events).await; + match qi_event { + CallbackEvent::RemQueryInterface { + requested_iid, + hresult, + } => { + assert_eq!(requested_iid, IUNKNOWN_IID); + assert_eq!(hresult, S_OK); + } + other => panic!("expected RemQueryInterface, got {other:?}"), + } + + // Read response PDU header back. + client.read_exact(&mut header_buf).await.unwrap(); + let resp_header = PduHeader::decode(&header_buf).unwrap(); + assert_eq!(resp_header.packet_type, PacketType::Response); + let mut resp_body = vec![0u8; resp_header.fragment_length as usize - 16]; + client.read_exact(&mut resp_body).await.unwrap(); + + server.shutdown().await; + } + + #[tokio::test] + async fn unknown_opnum_returns_fault() { + let (server, mut events) = CallbackExporter::bind(local_addr(), fixed_identities()) + .await + .unwrap(); + let mut client = TcpStream::connect(server.local_addr()).await.unwrap(); + + let bind = build_bind_pdu(1, IREM_UNKNOWN_IID); + client.write_all(&bind).await.unwrap(); + // Drain BindAck. + let mut header_buf = [0u8; 16]; + client.read_exact(&mut header_buf).await.unwrap(); + let header = PduHeader::decode(&header_buf).unwrap(); + let mut body = vec![0u8; header.fragment_length as usize - 16]; + client.read_exact(&mut body).await.unwrap(); + + // Drain Connected/Bind events. + let _ = collect_event(&mut events).await; + let _ = collect_event(&mut events).await; + + // Send a request with opnum 99 (unknown). + let req = build_request_pdu(2, 99, vec![0u8; 76]); + client.write_all(&req).await.unwrap(); + + // Drain Request and UnhandledRequest events. + let _ = collect_event(&mut events).await; + let unhandled = collect_event(&mut events).await; + match unhandled { + CallbackEvent::UnhandledRequest { iid, opnum } => { + assert_eq!(iid, IREM_UNKNOWN_IID); + assert_eq!(opnum, 99); + } + other => panic!("expected UnhandledRequest, got {other:?}"), + } + + client.read_exact(&mut header_buf).await.unwrap(); + let resp_header = PduHeader::decode(&header_buf).unwrap(); + assert_eq!(resp_header.packet_type, PacketType::Fault); + + server.shutdown().await; + } + + #[tokio::test] + async fn callback_invocation_emits_event_and_acks() { + let (server, mut events) = CallbackExporter::bind(local_addr(), fixed_identities()) + .await + .unwrap(); + let mut client = TcpStream::connect(server.local_addr()).await.unwrap(); + + // Bind to INmxSvcCallback. + let bind = build_bind_pdu(1, INMX_SVC_CALLBACK_IID); + client.write_all(&bind).await.unwrap(); + let mut header_buf = [0u8; 16]; + client.read_exact(&mut header_buf).await.unwrap(); + let header = PduHeader::decode(&header_buf).unwrap(); + let mut body = vec![0u8; header.fragment_length as usize - 16]; + client.read_exact(&mut body).await.unwrap(); + + let _ = collect_event(&mut events).await; // ClientConnected + let _ = collect_event(&mut events).await; // Bind + + // Build a callback request: OrpcThis + size + max_count + body. + let cid = Guid::new([0xEE; 16]); + let mut stub = Vec::new(); + stub.extend_from_slice(&OrpcThis::create(cid, None).encode()); + let payload: &[u8] = &[0xDE, 0xAD, 0xBE, 0xEF]; + stub.extend_from_slice(&(payload.len() as i32).to_le_bytes()); + stub.extend_from_slice(&(payload.len() as i32).to_le_bytes()); + stub.extend_from_slice(payload); + + let req = build_request_pdu(2, nmx_callback_messages::DATA_RECEIVED_OPNUM, stub); + client.write_all(&req).await.unwrap(); + + let _req_event = collect_event(&mut events).await; // Request + let cb_event = collect_event(&mut events).await; + match cb_event { + CallbackEvent::CallbackInvoked { opnum, body } => { + assert_eq!(opnum, nmx_callback_messages::DATA_RECEIVED_OPNUM); + assert_eq!(body, payload); + } + other => panic!("expected CallbackInvoked, got {other:?}"), + } + + // Drain the Response. PDU = 16-byte header + 8 response fields + // (allocation_hint, context_id, cancel_count, reserved23) + stub. + // For the callback success path the stub is OrpcThat(8) + hresult(4) = 12, + // so resp_body (= frag_length - 16) is 8 + 12 = 20 bytes; hresult + // sits at resp_body[16..20]. + client.read_exact(&mut header_buf).await.unwrap(); + let resp_header = PduHeader::decode(&header_buf).unwrap(); + assert_eq!(resp_header.packet_type, PacketType::Response); + let mut resp_body = vec![0u8; resp_header.fragment_length as usize - 16]; + client.read_exact(&mut resp_body).await.unwrap(); + assert_eq!(resp_body.len(), 20); + assert_eq!(&resp_body[16..20], &S_OK.to_le_bytes()); + + server.shutdown().await; + } + + #[tokio::test] + async fn shutdown_terminates_accept_loop() { + let (server, _events) = CallbackExporter::bind(local_addr(), fixed_identities()) + .await + .unwrap(); + let addr = server.local_addr(); + server.shutdown().await; + + // Subsequent connect should refuse (loop is gone, listener dropped). + let res = tokio::time::timeout( + std::time::Duration::from_millis(200), + TcpStream::connect(addr), + ) + .await; + // Either the timeout fires or connect returns an error — both are + // acceptable evidence the listener stopped accepting. + assert!(res.is_err() || res.unwrap().is_err()); + } + + #[test] + fn create_callback_objref_uses_callback_iid_and_port() { + // Pure-codec test (no listener required for this part). + let identities = fixed_identities(); + let exporter_no_listener_objref = ComObjRefBuilder::create_standard_objref( + INMX_SVC_CALLBACK_IID, + 0x280, + 5, + identities.oxid, + identities.oid, + identities.callback_ipid, + &["host[12345]"], + ); + // 8..24 is the IID + assert_eq!( + &exporter_no_listener_objref[8..24], + INMX_SVC_CALLBACK_IID.as_bytes() + ); + } + + #[test] + fn rem_query_interface_response_inspects_offset_60() { + // Build a request stub with a known IID at offset 60..76 and verify + // the encoder reads it back. + let mut stub = vec![0u8; 76]; + stub[60..76].copy_from_slice(INMX_SVC_CALLBACK_IID.as_bytes()); + let identities = fixed_identities(); + let (tx, _rx) = mpsc::unbounded_channel(); + let body = encode_rem_query_interface_response(&stub, &identities, &tx); + // hresult is at offset OrpcThat(8) + referent(4) + max_count(4) = 16. + let hresult = i32::from_le_bytes([body[16], body[17], body[18], body[19]]); + assert_eq!(hresult, S_OK); + + // Now with an unknown IID — should return E_NOINTERFACE. + let mut stub2 = vec![0u8; 76]; + stub2[60..76].copy_from_slice(&[0x99; 16]); + let (tx2, _rx2) = mpsc::unbounded_channel(); + let body2 = encode_rem_query_interface_response(&stub2, &identities, &tx2); + let hresult2 = i32::from_le_bytes([body2[16], body2[17], body2[18], body2[19]]); + assert_eq!(hresult2, E_NOINTERFACE); + } +} diff --git a/rust/crates/mxaccess-callback/src/lib.rs b/rust/crates/mxaccess-callback/src/lib.rs index 8566dd9..cc11588 100644 --- a/rust/crates/mxaccess-callback/src/lib.rs +++ b/rust/crates/mxaccess-callback/src/lib.rs @@ -1,6 +1,9 @@ //! `mxaccess-callback` — `INmxSvcCallback` RPC server (the callback exporter). //! -//! M0 stub. Real implementation lands in M2 — see `design/60-roadmap.md`. +//! M2 wave 3 landed: the [`exporter`] module ports +//! `src/MxNativeClient/ManagedCallbackExporter.cs` to a tokio-based TCP +//! server that serves `IRemUnknown` and `INmxSvcCallback` opnums and emits +//! typed [`exporter::CallbackEvent`]s for diagnostic observation. //! //! Opnums (verified against `src/MxNativeClient/NmxSvcCallbackMessages.cs:11-12`): //! - `3` `DataReceived(bufferSize: i32, dataBuffer: sbyte[bufferSize]) -> hresult` @@ -10,3 +13,7 @@ //! server-side handshake against our exported OBJREF (DoD condition for M2). #![forbid(unsafe_code)] + +pub mod exporter; + +pub use exporter::{CallbackEvent, CallbackExporter, ExporterIdentities, IUNKNOWN_IID};