diff --git a/clients/rust/Cargo.lock b/clients/rust/Cargo.lock index 78d5aa4..a0ead13 100644 --- a/clients/rust/Cargo.lock +++ b/clients/rust/Cargo.lock @@ -145,6 +145,16 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cc" +version = "1.2.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -225,6 +235,12 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "fixedbitset" version = "0.5.7" @@ -258,6 +274,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -277,11 +304,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", + "futures-macro", "futures-task", "pin-project-lite", "slab", ] +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.4.2" @@ -537,11 +576,14 @@ checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" name = "mxgateway-client" version = "0.1.0" dependencies = [ + "futures-core", + "futures-util", "prost", "prost-types", "serde_json", "thiserror", "tokio", + "tokio-stream", "tonic", "tonic-build", ] @@ -551,8 +593,11 @@ name = "mxgw-cli" version = "0.1.0" dependencies = [ "clap", + "futures-util", "mxgateway-client", + "serde", "serde_json", + "tokio", ] [[package]] @@ -724,6 +769,20 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustix" version = "1.1.4" @@ -737,6 +796,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c2c118cb077cca2822033836dfb1b975355dfb784b5e8da48f7b6c5db74e60e" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "semver" version = "1.0.28" @@ -750,6 +844,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", + "serde_derive", ] [[package]] @@ -785,6 +880,12 @@ dependencies = [ "zmij", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "slab" version = "0.4.12" @@ -823,6 +924,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.117" @@ -847,7 +954,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom", + "getrandom 0.4.2", "once_cell", "rustix", "windows-sys 0.61.2", @@ -899,6 +1006,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.18" @@ -945,6 +1062,7 @@ dependencies = [ "prost", "socket2 0.5.10", "tokio", + "tokio-rustls", "tokio-stream", "tower", "tower-layer", @@ -1046,6 +1164,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "utf8parse" version = "0.2.2" @@ -1301,6 +1425,12 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zmij" version = "1.0.21" diff --git a/clients/rust/Cargo.toml b/clients/rust/Cargo.toml index 6256c10..695d366 100644 --- a/clients/rust/Cargo.toml +++ b/clients/rust/Cargo.toml @@ -16,24 +16,31 @@ publish = false [workspace.dependencies] clap = { version = "4.5.53", features = ["derive"] } +futures-core = "0.3.31" +futures-util = "0.3.31" prost = "0.13.5" prost-types = "0.13.5" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" thiserror = "2.0.17" -tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread"] } -tonic = { version = "0.13.1", features = ["transport"] } +tokio = { version = "1.48.0", features = ["macros", "rt-multi-thread", "sync", "time"] } +tokio-stream = { version = "0.1.17", features = ["net"] } +tonic = { version = "0.13.1", features = ["transport", "tls-ring"] } tonic-build = "0.13.1" [dependencies] +futures-core = { workspace = true } +futures-util = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } thiserror = { workspace = true } +tokio = { workspace = true } tonic = { workspace = true } [dev-dependencies] serde_json = { workspace = true } tokio = { workspace = true } +tokio-stream = { workspace = true } [build-dependencies] tonic-build = { workspace = true } diff --git a/clients/rust/README.md b/clients/rust/README.md index 4c84c4f..9458619 100644 --- a/clients/rust/README.md +++ b/clients/rust/README.md @@ -1,7 +1,8 @@ # Rust Client Workspace The Rust client workspace contains the MXAccess Gateway client library, a -test CLI, and scaffold tests for generated contract wiring. The library uses +test CLI, and tests for generated contract wiring plus wrapper behavior. The +library uses the shared protobuf inputs documented in `../../docs/client-proto-generation.md` so the Rust bindings compile against the same public gateway and worker contracts as the server. @@ -31,6 +32,7 @@ Run the Rust workspace checks from `clients/rust`: cargo fmt --all --check cargo test --workspace cargo check --workspace +cargo clippy --workspace --all-targets -- -D warnings ``` The build script uses `protoc` from `PATH` or the Windows path recorded in @@ -38,13 +40,48 @@ The build script uses `protoc` from `PATH` or the Windows path recorded in ## CLI -The scaffold CLI exposes version information: +The CLI exposes version, session, command, event stream, write, and smoke +commands over the same client wrapper used by tests: ```powershell cargo run -p mxgw-cli -- version --json +cargo run -p mxgw-cli -- open-session --endpoint http://localhost:5000 --api-key-env MXGATEWAY_API_KEY --json +cargo run -p mxgw-cli -- register --session-id --client-name mxgw-rust-cli --json +cargo run -p mxgw-cli -- add-item --session-id --server-handle 1 --item TestChildObject.TestInt --json +cargo run -p mxgw-cli -- advise --session-id --server-handle 1 --item-handle 1 --json +cargo run -p mxgw-cli -- stream-events --session-id --max-events 1 --json +cargo run -p mxgw-cli -- write --session-id --server-handle 1 --item-handle 1 --value-type int32 --value 123 --json ``` -Additional commands are implemented with the client/session wrapper work. +Use `--tls`, `--ca-file`, and `--server-name-override` for TLS endpoints. The +CLI reads the API key from `--api-key` or from `--api-key-env`, which defaults +to `MXGATEWAY_API_KEY`. API keys are redacted by the library option and secret +types. + +## Library Surface + +`ClientOptions` configures endpoint, API key, plaintext or TLS transport, +timeouts, custom CA files, and server name override. `GatewayClient::connect` +creates an authenticated `tonic` client and attaches `authorization: Bearer +` metadata to unary and streaming calls. + +`GatewayClient` exposes raw generated calls through `open_session_raw`, +`close_session_raw`, `invoke_raw`, `stream_events`, and `raw_client`. The +session helpers keep MXAccess handles visible: + +```rust +let session = client.open_session(request).await?; +let server_handle = session.register("mxgw-rust").await?; +let item_handle = session.add_item(server_handle, "TestChildObject.TestInt").await?; +session.advise(server_handle, item_handle).await?; +let mut events = session.events().await?; +session.close().await?; +``` + +`MxValue`, `MxArrayValue`, and `MxStatus` wrap generated protobuf messages while +preserving the raw message for parity diagnostics. Command replies whose +protocol status is not `PROTOCOL_STATUS_CODE_OK` become `Error::Command` and +retain the raw `MxCommandReply`. ## Related Documentation diff --git a/clients/rust/build.rs b/clients/rust/build.rs index 33d38fb..2e17946 100644 --- a/clients/rust/build.rs +++ b/clients/rust/build.rs @@ -19,7 +19,7 @@ fn main() -> Result<(), Box> { println!("cargo:rerun-if-changed={}", worker_proto.display()); tonic_build::configure() - .build_server(false) + .build_server(true) .build_client(true) .file_descriptor_set_path(descriptor_path) .compile_protos( diff --git a/clients/rust/crates/mxgw-cli/Cargo.toml b/clients/rust/crates/mxgw-cli/Cargo.toml index 5691ac7..9145eca 100644 --- a/clients/rust/crates/mxgw-cli/Cargo.toml +++ b/clients/rust/crates/mxgw-cli/Cargo.toml @@ -10,5 +10,8 @@ path = "src/main.rs" [dependencies] clap = { workspace = true } +futures-util = { workspace = true } mxgateway-client = { path = "../.." } +serde = { workspace = true } serde_json = { workspace = true } +tokio = { workspace = true } diff --git a/clients/rust/crates/mxgw-cli/src/main.rs b/clients/rust/crates/mxgw-cli/src/main.rs index 4accfdf..35e601e 100644 --- a/clients/rust/crates/mxgw-cli/src/main.rs +++ b/clients/rust/crates/mxgw-cli/src/main.rs @@ -1,8 +1,20 @@ +use std::env; +use std::path::PathBuf; use std::process::ExitCode; +use std::time::Duration; -use clap::{Parser, Subcommand}; -use mxgateway_client::{CLIENT_VERSION, GATEWAY_PROTOCOL_VERSION, WORKER_PROTOCOL_VERSION}; +use clap::{Args, Parser, Subcommand, ValueEnum}; +use futures_util::StreamExt; +use mxgateway_client::generated::mxaccess_gateway::v1::{ + CloseSessionRequest, MxCommand, MxCommandKind, MxCommandRequest, OpenSessionRequest, + PingCommand, StreamEventsRequest, +}; +use mxgateway_client::{ + ApiKey, ClientOptions, Error, GatewayClient, MxValue, CLIENT_VERSION, GATEWAY_PROTOCOL_VERSION, + WORKER_PROTOCOL_VERSION, +}; use serde_json::json; +use serde_json::Value; #[derive(Debug, Parser)] #[command(name = "mxgw")] @@ -18,30 +30,428 @@ enum Command { #[arg(long)] json: bool, }, + Ping { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long, default_value = "ping")] + message: String, + #[arg(long)] + json: bool, + }, + OpenSession { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long, default_value = "mxgw-rust-cli")] + client_name: String, + #[arg(long)] + json: bool, + }, + CloseSession { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + json: bool, + }, + Register { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long, default_value = "mxgw-rust-cli")] + client_name: String, + #[arg(long)] + json: bool, + }, + AddItem { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + server_handle: i32, + #[arg(long)] + item: String, + #[arg(long)] + json: bool, + }, + Advise { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + server_handle: i32, + #[arg(long)] + item_handle: i32, + #[arg(long)] + json: bool, + }, + StreamEvents { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long, default_value_t = 0)] + after_worker_sequence: u64, + #[arg(long, default_value_t = 1)] + max_events: usize, + #[arg(long)] + json: bool, + }, + Write { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + server_handle: i32, + #[arg(long)] + item_handle: i32, + #[arg(long, value_enum)] + value_type: CliValueType, + #[arg(long)] + value: String, + #[arg(long, default_value_t = 0)] + user_id: i32, + #[arg(long)] + json: bool, + }, + Write2 { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + session_id: String, + #[arg(long)] + server_handle: i32, + #[arg(long)] + item_handle: i32, + #[arg(long, value_enum)] + value_type: CliValueType, + #[arg(long)] + value: String, + #[arg(long)] + timestamp: String, + #[arg(long, default_value_t = 0)] + user_id: i32, + #[arg(long)] + json: bool, + }, + Smoke { + #[command(flatten)] + connection: ConnectionArgs, + #[arg(long)] + item: String, + #[arg(long, default_value = "mxgw-rust-smoke")] + client_name: String, + #[arg(long)] + json: bool, + }, } -fn main() -> ExitCode { +#[derive(Debug, Args, Clone)] +struct ConnectionArgs { + #[arg(long, default_value = "http://127.0.0.1:5000")] + endpoint: String, + #[arg(long)] + api_key: Option, + #[arg(long, default_value = "MXGATEWAY_API_KEY")] + api_key_env: String, + #[arg(long)] + plaintext: bool, + #[arg(long)] + tls: bool, + #[arg(long)] + ca_file: Option, + #[arg(long)] + server_name_override: Option, + #[arg(long, default_value_t = 10)] + connect_timeout_seconds: u64, + #[arg(long, default_value_t = 30)] + call_timeout_seconds: u64, +} + +impl ConnectionArgs { + fn options(&self) -> ClientOptions { + let mut options = ClientOptions::new(self.endpoint.clone()) + .with_plaintext(!self.tls || self.plaintext) + .with_connect_timeout(Duration::from_secs(self.connect_timeout_seconds)) + .with_call_timeout(Duration::from_secs(self.call_timeout_seconds)); + + if let Some(api_key) = self + .api_key + .clone() + .or_else(|| env::var(&self.api_key_env).ok()) + .filter(|value| !value.is_empty()) + { + options = options.with_api_key(ApiKey::new(api_key)); + } + if let Some(ca_file) = &self.ca_file { + options = options.with_ca_file(ca_file); + } + if let Some(server_name_override) = &self.server_name_override { + options = options.with_server_name_override(server_name_override); + } + + options + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum CliValueType { + Bool, + Int32, + Int64, + Float, + Double, + String, +} + +#[tokio::main] +async fn main() -> ExitCode { let cli = Cli::parse(); - run(cli); - ExitCode::SUCCESS + match run(cli).await { + Ok(()) => ExitCode::SUCCESS, + Err(error) => { + eprintln!("{error}"); + ExitCode::FAILURE + } + } } -fn run(cli: Cli) { +async fn run(cli: Cli) -> Result<(), Error> { match cli.command { Command::Version { json } => print_version(json), + Command::Ping { + connection, + message, + json, + } => { + let client = connect(connection).await?; + let reply = client + .invoke(MxCommandRequest { + client_correlation_id: "rust-cli-ping".to_owned(), + command: Some(MxCommand { + kind: MxCommandKind::Ping as i32, + payload: Some(mxgateway_client::generated::mxaccess_gateway::v1::mx_command::Payload::Ping( + PingCommand { message }, + )), + }), + ..MxCommandRequest::default() + }) + .await?; + print_command_reply("ping", &reply, json); + } + Command::OpenSession { + connection, + client_name, + json, + } => { + let client = connect(connection).await?; + let reply = client + .open_session_raw(OpenSessionRequest { + client_session_name: client_name, + ..OpenSessionRequest::default() + }) + .await?; + if json { + println!( + "{}", + json!({ + "sessionId": reply.session_id, + "backendName": reply.backend_name, + "gatewayProtocolVersion": reply.gateway_protocol_version, + "workerProtocolVersion": reply.worker_protocol_version, + }) + ); + } else { + println!("{}", reply.session_id); + } + } + Command::CloseSession { + connection, + session_id, + json, + } => { + let client = connect(connection).await?; + let reply = client + .close_session_raw(CloseSessionRequest { + session_id, + client_correlation_id: "rust-cli-close-session".to_owned(), + }) + .await?; + if json { + println!("{}", json!({ "sessionId": reply.session_id })); + } else { + println!("closed {}", reply.session_id); + } + } + Command::Register { + connection, + session_id, + client_name, + json, + } => { + let session = session_for(connection, session_id).await?; + let server_handle = session.register(&client_name).await?; + print_handle("serverHandle", server_handle, json); + } + Command::AddItem { + connection, + session_id, + server_handle, + item, + json, + } => { + let session = session_for(connection, session_id).await?; + let item_handle = session.add_item(server_handle, &item).await?; + print_handle("itemHandle", item_handle, json); + } + Command::Advise { + connection, + session_id, + server_handle, + item_handle, + json, + } => { + let session = session_for(connection, session_id).await?; + session.advise(server_handle, item_handle).await?; + print_ok("advise", json); + } + Command::StreamEvents { + connection, + session_id, + after_worker_sequence, + max_events, + json, + } => { + let client = connect(connection).await?; + let mut stream = client + .stream_events(StreamEventsRequest { + session_id, + after_worker_sequence, + }) + .await?; + let mut events = Vec::new(); + while events.len() < max_events { + let Some(event) = stream.next().await else { + break; + }; + events.push(event?); + } + if json { + println!("{}", json!({ "eventCount": events.len() })); + } else { + for event in events { + println!("{} {}", event.worker_sequence, event.family); + } + } + } + Command::Write { + connection, + session_id, + server_handle, + item_handle, + value_type, + value, + user_id, + json, + } => { + let session = session_for(connection, session_id).await?; + session + .write( + server_handle, + item_handle, + parse_value(value_type, &value)?, + user_id, + ) + .await?; + print_ok("write", json); + } + Command::Write2 { + connection, + session_id, + server_handle, + item_handle, + value_type, + value, + timestamp, + user_id, + json, + } => { + let session = session_for(connection, session_id).await?; + session + .write2( + server_handle, + item_handle, + parse_value(value_type, &value)?, + MxValue::string(timestamp), + user_id, + ) + .await?; + print_ok("write2", json); + } + Command::Smoke { + connection, + item, + client_name, + json, + } => { + let client = connect(connection).await?; + let session = client + .open_session(OpenSessionRequest { + client_session_name: client_name.clone(), + ..OpenSessionRequest::default() + }) + .await?; + let result = async { + let server_handle = session.register(&client_name).await?; + let item_handle = session.add_item(server_handle, &item).await?; + session.advise(server_handle, item_handle).await?; + Ok::<_, Error>((server_handle, item_handle)) + } + .await; + let close_result = session.close().await; + let (server_handle, item_handle) = result?; + close_result?; + if json { + println!( + "{}", + json!({ + "sessionId": session.id(), + "serverHandle": server_handle, + "itemHandle": item_handle, + "closed": true, + }) + ); + } else { + println!( + "session {} registered server {server_handle}, item {item_handle}, closed", + session.id() + ); + } + } } + + Ok(()) +} + +async fn connect(connection: ConnectionArgs) -> Result { + GatewayClient::connect(connection.options()).await +} + +async fn session_for( + connection: ConnectionArgs, + session_id: String, +) -> Result { + let client = connect(connection).await?; + Ok(client.session(session_id)) } fn print_version(use_json: bool) { if use_json { - println!( - "{}", - json!({ - "clientVersion": CLIENT_VERSION, - "gatewayProtocolVersion": GATEWAY_PROTOCOL_VERSION, - "workerProtocolVersion": WORKER_PROTOCOL_VERSION, - }) - ); + println!("{}", version_json()); return; } @@ -50,6 +460,73 @@ fn print_version(use_json: bool) { println!("worker protocol {WORKER_PROTOCOL_VERSION}"); } +fn version_json() -> Value { + json!({ + "clientVersion": CLIENT_VERSION, + "gatewayProtocolVersion": GATEWAY_PROTOCOL_VERSION, + "workerProtocolVersion": WORKER_PROTOCOL_VERSION, + }) +} + +fn print_command_reply( + operation: &str, + reply: &mxgateway_client::generated::mxaccess_gateway::v1::MxCommandReply, + use_json: bool, +) { + if use_json { + println!( + "{}", + json!({ + "operation": operation, + "sessionId": reply.session_id, + "correlationId": reply.correlation_id, + "kind": reply.kind, + }) + ); + } else { + println!("{operation} completed"); + } +} + +fn print_handle(name: &str, handle: i32, use_json: bool) { + if use_json { + println!("{}", json!({ name: handle })); + } else { + println!("{handle}"); + } +} + +fn print_ok(operation: &str, use_json: bool) { + if use_json { + println!("{}", json!({ "operation": operation, "ok": true })); + } else { + println!("{operation} completed"); + } +} + +fn parse_value(value_type: CliValueType, value: &str) -> Result { + let parsed = match value_type { + CliValueType::Bool => MxValue::bool(parse_cli_value(value)?), + CliValueType::Int32 => MxValue::int32(parse_cli_value(value)?), + CliValueType::Int64 => MxValue::int64(parse_cli_value(value)?), + CliValueType::Float => MxValue::float(parse_cli_value(value)?), + CliValueType::Double => MxValue::double(parse_cli_value(value)?), + CliValueType::String => MxValue::string(value), + }; + Ok(parsed) +} + +fn parse_cli_value(value: &str) -> Result +where + T: std::str::FromStr, + T::Err: std::fmt::Display, +{ + value.parse::().map_err(|source| Error::InvalidArgument { + name: "value".to_owned(), + detail: source.to_string(), + }) +} + #[cfg(test)] mod tests { use clap::Parser; @@ -61,4 +538,31 @@ mod tests { let parsed = Cli::try_parse_from(["mxgw", "version", "--json"]); assert!(parsed.is_ok()); } + + #[test] + fn parses_write_command() { + let parsed = Cli::try_parse_from([ + "mxgw", + "write", + "--session-id", + "session-1", + "--server-handle", + "12", + "--item-handle", + "34", + "--value-type", + "int32", + "--value", + "123", + ]); + assert!(parsed.is_ok()); + } + + #[test] + fn version_json_output_has_protocol_versions() { + let value = super::version_json(); + + assert_eq!(value["gatewayProtocolVersion"], 1); + assert_eq!(value["workerProtocolVersion"], 1); + } } diff --git a/clients/rust/src/auth.rs b/clients/rust/src/auth.rs index 6af8123..b63c51f 100644 --- a/clients/rust/src/auth.rs +++ b/clients/rust/src/auth.rs @@ -1,5 +1,9 @@ use std::fmt; +use tonic::metadata::MetadataValue; +use tonic::service::Interceptor; +use tonic::{Request, Status}; + /// API key wrapper that avoids exposing raw credentials in formatted output. #[derive(Clone, Eq, PartialEq)] pub struct ApiKey(String); @@ -28,3 +32,56 @@ impl fmt::Display for ApiKey { formatter.write_str("") } } + +/// `tonic` interceptor that attaches gateway API key metadata. +#[derive(Clone, Debug, Default)] +pub struct AuthInterceptor { + api_key: Option, +} + +impl AuthInterceptor { + pub fn new(api_key: Option) -> Self { + Self { api_key } + } +} + +impl Interceptor for AuthInterceptor { + fn call(&mut self, mut request: Request<()>) -> Result, Status> { + if let Some(api_key) = &self.api_key { + let header_value = format!("Bearer {}", api_key.expose_secret()) + .parse::>() + .map_err(|_| Status::unauthenticated("invalid API key metadata"))?; + request.metadata_mut().insert("authorization", header_value); + } + + Ok(request) + } +} + +#[cfg(test)] +mod tests { + use tonic::service::Interceptor; + use tonic::Request; + + use super::{ApiKey, AuthInterceptor}; + + #[test] + fn api_key_debug_is_redacted() { + let key = ApiKey::new("mxgw_visible_secret"); + + assert_eq!(format!("{key:?}"), "ApiKey(\"\")"); + assert!(!format!("{key:?}").contains("visible_secret")); + assert_eq!(key.to_string(), ""); + } + + #[test] + fn interceptor_attaches_bearer_metadata() { + let mut interceptor = AuthInterceptor::new(Some(ApiKey::new("mxgw_fixture_secret"))); + let request = interceptor.call(Request::new(())).unwrap(); + + assert_eq!( + request.metadata().get("authorization").unwrap(), + "Bearer mxgw_fixture_secret" + ); + } +} diff --git a/clients/rust/src/client.rs b/clients/rust/src/client.rs index d09f929..29fdc50 100644 --- a/clients/rust/src/client.rs +++ b/clients/rust/src/client.rs @@ -1,30 +1,123 @@ -use tonic::transport::Channel; +use std::fs; -use crate::error::Error; +use tonic::codegen::InterceptedService; +use tonic::transport::{Certificate, Channel, ClientTlsConfig}; +use tonic::Request; + +use crate::auth::AuthInterceptor; +use crate::error::{ensure_command_success, Error}; use crate::generated::mxaccess_gateway::v1::mx_access_gateway_client::MxAccessGatewayClient; +use crate::generated::mxaccess_gateway::v1::{ + CloseSessionReply, CloseSessionRequest, MxCommandReply, MxCommandRequest, MxEvent, + OpenSessionReply, OpenSessionRequest, StreamEventsRequest, +}; use crate::options::ClientOptions; +use crate::session::Session; + +pub type RawGatewayClient = MxAccessGatewayClient>; +pub type EventStream = + std::pin::Pin> + Send + 'static>>; /// Thin owner for the generated gateway client. +#[derive(Clone)] pub struct GatewayClient { - inner: MxAccessGatewayClient, + inner: RawGatewayClient, + call_timeout: std::time::Duration, } impl GatewayClient { pub async fn connect(options: ClientOptions) -> Result { - let endpoint = Channel::from_shared(options.endpoint().to_owned()).map_err(|source| { - Error::InvalidEndpoint { - endpoint: options.endpoint().to_owned(), - detail: source.to_string(), + let mut endpoint = + Channel::from_shared(options.endpoint().to_owned()).map_err(|source| { + Error::InvalidEndpoint { + endpoint: options.endpoint().to_owned(), + detail: source.to_string(), + } + })?; + endpoint = endpoint.connect_timeout(options.connect_timeout()); + + if !options.plaintext() { + let mut tls = ClientTlsConfig::new(); + if let Some(server_name) = options.server_name_override() { + tls = tls.domain_name(server_name.to_owned()); } - })?; + if let Some(ca_file) = options.ca_file() { + let certificate = fs::read(ca_file).map_err(|source| Error::InvalidEndpoint { + endpoint: options.endpoint().to_owned(), + detail: format!("failed to read CA file {}: {source}", ca_file.display()), + })?; + tls = tls.ca_certificate(Certificate::from_pem(certificate)); + } + endpoint = endpoint.tls_config(tls)?; + } + let channel = endpoint.connect().await?; + let interceptor = AuthInterceptor::new(options.api_key().cloned()); Ok(Self { - inner: MxAccessGatewayClient::new(channel), + inner: MxAccessGatewayClient::with_interceptor(channel, interceptor), + call_timeout: options.call_timeout(), }) } - pub fn into_inner(self) -> MxAccessGatewayClient { + pub fn raw_client(&mut self) -> &mut RawGatewayClient { + &mut self.inner + } + + pub fn into_inner(self) -> RawGatewayClient { self.inner } + + pub fn session(&self, session_id: impl Into) -> Session { + Session::new(session_id, self.clone()) + } + + pub async fn open_session_raw( + &self, + request: OpenSessionRequest, + ) -> Result { + let mut client = self.inner.clone(); + let response = client.open_session(self.unary_request(request)).await?; + Ok(response.into_inner()) + } + + pub async fn open_session(&self, request: OpenSessionRequest) -> Result { + let reply = self.open_session_raw(request).await?; + Ok(Session::new(reply.session_id, self.clone())) + } + + pub async fn close_session_raw( + &self, + request: CloseSessionRequest, + ) -> Result { + let mut client = self.inner.clone(); + let response = client.close_session(self.unary_request(request)).await?; + Ok(response.into_inner()) + } + + pub async fn invoke_raw(&self, request: MxCommandRequest) -> Result { + let mut client = self.inner.clone(); + let response = client.invoke(self.unary_request(request)).await?; + Ok(response.into_inner()) + } + + pub async fn invoke(&self, request: MxCommandRequest) -> Result { + ensure_command_success(self.invoke_raw(request).await?) + } + + pub async fn stream_events(&self, request: StreamEventsRequest) -> Result { + let mut client = self.inner.clone(); + let response = client.stream_events(self.unary_request(request)).await?; + let stream = futures_util::StreamExt::map(response.into_inner(), |result| { + result.map_err(Error::from) + }); + + Ok(Box::pin(stream)) + } + + fn unary_request(&self, message: T) -> Request { + let mut request = Request::new(message); + request.set_timeout(self.call_timeout); + request + } } diff --git a/clients/rust/src/error.rs b/clients/rust/src/error.rs index 3aa1f2a..16b28c2 100644 --- a/clients/rust/src/error.rs +++ b/clients/rust/src/error.rs @@ -1,13 +1,161 @@ use thiserror::Error as ThisError; +use tonic::Code; + +use crate::generated::mxaccess_gateway::v1::{MxCommandReply, ProtocolStatusCode}; #[derive(Debug, ThisError)] pub enum Error { #[error("invalid gateway endpoint `{endpoint}`: {detail}")] InvalidEndpoint { endpoint: String, detail: String }, + #[error("invalid argument `{name}`: {detail}")] + InvalidArgument { name: String, detail: String }, + #[error("gateway transport error: {0}")] Transport(#[from] tonic::transport::Error), + #[error("authentication failed: {message}")] + Authentication { + message: String, + #[source] + status: Box, + }, + + #[error("authorization failed: {message}")] + Authorization { + message: String, + #[source] + status: Box, + }, + + #[error("gateway call timed out: {message}")] + Timeout { + message: String, + #[source] + status: Box, + }, + + #[error("gateway call cancelled: {message}")] + Cancelled { + message: String, + #[source] + status: Box, + }, + #[error("gateway status error: {0}")] - Status(#[from] tonic::Status), + Status(Box), + + #[error("gateway command failed: {0}")] + Command(#[from] Box), +} + +#[derive(Clone, Debug)] +pub struct CommandError { + reply: MxCommandReply, +} + +impl CommandError { + pub fn new(reply: MxCommandReply) -> Self { + Self { reply } + } + + pub fn reply(&self) -> &MxCommandReply { + &self.reply + } + + pub fn into_reply(self) -> MxCommandReply { + self.reply + } +} + +impl std::fmt::Display for CommandError { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let status = self.reply.protocol_status.as_ref(); + let code = status + .and_then(|status| ProtocolStatusCode::try_from(status.code).ok()) + .unwrap_or(ProtocolStatusCode::Unspecified); + let message = status.map(|status| status.message.as_str()).unwrap_or(""); + + if message.is_empty() { + write!(formatter, "{code:?}") + } else { + write!(formatter, "{code:?}: {message}") + } + } +} + +impl std::error::Error for CommandError {} + +impl From for Error { + fn from(status: tonic::Status) -> Self { + let message = redact_credentials(status.message()); + match status.code() { + Code::Unauthenticated => Self::Authentication { + message, + status: Box::new(status), + }, + Code::PermissionDenied => Self::Authorization { + message, + status: Box::new(status), + }, + Code::DeadlineExceeded => Self::Timeout { + message, + status: Box::new(status), + }, + Code::Cancelled => Self::Cancelled { + message, + status: Box::new(status), + }, + _ => Self::Status(Box::new(status)), + } + } +} + +pub fn ensure_command_success(reply: MxCommandReply) -> Result { + let code = reply + .protocol_status + .as_ref() + .and_then(|status| ProtocolStatusCode::try_from(status.code).ok()) + .unwrap_or(ProtocolStatusCode::Unspecified); + + if code == ProtocolStatusCode::Ok { + Ok(reply) + } else { + Err(Box::new(CommandError::new(reply)).into()) + } +} + +fn redact_credentials(message: &str) -> String { + message + .split_whitespace() + .map(|part| { + if part.starts_with("mxgw_") || part.eq_ignore_ascii_case("bearer") { + "" + } else { + part + } + }) + .collect::>() + .join(" ") +} + +#[cfg(test)] +mod tests { + use tonic::{Code, Status}; + + use super::Error; + + #[test] + fn classifies_authentication_status() { + let error = Error::from(Status::new( + Code::Unauthenticated, + "invalid API key mxgw_visible_secret", + )); + + let message = error.to_string(); + + assert!(matches!(error, Error::Authentication { .. })); + assert!(message.contains("")); + assert!(!message.contains("visible_secret")); + } } diff --git a/clients/rust/src/lib.rs b/clients/rust/src/lib.rs index aac1c91..d1f4b55 100644 --- a/clients/rust/src/lib.rs +++ b/clients/rust/src/lib.rs @@ -13,9 +13,10 @@ pub mod session; pub mod value; pub mod version; -pub use auth::ApiKey; -pub use client::GatewayClient; -pub use error::Error; +pub use auth::{ApiKey, AuthInterceptor}; +pub use client::{EventStream, GatewayClient}; +pub use error::{CommandError, Error}; pub use options::ClientOptions; pub use session::Session; +pub use value::{MxArrayProjection, MxArrayValue, MxStatus, MxValue, MxValueProjection}; pub use version::{CLIENT_VERSION, GATEWAY_PROTOCOL_VERSION, WORKER_PROTOCOL_VERSION}; diff --git a/clients/rust/src/options.rs b/clients/rust/src/options.rs index 38013ff..45e8466 100644 --- a/clients/rust/src/options.rs +++ b/clients/rust/src/options.rs @@ -1,4 +1,6 @@ use std::fmt; +use std::path::PathBuf; +use std::time::Duration; use crate::auth::ApiKey; @@ -7,6 +9,10 @@ pub struct ClientOptions { endpoint: String, api_key: Option, plaintext: bool, + ca_file: Option, + server_name_override: Option, + connect_timeout: Duration, + call_timeout: Duration, } impl ClientOptions { @@ -15,6 +21,10 @@ impl ClientOptions { endpoint: endpoint.into(), api_key: None, plaintext: true, + ca_file: None, + server_name_override: None, + connect_timeout: Duration::from_secs(10), + call_timeout: Duration::from_secs(30), } } @@ -23,6 +33,31 @@ impl ClientOptions { self } + pub fn with_plaintext(mut self, plaintext: bool) -> Self { + self.plaintext = plaintext; + self + } + + pub fn with_ca_file(mut self, ca_file: impl Into) -> Self { + self.ca_file = Some(ca_file.into()); + self + } + + pub fn with_server_name_override(mut self, server_name_override: impl Into) -> Self { + self.server_name_override = Some(server_name_override.into()); + self + } + + pub fn with_connect_timeout(mut self, connect_timeout: Duration) -> Self { + self.connect_timeout = connect_timeout; + self + } + + pub fn with_call_timeout(mut self, call_timeout: Duration) -> Self { + self.call_timeout = call_timeout; + self + } + pub fn endpoint(&self) -> &str { &self.endpoint } @@ -34,6 +69,22 @@ impl ClientOptions { pub fn plaintext(&self) -> bool { self.plaintext } + + pub fn ca_file(&self) -> Option<&PathBuf> { + self.ca_file.as_ref() + } + + pub fn server_name_override(&self) -> Option<&str> { + self.server_name_override.as_deref() + } + + pub fn connect_timeout(&self) -> Duration { + self.connect_timeout + } + + pub fn call_timeout(&self) -> Duration { + self.call_timeout + } } impl Default for ClientOptions { @@ -49,6 +100,27 @@ impl fmt::Debug for ClientOptions { .field("endpoint", &self.endpoint) .field("api_key", &self.api_key.as_ref().map(|_| "")) .field("plaintext", &self.plaintext) + .field("ca_file", &self.ca_file) + .field("server_name_override", &self.server_name_override) + .field("connect_timeout", &self.connect_timeout) + .field("call_timeout", &self.call_timeout) .finish() } } + +#[cfg(test)] +mod tests { + use super::ClientOptions; + use crate::auth::ApiKey; + + #[test] + fn debug_redacts_api_key() { + let options = + ClientOptions::new("http://localhost:5000").with_api_key(ApiKey::new("mxgw_secret")); + + let debug = format!("{options:?}"); + + assert!(debug.contains("")); + assert!(!debug.contains("mxgw_secret")); + } +} diff --git a/clients/rust/src/session.rs b/clients/rust/src/session.rs index a421480..e934478 100644 --- a/clients/rust/src/session.rs +++ b/clients/rust/src/session.rs @@ -1,15 +1,234 @@ +use crate::client::{EventStream, GatewayClient}; +use crate::error::Error; +use crate::generated::mxaccess_gateway::v1::mx_command::Payload; +use crate::generated::mxaccess_gateway::v1::mx_command_reply; +use crate::generated::mxaccess_gateway::v1::{ + AddItem2Command, AddItemCommand, AdviseCommand, CloseSessionRequest, MxCommand, MxCommandKind, + MxCommandReply, MxCommandRequest, MxValue as ProtoMxValue, OpenSessionRequest, RegisterCommand, + StreamEventsRequest, Write2Command, WriteCommand, +}; +use crate::value::MxValue; + /// Session identifier returned by the gateway. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone)] pub struct Session { id: String, + client: GatewayClient, } impl Session { - pub fn new(id: impl Into) -> Self { - Self { id: id.into() } + pub(crate) fn new(id: impl Into, client: GatewayClient) -> Self { + Self { + id: id.into(), + client, + } } pub fn id(&self) -> &str { &self.id } + + pub async fn open(client: GatewayClient, client_session_name: &str) -> Result { + client + .open_session(OpenSessionRequest { + client_session_name: client_session_name.to_owned(), + ..OpenSessionRequest::default() + }) + .await + } + + pub async fn close(&self) -> Result<(), Error> { + self.client + .close_session_raw(CloseSessionRequest { + session_id: self.id.clone(), + client_correlation_id: "rust-client-close-session".to_owned(), + }) + .await?; + Ok(()) + } + + pub async fn register(&self, client_name: &str) -> Result { + let reply = self + .invoke( + MxCommandKind::Register, + Payload::Register(RegisterCommand { + client_name: client_name.to_owned(), + }), + ) + .await?; + + Ok(register_server_handle(&reply)) + } + + pub async fn add_item(&self, server_handle: i32, item_definition: &str) -> Result { + let reply = self + .invoke( + MxCommandKind::AddItem, + Payload::AddItem(AddItemCommand { + server_handle, + item_definition: item_definition.to_owned(), + }), + ) + .await?; + + Ok(add_item_handle(&reply)) + } + + pub async fn add_item2( + &self, + server_handle: i32, + item_definition: &str, + item_context: &str, + ) -> Result { + let reply = self + .invoke( + MxCommandKind::AddItem2, + Payload::AddItem2(AddItem2Command { + server_handle, + item_definition: item_definition.to_owned(), + item_context: item_context.to_owned(), + }), + ) + .await?; + + Ok(add_item2_handle(&reply)) + } + + pub async fn advise(&self, server_handle: i32, item_handle: i32) -> Result<(), Error> { + self.invoke( + MxCommandKind::Advise, + Payload::Advise(AdviseCommand { + server_handle, + item_handle, + }), + ) + .await?; + Ok(()) + } + + pub async fn write( + &self, + server_handle: i32, + item_handle: i32, + value: MxValue, + user_id: i32, + ) -> Result<(), Error> { + self.invoke( + MxCommandKind::Write, + Payload::Write(WriteCommand { + server_handle, + item_handle, + value: Some(value.into_proto()), + user_id, + }), + ) + .await?; + Ok(()) + } + + pub async fn write2( + &self, + server_handle: i32, + item_handle: i32, + value: MxValue, + timestamp_value: MxValue, + user_id: i32, + ) -> Result<(), Error> { + self.invoke( + MxCommandKind::Write2, + Payload::Write2(Write2Command { + server_handle, + item_handle, + value: Some(value.into_proto()), + timestamp_value: Some(timestamp_value.into_proto()), + user_id, + }), + ) + .await?; + Ok(()) + } + + pub async fn events(&self) -> Result { + self.events_after(0).await + } + + pub async fn events_after(&self, after_worker_sequence: u64) -> Result { + self.client + .stream_events(StreamEventsRequest { + session_id: self.id.clone(), + after_worker_sequence, + }) + .await + } + + pub async fn invoke_raw( + &self, + kind: MxCommandKind, + payload: Payload, + ) -> Result { + self.client + .invoke_raw(self.command_request(kind, payload)) + .await + } + + pub async fn invoke( + &self, + kind: MxCommandKind, + payload: Payload, + ) -> Result { + self.client + .invoke(self.command_request(kind, payload)) + .await + } + + fn command_request(&self, kind: MxCommandKind, payload: Payload) -> MxCommandRequest { + MxCommandRequest { + session_id: self.id.clone(), + client_correlation_id: format!("rust-client-{}", kind.as_str_name()), + command: Some(MxCommand { + kind: kind as i32, + payload: Some(payload), + }), + } + } +} + +fn register_server_handle(reply: &MxCommandReply) -> i32 { + match reply.payload.as_ref() { + Some(mx_command_reply::Payload::Register(register)) => register.server_handle, + _ => reply + .return_value + .as_ref() + .and_then(int32_reply_value) + .unwrap_or_default(), + } +} + +fn add_item_handle(reply: &MxCommandReply) -> i32 { + match reply.payload.as_ref() { + Some(mx_command_reply::Payload::AddItem(add_item)) => add_item.item_handle, + _ => reply + .return_value + .as_ref() + .and_then(int32_reply_value) + .unwrap_or_default(), + } +} + +fn add_item2_handle(reply: &MxCommandReply) -> i32 { + match reply.payload.as_ref() { + Some(mx_command_reply::Payload::AddItem2(add_item)) => add_item.item_handle, + _ => reply + .return_value + .as_ref() + .and_then(int32_reply_value) + .unwrap_or_default(), + } +} + +fn int32_reply_value(value: &ProtoMxValue) -> Option { + match value.kind.as_ref()? { + crate::generated::mxaccess_gateway::v1::mx_value::Kind::Int32Value(value) => Some(*value), + _ => None, + } } diff --git a/clients/rust/src/value.rs b/clients/rust/src/value.rs index 754e64d..6547b98 100644 --- a/clients/rust/src/value.rs +++ b/clients/rust/src/value.rs @@ -1,9 +1,239 @@ -use crate::generated::mxaccess_gateway::v1::MxValue; +use crate::generated::mxaccess_gateway::v1::mx_array::Values; +use crate::generated::mxaccess_gateway::v1::mx_value::Kind; +use crate::generated::mxaccess_gateway::v1::{ + BoolArray, DoubleArray, FloatArray, Int32Array, Int64Array, MxArray, MxDataType, + MxStatusCategory, MxStatusProxy, MxStatusSource, MxValue as ProtoMxValue, RawArray, + StringArray, TimestampArray, +}; -pub fn int32_value(value: i32) -> MxValue { - MxValue { - data_type: crate::generated::mxaccess_gateway::v1::MxDataType::Integer as i32, - kind: Some(crate::generated::mxaccess_gateway::v1::mx_value::Kind::Int32Value(value)), - ..MxValue::default() +#[derive(Clone, Debug, PartialEq)] +pub struct MxValue { + raw: ProtoMxValue, + projection: MxValueProjection, +} + +impl MxValue { + pub fn from_proto(raw: ProtoMxValue) -> Self { + let projection = MxValueProjection::from_proto(&raw); + Self { raw, projection } + } + + pub fn bool(value: bool) -> Self { + Self::from_proto(ProtoMxValue { + data_type: MxDataType::Boolean as i32, + variant_type: "VT_BOOL".to_owned(), + kind: Some(Kind::BoolValue(value)), + ..ProtoMxValue::default() + }) + } + + pub fn int32(value: i32) -> Self { + Self::from_proto(ProtoMxValue { + data_type: MxDataType::Integer as i32, + variant_type: "VT_I4".to_owned(), + kind: Some(Kind::Int32Value(value)), + ..ProtoMxValue::default() + }) + } + + pub fn int64(value: i64) -> Self { + Self::from_proto(ProtoMxValue { + data_type: MxDataType::Integer as i32, + variant_type: "VT_I8".to_owned(), + kind: Some(Kind::Int64Value(value)), + ..ProtoMxValue::default() + }) + } + + pub fn float(value: f32) -> Self { + Self::from_proto(ProtoMxValue { + data_type: MxDataType::Float as i32, + variant_type: "VT_R4".to_owned(), + kind: Some(Kind::FloatValue(value)), + ..ProtoMxValue::default() + }) + } + + pub fn double(value: f64) -> Self { + Self::from_proto(ProtoMxValue { + data_type: MxDataType::Double as i32, + variant_type: "VT_R8".to_owned(), + kind: Some(Kind::DoubleValue(value)), + ..ProtoMxValue::default() + }) + } + + pub fn string(value: impl Into) -> Self { + Self::from_proto(ProtoMxValue { + data_type: MxDataType::String as i32, + variant_type: "VT_BSTR".to_owned(), + kind: Some(Kind::StringValue(value.into())), + ..ProtoMxValue::default() + }) + } + + pub fn raw(&self) -> &ProtoMxValue { + &self.raw + } + + pub fn projection(&self) -> &MxValueProjection { + &self.projection + } + + pub fn into_proto(self) -> ProtoMxValue { + self.raw + } +} + +impl From for ProtoMxValue { + fn from(value: MxValue) -> Self { + value.into_proto() + } +} + +impl From for MxValue { + fn from(value: ProtoMxValue) -> Self { + Self::from_proto(value) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum MxValueProjection { + Unset, + Null, + Bool(bool), + Int32(i32), + Int64(i64), + Float(f32), + Double(f64), + String(String), + Timestamp(prost_types::Timestamp), + Array(MxArrayValue), + Raw(Vec), +} + +impl MxValueProjection { + fn from_proto(value: &ProtoMxValue) -> Self { + if value.is_null { + return Self::Null; + } + + match value.kind.as_ref() { + Some(Kind::BoolValue(value)) => Self::Bool(*value), + Some(Kind::Int32Value(value)) => Self::Int32(*value), + Some(Kind::Int64Value(value)) => Self::Int64(*value), + Some(Kind::FloatValue(value)) => Self::Float(*value), + Some(Kind::DoubleValue(value)) => Self::Double(*value), + Some(Kind::StringValue(value)) => Self::String(value.clone()), + Some(Kind::TimestampValue(value)) => Self::Timestamp(*value), + Some(Kind::ArrayValue(value)) => Self::Array(MxArrayValue::from_proto(value.clone())), + Some(Kind::RawValue(value)) => Self::Raw(value.clone()), + None => Self::Unset, + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct MxArrayValue { + raw: MxArray, + projection: MxArrayProjection, +} + +impl MxArrayValue { + pub fn from_proto(raw: MxArray) -> Self { + let projection = MxArrayProjection::from_proto(&raw); + Self { raw, projection } + } + + pub fn string(values: Vec) -> Self { + Self::from_proto(MxArray { + element_data_type: MxDataType::String as i32, + variant_type: "VT_ARRAY|VT_BSTR".to_owned(), + dimensions: vec![values.len() as u32], + values: Some(Values::StringValues(StringArray { values })), + ..MxArray::default() + }) + } + + pub fn raw(&self) -> &MxArray { + &self.raw + } + + pub fn projection(&self) -> &MxArrayProjection { + &self.projection + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum MxArrayProjection { + Unset, + Bool(Vec), + Int32(Vec), + Int64(Vec), + Float(Vec), + Double(Vec), + String(Vec), + Timestamp(Vec), + Raw(Vec>), +} + +impl MxArrayProjection { + fn from_proto(array: &MxArray) -> Self { + match array.values.as_ref() { + Some(Values::BoolValues(BoolArray { values })) => Self::Bool(values.clone()), + Some(Values::Int32Values(Int32Array { values })) => Self::Int32(values.clone()), + Some(Values::Int64Values(Int64Array { values })) => Self::Int64(values.clone()), + Some(Values::FloatValues(FloatArray { values })) => Self::Float(values.clone()), + Some(Values::DoubleValues(DoubleArray { values })) => Self::Double(values.clone()), + Some(Values::StringValues(StringArray { values })) => Self::String(values.clone()), + Some(Values::TimestampValues(TimestampArray { values })) => { + Self::Timestamp(values.clone()) + } + Some(Values::RawValues(RawArray { values })) => Self::Raw(values.clone()), + None => Self::Unset, + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct MxStatus { + raw: MxStatusProxy, +} + +impl MxStatus { + pub fn from_proto(raw: MxStatusProxy) -> Self { + Self { raw } + } + + pub fn raw(&self) -> &MxStatusProxy { + &self.raw + } + + pub fn success(&self) -> i32 { + self.raw.success + } + + pub fn category(&self) -> Option { + MxStatusCategory::try_from(self.raw.category).ok() + } + + pub fn detected_by(&self) -> Option { + MxStatusSource::try_from(self.raw.detected_by).ok() + } + + pub fn detail(&self) -> i32 { + self.raw.detail + } + + pub fn raw_category(&self) -> i32 { + self.raw.raw_category + } + + pub fn raw_detected_by(&self) -> i32 { + self.raw.raw_detected_by + } + + pub fn diagnostic_text(&self) -> &str { + &self.raw.diagnostic_text } } diff --git a/clients/rust/tests/client_behavior.rs b/clients/rust/tests/client_behavior.rs new file mode 100644 index 0000000..0afccf6 --- /dev/null +++ b/clients/rust/tests/client_behavior.rs @@ -0,0 +1,398 @@ +use std::pin::Pin; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures_core::Stream; +use futures_util::StreamExt; +use mxgateway_client::generated::mxaccess_gateway::v1::mx_access_gateway_server::{ + MxAccessGateway, MxAccessGatewayServer, +}; +use mxgateway_client::generated::mxaccess_gateway::v1::mx_command_reply; +use mxgateway_client::generated::mxaccess_gateway::v1::mx_value::Kind; +use mxgateway_client::generated::mxaccess_gateway::v1::{ + AddItemReply, CloseSessionReply, CloseSessionRequest, MxCommandKind, MxCommandReply, + MxDataType, MxEvent, MxEventFamily, MxStatusCategory, MxStatusProxy, MxStatusSource, MxValue, + OpenSessionReply, OpenSessionRequest, ProtocolStatus, ProtocolStatusCode, SessionState, + StreamEventsRequest, +}; +use mxgateway_client::{ + ApiKey, ClientOptions, CommandError, Error, GatewayClient, MxStatus, MxValue as ClientMxValue, + MxValueProjection, +}; +use serde_json::Value; +use tokio::net::TcpListener; +use tokio::sync::{mpsc, Mutex}; +use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream}; +use tonic::transport::Server; +use tonic::{Request, Response, Status}; + +#[tokio::test] +async fn fake_server_receives_bearer_metadata_and_raw_client_is_reachable() { + let state = Arc::new(FakeState::default()); + let endpoint = spawn_fake_gateway(state.clone()).await; + let mut client = GatewayClient::connect( + ClientOptions::new(endpoint).with_api_key(ApiKey::new("mxgw_fixture_secret")), + ) + .await + .unwrap(); + + let _raw = client.raw_client(); + let session = client + .open_session(OpenSessionRequest { + client_session_name: "rust-test".to_owned(), + ..OpenSessionRequest::default() + }) + .await + .unwrap(); + + assert_eq!(session.id(), "session-fixture"); + assert_eq!( + state.authorization.lock().await.as_deref(), + Some("Bearer mxgw_fixture_secret") + ); +} + +#[tokio::test] +async fn session_helpers_build_commands_and_preserve_command_errors() { + let state = Arc::new(FakeState::default()); + let endpoint = spawn_fake_gateway(state.clone()).await; + let client = GatewayClient::connect(ClientOptions::new(endpoint)) + .await + .unwrap(); + let session = client.session("session-fixture"); + + let item_handle = session.add_item(12, "Plant.Area.Tag").await.unwrap(); + + assert_eq!(item_handle, 34); + let last_command = state.last_command_kind.lock().await; + assert_eq!(*last_command, Some(MxCommandKind::AddItem as i32)); + drop(last_command); + + let error = session + .write(12, 34, ClientMxValue::int32(123), 0) + .await + .unwrap_err(); + let Error::Command(error) = error else { + panic!("write failure should preserve the raw command reply: {error:?}"); + }; + assert_eq!( + error.reply().protocol_status.as_ref().unwrap().code, + ProtocolStatusCode::MxaccessFailure as i32 + ); + assert_eq!(error.reply().hresult, Some(-2147220992)); + assert_eq!(error.reply().statuses.len(), 2); +} + +#[tokio::test] +async fn event_stream_preserves_order_and_drop_cancels_server_stream() { + let state = Arc::new(FakeState::default()); + let endpoint = spawn_fake_gateway(state.clone()).await; + let client = GatewayClient::connect(ClientOptions::new(endpoint)) + .await + .unwrap(); + + let mut stream = client + .stream_events(StreamEventsRequest { + session_id: "session-fixture".to_owned(), + after_worker_sequence: 0, + }) + .await + .unwrap(); + + assert_eq!(stream.next().await.unwrap().unwrap().worker_sequence, 1); + assert_eq!(stream.next().await.unwrap().unwrap().worker_sequence, 2); + + drop(stream); + for _ in 0..20 { + if state.stream_dropped.load(Ordering::SeqCst) { + return; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + assert!(state.stream_dropped.load(Ordering::SeqCst)); +} + +#[test] +fn value_conversion_fixtures_keep_typed_projection_and_raw_metadata() { + let fixture = behavior_fixture("values/value-conversion-cases.json"); + let cases = fixture["cases"].as_array().unwrap(); + + let int64_case = case_by_id(cases, "int64.large"); + let int64_value = ClientMxValue::from_proto(MxValue { + data_type: MxDataType::Integer as i32, + variant_type: "VT_I8".to_owned(), + kind: Some(Kind::Int64Value( + int64_case["value"]["int64Value"] + .as_str() + .unwrap() + .parse() + .unwrap(), + )), + ..MxValue::default() + }); + assert_eq!( + int64_value.projection(), + &MxValueProjection::Int64(9_223_372_036_854_770_000) + ); + + let raw_case = case_by_id(cases, "raw-fallback.variant"); + let raw_value = ClientMxValue::from_proto(MxValue { + data_type: MxDataType::Unknown as i32, + variant_type: "VT_RECORD".to_owned(), + raw_diagnostic: raw_case["value"]["rawDiagnostic"] + .as_str() + .unwrap() + .to_owned(), + raw_data_type: raw_case["value"]["rawDataType"].as_i64().unwrap() as i32, + kind: Some(Kind::RawValue(vec![1, 2, 3, 4, 5])), + ..MxValue::default() + }); + assert_eq!( + raw_value.projection(), + &MxValueProjection::Raw(vec![1, 2, 3, 4, 5]) + ); + assert_eq!(raw_value.raw().raw_data_type, 32767); + assert!(raw_value.raw().raw_diagnostic.contains("No lossless")); +} + +#[test] +fn status_conversion_fixtures_preserve_raw_fields() { + let fixture = behavior_fixture("statuses/status-conversion-cases.json"); + let cases = fixture["cases"].as_array().unwrap(); + let raw_case = case_by_id(cases, "raw-unknown-category"); + let status = MxStatus::from_proto(MxStatusProxy { + success: raw_case["status"]["success"].as_i64().unwrap() as i32, + category: MxStatusCategory::Unknown as i32, + detected_by: MxStatusSource::Unknown as i32, + detail: raw_case["status"]["detail"].as_i64().unwrap() as i32, + raw_category: raw_case["status"]["rawCategory"].as_i64().unwrap() as i32, + raw_detected_by: raw_case["status"]["rawDetectedBy"].as_i64().unwrap() as i32, + diagnostic_text: raw_case["status"]["diagnosticText"] + .as_str() + .unwrap() + .to_owned(), + }); + + assert_eq!(status.success(), 0); + assert_eq!(status.category(), Some(MxStatusCategory::Unknown)); + assert_eq!(status.raw_category(), 99); + assert_eq!(status.raw_detected_by(), 77); + assert!(status.diagnostic_text().contains("preserved")); +} + +#[test] +fn authentication_and_authorization_statuses_are_distinct_and_redacted() { + let auth = Error::from(Status::unauthenticated( + "invalid API key mxgw_visible_secret", + )); + let denied = Error::from(Status::permission_denied("missing scope mxaccess.write")); + + assert!(matches!(auth, Error::Authentication { .. })); + assert!(matches!(denied, Error::Authorization { .. })); + assert!(!auth.to_string().contains("visible_secret")); +} + +#[test] +fn command_error_display_keeps_raw_reply_accessible() { + let reply = mxaccess_failure_reply(); + let error = CommandError::new(reply.clone()); + + assert_eq!(error.reply().hresult, Some(-2147220992)); + assert!(error.to_string().contains("MxaccessFailure")); +} + +#[derive(Default)] +struct FakeState { + authorization: Mutex>, + last_command_kind: Mutex>, + stream_dropped: Arc, +} + +#[derive(Clone)] +struct FakeGateway { + state: Arc, +} + +#[tonic::async_trait] +impl MxAccessGateway for FakeGateway { + async fn open_session( + &self, + request: Request, + ) -> Result, Status> { + *self.state.authorization.lock().await = request + .metadata() + .get("authorization") + .and_then(|value| value.to_str().ok()) + .map(str::to_owned); + + Ok(Response::new(OpenSessionReply { + session_id: "session-fixture".to_owned(), + backend_name: "fake".to_owned(), + worker_process_id: 1234, + worker_protocol_version: 1, + gateway_protocol_version: 1, + protocol_status: Some(ok_status("opened")), + ..OpenSessionReply::default() + })) + } + + async fn close_session( + &self, + request: Request, + ) -> Result, Status> { + Ok(Response::new(CloseSessionReply { + session_id: request.into_inner().session_id, + final_state: SessionState::Closed as i32, + protocol_status: Some(ok_status("closed")), + })) + } + + async fn invoke( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let kind = request + .command + .as_ref() + .map(|command| command.kind) + .unwrap_or_default(); + *self.state.last_command_kind.lock().await = Some(kind); + + if kind == MxCommandKind::Write as i32 { + return Ok(Response::new(mxaccess_failure_reply())); + } + + Ok(Response::new(MxCommandReply { + session_id: request.session_id, + correlation_id: "fake-correlation".to_owned(), + kind, + protocol_status: Some(ok_status("command ok")), + payload: Some(mx_command_reply::Payload::AddItem(AddItemReply { + item_handle: 34, + })), + ..MxCommandReply::default() + })) + } + + type StreamEventsStream = DropAwareStream; + + async fn stream_events( + &self, + _request: Request, + ) -> Result, Status> { + let (sender, receiver) = mpsc::channel(4); + sender.send(Ok(event(1))).await.unwrap(); + sender.send(Ok(event(2))).await.unwrap(); + + Ok(Response::new(DropAwareStream { + inner: ReceiverStream::new(receiver), + dropped: self.state.stream_dropped.clone(), + })) + } +} + +struct DropAwareStream { + inner: ReceiverStream>, + dropped: Arc, +} + +impl Stream for DropAwareStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(context) + } +} + +impl Drop for DropAwareStream { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +async fn spawn_fake_gateway(state: Arc) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let incoming = TcpListenerStream::new(listener); + let service = MxAccessGatewayServer::new(FakeGateway { state }); + tokio::spawn(async move { + Server::builder() + .add_service(service) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + + format!("http://{address}") +} + +fn ok_status(message: &str) -> ProtocolStatus { + ProtocolStatus { + code: ProtocolStatusCode::Ok as i32, + message: message.to_owned(), + } +} + +fn mxaccess_failure_reply() -> MxCommandReply { + MxCommandReply { + session_id: "session-fixture".to_owned(), + correlation_id: "gateway-correlation-write-1".to_owned(), + kind: MxCommandKind::Write as i32, + protocol_status: Some(ProtocolStatus { + code: ProtocolStatusCode::MxaccessFailure as i32, + message: "MXAccess rejected the write.".to_owned(), + }), + hresult: Some(-2147220992), + statuses: vec![ + MxStatusProxy { + success: 0, + category: MxStatusCategory::SecurityError as i32, + detected_by: MxStatusSource::RespondingLmx as i32, + detail: 321, + raw_category: 8, + raw_detected_by: 3, + diagnostic_text: "Write denied by provider security.".to_owned(), + }, + MxStatusProxy { + success: 0, + category: MxStatusCategory::OperationalError as i32, + detected_by: MxStatusSource::RespondingNmx as i32, + detail: 902, + raw_category: 7, + raw_detected_by: 5, + diagnostic_text: "Provider rejected the item state.".to_owned(), + }, + ], + ..MxCommandReply::default() + } +} + +fn event(sequence: u64) -> MxEvent { + MxEvent { + family: MxEventFamily::OnDataChange as i32, + session_id: "session-fixture".to_owned(), + worker_sequence: sequence, + ..MxEvent::default() + } +} + +fn behavior_fixture(path: &str) -> Value { + let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../proto/fixtures/behavior") + .join(path); + let data = std::fs::read_to_string(&path).unwrap(); + serde_json::from_str(&data).unwrap() +} + +fn case_by_id<'a>(cases: &'a [Value], id: &str) -> &'a Value { + cases + .iter() + .find(|case| case["id"].as_str() == Some(id)) + .unwrap_or_else(|| panic!("missing fixture case {id}")) +}