0702551c25
Regenerate Python proto bindings to pick up MxSparseArray/MxSparseElement/
sparse_array_value from the shared mxaccess_gateway.proto. Add
Session.write_array_elements which builds an MxValue(sparse_array_value=…)
from a {index→scalar} dict and delegates to the existing write(). Add 8 pytest
tests covering builder correctness and full round-trip wire shape. Update
README with a default-fill semantics paragraph and bare-name array-write note.
210 lines
6.9 KiB
Python
210 lines
6.9 KiB
Python
"""Tests for Session.write_array_elements default-fill sparse-array helper."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from zb_mom_ww_mxgateway import ClientOptions, GatewayClient
|
|
from zb_mom_ww_mxgateway.generated import mxaccess_gateway_pb2 as pb
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_sparse_mx_value(
|
|
element_data_type: "pb.MxDataType.ValueType",
|
|
total_length: int,
|
|
elements: dict[int, Any],
|
|
) -> pb.MxValue:
|
|
"""Build an MxValue wrapping an MxSparseArray from Python primitives.
|
|
|
|
Mirrors the logic inside Session.write_array_elements so tests can assert
|
|
the exact wire shape the helper produces without going through the full
|
|
gRPC stack.
|
|
"""
|
|
from zb_mom_ww_mxgateway.values import to_mx_value
|
|
|
|
return pb.MxValue(
|
|
sparse_array_value=pb.MxSparseArray(
|
|
element_data_type=element_data_type,
|
|
total_length=total_length,
|
|
elements=[
|
|
pb.MxSparseElement(index=idx, value=to_mx_value(val))
|
|
for idx, val in elements.items()
|
|
],
|
|
)
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fake stub (minimal — only needs Invoke / OpenSession)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeUnary:
|
|
def __init__(self, replies: list[Any]) -> None:
|
|
self.replies = list(replies)
|
|
self.requests: list[Any] = []
|
|
self.metadata: tuple[tuple[str, str], ...] | None = None
|
|
|
|
async def __call__(
|
|
self,
|
|
request: Any,
|
|
*,
|
|
metadata: tuple[tuple[str, str], ...],
|
|
) -> Any:
|
|
self.requests.append(request)
|
|
self.metadata = metadata
|
|
return self.replies.pop(0)
|
|
|
|
|
|
class _FakeStub:
|
|
"""Minimal stub that satisfies GatewayClient for a single invoke round-trip."""
|
|
|
|
def __init__(self) -> None:
|
|
ok = pb.ProtocolStatus(code=pb.PROTOCOL_STATUS_CODE_OK)
|
|
self.open_session = _FakeUnary([pb.OpenSessionReply(session_id="s1", protocol_status=ok)])
|
|
self.invoke = _FakeUnary(
|
|
[
|
|
pb.MxCommandReply(
|
|
session_id="s1",
|
|
kind=pb.MX_COMMAND_KIND_WRITE,
|
|
protocol_status=ok,
|
|
),
|
|
]
|
|
)
|
|
self.OpenSession = self.open_session
|
|
self.Invoke = self.invoke
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_sparse_mx_value_builder_sets_correct_oneof() -> None:
|
|
"""Builder helper must produce an MxValue with kind == 'sparse_array_value'."""
|
|
mv = _make_sparse_mx_value(pb.MX_DATA_TYPE_INTEGER, 5, {0: 10, 3: 30})
|
|
assert mv.WhichOneof("kind") == "sparse_array_value"
|
|
|
|
|
|
def test_sparse_mx_value_builder_total_length() -> None:
|
|
"""total_length must equal the value passed to the builder."""
|
|
mv = _make_sparse_mx_value(pb.MX_DATA_TYPE_INTEGER, 20, {1: 7})
|
|
assert mv.sparse_array_value.total_length == 20
|
|
|
|
|
|
def test_sparse_mx_value_builder_element_count_and_values() -> None:
|
|
"""Elements list length and scalar values must match the input dict."""
|
|
mv = _make_sparse_mx_value(pb.MX_DATA_TYPE_INTEGER, 10, {0: 11, 4: 55, 9: 99})
|
|
sa = mv.sparse_array_value
|
|
assert len(sa.elements) == 3
|
|
by_index = {e.index: e.value for e in sa.elements}
|
|
assert by_index[0].int32_value == 11
|
|
assert by_index[4].int32_value == 55
|
|
assert by_index[9].int32_value == 99
|
|
|
|
|
|
def test_sparse_mx_value_builder_element_data_type() -> None:
|
|
"""element_data_type must be forwarded verbatim."""
|
|
mv = _make_sparse_mx_value(pb.MX_DATA_TYPE_FLOAT, 3, {})
|
|
assert mv.sparse_array_value.element_data_type == pb.MX_DATA_TYPE_FLOAT
|
|
|
|
|
|
def test_sparse_mx_value_builder_empty_elements() -> None:
|
|
"""An empty elements dict must still produce a valid MxSparseArray."""
|
|
mv = _make_sparse_mx_value(pb.MX_DATA_TYPE_BOOLEAN, 8, {})
|
|
sa = mv.sparse_array_value
|
|
assert len(sa.elements) == 0
|
|
assert sa.total_length == 8
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration-level: write_array_elements routes through Session.write
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_array_elements_sends_sparse_array_write_command() -> None:
|
|
"""write_array_elements must send a WRITE command whose value is sparse_array_value."""
|
|
stub = _FakeStub()
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
session = await client.open_session()
|
|
|
|
await session.write_array_elements(
|
|
server_handle=1,
|
|
item_handle=2,
|
|
element_data_type=pb.MX_DATA_TYPE_INTEGER,
|
|
total_length=10,
|
|
elements={0: 100, 5: 500},
|
|
)
|
|
|
|
assert len(stub.invoke.requests) == 1
|
|
cmd_req: pb.MxCommandRequest = stub.invoke.requests[0]
|
|
cmd = cmd_req.command
|
|
assert cmd.kind == pb.MX_COMMAND_KIND_WRITE
|
|
mv = cmd.write.value
|
|
assert mv.WhichOneof("kind") == "sparse_array_value"
|
|
|
|
sa = mv.sparse_array_value
|
|
assert sa.element_data_type == pb.MX_DATA_TYPE_INTEGER
|
|
assert sa.total_length == 10
|
|
assert len(sa.elements) == 2
|
|
by_index = {e.index: e.value for e in sa.elements}
|
|
assert by_index[0].int32_value == 100
|
|
assert by_index[5].int32_value == 500
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_array_elements_forwards_user_id() -> None:
|
|
"""user_id must reach the WriteCommand."""
|
|
stub = _FakeStub()
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
session = await client.open_session()
|
|
|
|
await session.write_array_elements(
|
|
server_handle=1,
|
|
item_handle=2,
|
|
element_data_type=pb.MX_DATA_TYPE_BOOLEAN,
|
|
total_length=4,
|
|
elements={},
|
|
user_id=42,
|
|
)
|
|
|
|
cmd = stub.invoke.requests[0].command
|
|
assert cmd.write.user_id == 42
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_array_elements_string_elements() -> None:
|
|
"""String element values must be encoded as string_value scalars."""
|
|
stub = _FakeStub()
|
|
client = await GatewayClient.connect(
|
|
ClientOptions(endpoint="fake", api_key="mxgw_test_secret", plaintext=True),
|
|
stub=stub,
|
|
)
|
|
session = await client.open_session()
|
|
|
|
await session.write_array_elements(
|
|
server_handle=1,
|
|
item_handle=2,
|
|
element_data_type=pb.MX_DATA_TYPE_STRING,
|
|
total_length=3,
|
|
elements={1: "hello", 2: "world"},
|
|
)
|
|
|
|
sa = stub.invoke.requests[0].command.write.value.sparse_array_value
|
|
by_index = {e.index: e.value for e in sa.elements}
|
|
assert by_index[1].string_value == "hello"
|
|
assert by_index[2].string_value == "world"
|