from __future__ import annotations import argparse import csv import datetime as dt import json import re import struct from dataclasses import dataclass from pathlib import Path from scapy.all import IP, IPv6, Raw, TCP, rdpcap EVENT_RE = re.compile(r"^(?P\S+)\t(?P[^\t]+)\t(?P.*)$") @dataclass(frozen=True) class Endpoint: host: str port: int @classmethod def parse(cls, text: str) -> "Endpoint": host, port = text.rsplit(":", 1) return cls(host, int(port)) def parse_timestamp(text: str) -> dt.datetime: normalized = text.replace("Z", "+00:00") parsed = dt.datetime.fromisoformat(normalized) if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=dt.timezone.utc) return parsed.astimezone(dt.timezone.utc) def harness_events(path: Path) -> list[dict[str, object]]: events: list[dict[str, object]] = [] for line in path.read_text(encoding="utf-8").splitlines(): match = EVENT_RE.match(line) if not match: continue try: payload = json.loads(match.group("payload")) except json.JSONDecodeError: payload = {} events.append({ "timestamp": parse_timestamp(match.group("timestamp")), "event": match.group("event"), "payload": payload, }) return events def find_event(events: list[dict[str, object]], name: str) -> dict[str, object]: for event in events: if event["event"] == name: return event raise RuntimeError(f"Event {name!r} was not found.") def find_events(events: list[dict[str, object]], name: str) -> list[dict[str, object]]: return [event for event in events if event["event"] == name] def packet_hosts(packet) -> tuple[str, str] | None: if IP in packet: return str(packet[IP].src), str(packet[IP].dst) if IPv6 in packet: return str(packet[IPv6].src), str(packet[IPv6].dst) return None def i32(data: bytes, offset: int) -> int | None: if offset + 4 > len(data): return None return struct.unpack_from(" int | None: if offset + 4 > len(data): return None return struct.unpack_from(" str: return "".join(chr(value) if 32 <= value <= 126 else "." for value in data[:limit]) def announced_data_records_match(data: bytes, offset: int, announced_size: int) -> bool: if announced_size < 0: return False total = 0 cursor = offset + 12 while total < announced_size and cursor + 4 <= len(data): record_length = u32(data, cursor) if record_length is None or record_length > 1024 * 1024: return False record_size = record_length + 4 if record_size <= 4 or cursor + record_size > len(data): return False total += record_size cursor += record_size return total == announced_size def classify_control(data: bytes, offset: int) -> str | None: first = i32(data, offset) second = i32(data, offset + 4) third = i32(data, offset + 8) if first is None or second is None or third is None: return None if first in {-1, -2}: return "control" if third != 0 or second < 0: return None if announced_data_records_match(data, offset, first): return "control_announce" if offset + 12 == len(data): return "control_announce" return None def iter_records(payload: bytes) -> list[tuple[str, int, bytes]]: records: list[tuple[str, int, bytes]] = [] offset = 0 while offset < len(payload): first = i32(payload, offset) if first is None: break control_type = classify_control(payload, offset) if control_type is not None: record = payload[offset:offset + 12] records.append((control_type, offset, record)) offset += 12 continue length = u32(payload, offset) if length is not None and length <= 1024 * 1024 and offset + 4 + length <= len(payload): record = payload[offset:offset + 4 + length] records.append(("data", offset, record)) offset += 4 + length continue records.append(("unknown", offset, payload[offset:])) break return records def record_body(record_type: str, record: bytes) -> bytes: if record_type == "data": return record[4:] return record def write_rows(capture_dir: Path, endpoint_a: Endpoint, endpoint_b: Endpoint, before: float, after: float, out: Path) -> None: events = harness_events(capture_dir / "harness.log") write_begins = find_events(events, "mx.write.begin") write_completes = find_events(events, "mx.event.write-complete") if not write_begins: raise RuntimeError("Event 'mx.write.begin' was not found.") rows: list[dict[str, str]] = [] packets = list(enumerate(rdpcap(str(capture_dir / "loopback.pcapng")), start=1)) complete_cursor = 0 for ordinal, write_begin in enumerate(write_begins): write_time = write_begin["timestamp"] assert isinstance(write_time, dt.datetime) while complete_cursor < len(write_completes): candidate_time = write_completes[complete_cursor]["timestamp"] assert isinstance(candidate_time, dt.datetime) if candidate_time >= write_time: break complete_cursor += 1 if complete_cursor < len(write_completes): complete_time = write_completes[complete_cursor]["timestamp"] complete_cursor += 1 else: complete_time = write_time assert isinstance(complete_time, dt.datetime) write_payload = write_begin.get("payload", {}) write_index = str(write_payload.get("WriteIndex", ordinal)) if isinstance(write_payload, dict) else str(ordinal) write_value = "" if isinstance(write_payload, dict) and isinstance(write_payload.get("Value"), dict): value_payload = write_payload["Value"] write_value = str(value_payload.get("Value", "")) start_epoch = write_time.timestamp() - before end_epoch = write_time.timestamp() + after for frame, packet in packets: if TCP not in packet or Raw not in packet: continue packet_time = float(packet.time) if packet_time < start_epoch or packet_time > end_epoch: continue hosts = packet_hosts(packet) if hosts is None: continue tcp = packet[TCP] src = Endpoint(hosts[0], int(tcp.sport)) dst = Endpoint(hosts[1], int(tcp.dport)) if {src, dst} != {endpoint_a, endpoint_b}: continue direction = "a_to_b" if src == endpoint_a else "b_to_a" payload = bytes(packet[Raw].load) for record_index, (record_type, payload_offset, record) in enumerate(iter_records(payload)): body = record_body(record_type, record) rows.append({ "capture": capture_dir.name, "write_index": write_index, "write_value": write_value, "frame": str(frame), "packet_time_relative_to_write": f"{packet_time - write_time.timestamp():.9f}", "packet_time_relative_to_complete": f"{packet_time - complete_time.timestamp():.9f}", "direction": direction, "src": f"{src.host}:{src.port}", "dst": f"{dst.host}:{dst.port}", "tcp_seq": str(int(tcp.seq)), "payload_offset": str(payload_offset), "record_index": str(record_index), "record_type": record_type, "record_size": str(len(record)), "announced_length": "" if record_type != "data" else str(len(body)), "i32_0": "" if (v := i32(body, 0)) is None else str(v), "i32_1": "" if (v := i32(body, 4)) is None else str(v), "i32_2": "" if (v := i32(body, 8)) is None else str(v), "i32_3": "" if (v := i32(body, 12)) is None else str(v), "signature16": body[:16].hex(" "), "signature24": body[:24].hex(" "), "hex": body.hex(" "), "ascii_preview": ascii_preview(body), }) out.parent.mkdir(parents=True, exist_ok=True) header = [ "capture", "write_index", "write_value", "frame", "packet_time_relative_to_write", "packet_time_relative_to_complete", "direction", "src", "dst", "tcp_seq", "payload_offset", "record_index", "record_type", "record_size", "announced_length", "i32_0", "i32_1", "i32_2", "i32_3", "signature16", "signature24", "hex", "ascii_preview", ] with out.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=header, delimiter="\t", lineterminator="\n") writer.writeheader() writer.writerows(rows) def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("capture_dir", type=Path) parser.add_argument("--a", default="127.0.0.1:57415") parser.add_argument("--b", default="127.0.0.1:57433") parser.add_argument("--before", type=float, default=0.35) parser.add_argument("--after", type=float, default=0.75) parser.add_argument("--out", type=Path) args = parser.parse_args() out = args.out or (args.capture_dir / "write-window-mixed-records.tsv") write_rows(args.capture_dir, Endpoint.parse(args.a), Endpoint.parse(args.b), args.before, args.after, out) return 0 if __name__ == "__main__": raise SystemExit(main())