diff --git a/src/ScadaLink.InboundAPI/Middleware/AuditWriteMiddleware.cs b/src/ScadaLink.InboundAPI/Middleware/AuditWriteMiddleware.cs index f6d3e13..d7bf85c 100644 --- a/src/ScadaLink.InboundAPI/Middleware/AuditWriteMiddleware.cs +++ b/src/ScadaLink.InboundAPI/Middleware/AuditWriteMiddleware.cs @@ -3,6 +3,8 @@ using System.Text; using System.Text.Json; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ScadaLink.AuditLog.Configuration; using ScadaLink.Commons.Entities.Audit; using ScadaLink.Commons.Interfaces.Services; using ScadaLink.Commons.Types.Enums; @@ -43,13 +45,21 @@ namespace ScadaLink.InboundAPI.Middleware; /// Body capture. The request body is buffered via /// then /// rewound so the downstream endpoint handler still sees the full payload. The -/// response body is captured by swapping for a -/// before the pipeline runs; after the pipeline -/// returns, the buffered bytes are copied to the original stream (transparent -/// to the real client) and read into . -/// Truncation to the configured inbound ceiling happens in -/// ; the -/// middleware itself stores the full buffered content. +/// response body is captured by wrapping in a +/// forwarding stream that mirrors writes to the original sink (transparent to +/// the real client) while capturing a bounded copy for audit. +/// +/// +/// +/// Bounded capture at the source. Both the request- and response-body +/// audit copies are bounded at +/// (default 1 MiB) AT THE CAPTURE SITE — we never buffer more than +/// cap + 1 bytes per body even when the client streams hundreds of MiB. +/// The downstream handler and the real client still see every byte; only the +/// audit copy is bounded. The cap is also enforced again by +/// (which OR's +/// in its own determination), so a +/// row truncated here remains truncated even if the filter is bypassed. /// /// public sealed class AuditWriteMiddleware @@ -77,21 +87,29 @@ public sealed class AuditWriteMiddleware private readonly RequestDelegate _next; private readonly ICentralAuditWriter _auditWriter; private readonly ILogger _logger; + private readonly IOptionsMonitor _options; public AuditWriteMiddleware( RequestDelegate next, ICentralAuditWriter auditWriter, - ILogger logger) + ILogger logger, + IOptionsMonitor options) { _next = next ?? throw new ArgumentNullException(nameof(next)); _auditWriter = auditWriter ?? throw new ArgumentNullException(nameof(auditWriter)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _options = options ?? throw new ArgumentNullException(nameof(options)); } public async Task InvokeAsync(HttpContext ctx) { var sw = Stopwatch.StartNew(); + // Per-request hot read of the inbound cap — mirrors the convention used + // by DefaultAuditPayloadFilter so a live config change picks up on the + // next request without re-resolving the singleton. + var cap = _options.CurrentValue.InboundMaxBytes; + // Audit Log #23 (ParentExecutionId): mint the inbound request's per-request // ExecutionId ONCE, here at the start of the request, and stash it on // HttpContext.Items. Two consumers share this single id: @@ -109,18 +127,17 @@ public sealed class AuditWriteMiddleware // of the pipeline for us — but we also rewind to position 0 after our // own read so the very next reader starts from the top. ctx.Request.EnableBuffering(); - var requestBody = await ReadBufferedRequestBodyAsync(ctx.Request).ConfigureAwait(false); + var (requestBody, requestTruncated) = + await ReadBufferedRequestBodyAsync(ctx.Request, cap).ConfigureAwait(false); - // Response body — swap in a MemoryStream so the pipeline writes are - // captured. The original Response.Body is restored in the finally block, - // and the captured bytes are copied back to it so the real client still - // receives every byte (transparent wrap). The captured string is then - // available for the audit row. + // Response body — wrap Response.Body in a forwarding stream that mirrors + // every write to the original sink (transparent to the real client) + // while capturing AT MOST `cap + 1` bytes for the audit copy. The + // original Response.Body is restored in the finally block. var originalResponseBody = ctx.Response.Body; - using var responseBuffer = new MemoryStream(); - ctx.Response.Body = responseBuffer; + using var captureStream = new CapturedResponseStream(originalResponseBody, cap); + ctx.Response.Body = captureStream; - string? responseBody = null; Exception? thrown = null; try { @@ -137,14 +154,19 @@ public sealed class AuditWriteMiddleware { sw.Stop(); - // Whatever the handler managed to write — full success, partial - // success before throwing, or nothing at all — copy back to the - // original stream and read for audit. - responseBody = await DrainResponseBufferAsync(responseBuffer, originalResponseBody) - .ConfigureAwait(false); + // Restore the original stream and resolve the captured audit copy. + // The forwarding wrapper has already written every byte to the + // original sink; this just pulls back the bounded UTF-8 string. ctx.Response.Body = originalResponseBody; + var (responseBody, responseTruncated) = captureStream.GetCapturedBody(); - EmitInboundAudit(ctx, sw.ElapsedMilliseconds, thrown, requestBody, responseBody); + EmitInboundAudit( + ctx, + sw.ElapsedMilliseconds, + thrown, + requestBody, + responseBody, + requestTruncated || responseTruncated); } } @@ -158,7 +180,8 @@ public sealed class AuditWriteMiddleware long durationMs, Exception? thrown, string? requestBody, - string? responseBody) + string? responseBody, + bool payloadTruncated) { try { @@ -210,7 +233,7 @@ public sealed class AuditWriteMiddleware ErrorMessage = thrown?.Message, RequestSummary = requestBody, ResponseSummary = responseBody, - PayloadTruncated = false, + PayloadTruncated = payloadTruncated, Extra = extra, // Central direct-write — no site-local forwarding state. ForwardState = null, @@ -231,80 +254,101 @@ public sealed class AuditWriteMiddleware } /// - /// Reads the buffered request body fully into a string and rewinds the - /// stream so the downstream handler sees the unconsumed payload. Returns - /// null for empty/missing bodies so the audit row's + /// Reads the buffered request body up to bytes + /// into a string for the audit copy and rewinds the stream so the + /// downstream handler sees the unconsumed payload. Returns + /// (null, false) for empty/missing bodies so the audit row's /// stays null rather than /// containing an empty string. /// - private static async Task ReadBufferedRequestBodyAsync(HttpRequest request) + /// + /// Reads AT MOST cap + 1 bytes from the request stream into a + /// scratch buffer; if the extra byte arrives the body is over the cap and + /// the returned string is UTF-8 byte-safe truncated to exactly + /// cap bytes with truncated = true. The cap applies only to + /// the audit copy — the request stream is always rewound to position 0 + /// afterwards so the framework's next reader (the endpoint handler's + /// JSON parser) sees the full body. + /// + private static async Task<(string? body, bool truncated)> ReadBufferedRequestBodyAsync( + HttpRequest request, + int capBytes) { if (request.ContentLength is 0) { - return null; + return (null, false); } try { request.Body.Position = 0; - using var reader = new StreamReader( - request.Body, - Encoding.UTF8, - detectEncodingFromByteOrderMarks: false, - bufferSize: 1024, - leaveOpen: true); - var content = await reader.ReadToEndAsync().ConfigureAwait(false); + + // Read AT MOST cap + 1 bytes — the extra byte tells us the body was + // over the cap without forcing us to allocate the whole payload. + var limit = capBytes + 1; + var buffer = new byte[limit]; + var total = 0; + while (total < limit) + { + var read = await request.Body + .ReadAsync(buffer.AsMemory(total, limit - total)) + .ConfigureAwait(false); + if (read == 0) + { + break; + } + total += read; + } request.Body.Position = 0; - return string.IsNullOrEmpty(content) ? null : content; + + if (total == 0) + { + return (null, false); + } + + var truncated = total > capBytes; + var bytesForString = truncated ? capBytes : total; + var content = DecodeUtf8Bounded(buffer, bytesForString, cutAtValidBytes: truncated); + return (string.IsNullOrEmpty(content) ? null : content, truncated); } catch { // A failed body read must not abort the request — fall through // with a null RequestSummary; the audit row still records the // outcome. - return null; + return (null, false); } } /// - /// Copies the bytes buffered in to - /// (so the real client still receives them) - /// and returns a UTF-8 string copy for . - /// Returns null when no bytes were written, mirroring the - /// empty-body contract. + /// UTF-8 byte-safe decode of bytes from + /// . When is + /// true the input is the result of a hard byte-count truncation, so + /// we walk back from validBytes while the byte is a continuation + /// byte (byte & 0xC0 == 0x80) to avoid splitting a multi-byte + /// codepoint. When false the caller is decoding the full payload + /// and the boundary stands as-is. /// - private static async Task DrainResponseBufferAsync( - MemoryStream buffer, - Stream originalBody) + /// + /// Mirrors the algorithm in DefaultAuditPayloadFilter.TruncateUtf8; + /// kept local to avoid a backwards project reference from + /// ScadaLink.AuditLog into ScadaLink.InboundAPI. + /// + private static string DecodeUtf8Bounded(byte[] bytes, int validBytes, bool cutAtValidBytes) { - if (buffer.Length == 0) + if (validBytes <= 0) { - return null; + return string.Empty; } - - buffer.Position = 0; - // Copy first so the client never misses bytes even if the read for audit - // throws somehow (defensive — MemoryStream.CopyToAsync to a sink shouldn't - // throw on its own, but the original body may). - try + var boundary = validBytes; + if (cutAtValidBytes) { - await buffer.CopyToAsync(originalBody).ConfigureAwait(false); + while (boundary > 0 && (bytes[boundary] & 0xC0) == 0x80) + { + boundary--; + } } - catch - { - // Best-effort: a sink that refuses our copy is the sink's problem; - // the audit still records what the handler produced. Do NOT rethrow. - } - - buffer.Position = 0; - using var reader = new StreamReader( - buffer, - Encoding.UTF8, - detectEncodingFromByteOrderMarks: false, - bufferSize: 1024, - leaveOpen: true); - var content = await reader.ReadToEndAsync().ConfigureAwait(false); - return string.IsNullOrEmpty(content) ? null : content; + return Encoding.UTF8.GetString(bytes, 0, boundary); } /// @@ -383,4 +427,153 @@ public sealed class AuditWriteMiddleware return path[(lastSlash + 1)..]; } + + /// + /// Write-only forwarding wrapper that mirrors every + /// write to the inner ASP.NET (so the real + /// client receives all bytes) while capturing AT MOST cap + 1 bytes + /// into a private bounded for the audit copy. + /// + /// + /// + /// The inner sink is owned by the framework and is NOT disposed when this + /// wrapper is disposed — we only own the capture . + /// + /// + /// All Write overloads forward to the inner stream FIRST, then capture the + /// remaining quota. If the inner sink throws (e.g. the client disconnects), + /// the exception is allowed to propagate — capture is best-effort, the + /// real I/O is authoritative. The handler-throws-mid-response test + /// (ResponseBody_OnHandlerThrow_BodyCapturedUpToTheThrow) verifies + /// that captured bytes up to the throw are still recoverable. + /// + /// + private sealed class CapturedResponseStream : Stream + { + private readonly Stream _inner; + private readonly int _capBytes; + private readonly MemoryStream _captured; + private bool _disposed; + + public CapturedResponseStream(Stream inner, int capBytes) + { + _inner = inner ?? throw new ArgumentNullException(nameof(inner)); + _capBytes = Math.Max(0, capBytes); + // Capture up to cap + 1 bytes so we can detect the over-cap case + // without growing the buffer further. + _captured = new MemoryStream(); + } + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public override long Length => + throw new NotSupportedException("CapturedResponseStream is write-only."); + + public override long Position + { + get => throw new NotSupportedException("CapturedResponseStream is write-only."); + set => throw new NotSupportedException("CapturedResponseStream is write-only."); + } + + public override void Flush() => _inner.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => + _inner.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) => + throw new NotSupportedException("CapturedResponseStream is write-only."); + + public override long Seek(long offset, SeekOrigin origin) => + throw new NotSupportedException("CapturedResponseStream is write-only."); + + public override void SetLength(long value) => + throw new NotSupportedException("CapturedResponseStream is write-only."); + + public override void Write(byte[] buffer, int offset, int count) + { + // Forward to the real sink FIRST — the client must never miss + // bytes if capture throws. + _inner.Write(buffer, offset, count); + CaptureBytes(buffer.AsSpan(offset, count)); + } + + public override void Write(ReadOnlySpan buffer) + { + _inner.Write(buffer); + CaptureBytes(buffer); + } + + public override async Task WriteAsync( + byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _inner.WriteAsync(buffer.AsMemory(offset, count), cancellationToken) + .ConfigureAwait(false); + CaptureBytes(buffer.AsSpan(offset, count)); + } + + public override async ValueTask WriteAsync( + ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + await _inner.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); + CaptureBytes(buffer.Span); + } + + /// + /// Capture up to cap + 1 bytes total into the private + /// . Once the cap quota is reached, further + /// bytes are silently dropped from the audit copy (the real sink has + /// already received them upstream of this call). + /// + private void CaptureBytes(ReadOnlySpan span) + { + if (span.Length == 0) + { + return; + } + var quota = (_capBytes + 1) - (int)_captured.Length; + if (quota <= 0) + { + return; + } + var take = Math.Min(quota, span.Length); + _captured.Write(span[..take]); + } + + /// + /// Returns the captured response body as a UTF-8 string (byte-safe + /// truncated to cap bytes) and a flag indicating whether the + /// audit copy hit the cap. Returns (null, false) when no bytes + /// were captured, mirroring the request-body empty contract. + /// + public (string? body, bool truncated) GetCapturedBody() + { + var length = (int)_captured.Length; + if (length == 0) + { + return (null, false); + } + var truncated = length > _capBytes; + var bytes = _captured.GetBuffer(); + var bytesForString = truncated ? _capBytes : length; + var content = DecodeUtf8Bounded(bytes, bytesForString, cutAtValidBytes: truncated); + return (string.IsNullOrEmpty(content) ? null : content, truncated); + } + + protected override void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + // Own only the capture stream; the inner sink belongs to + // the framework's response pipeline. + _captured.Dispose(); + } + _disposed = true; + } + base.Dispose(disposing); + } + } } diff --git a/src/ScadaLink.InboundAPI/ScadaLink.InboundAPI.csproj b/src/ScadaLink.InboundAPI/ScadaLink.InboundAPI.csproj index 94cf9ec..7e898aa 100644 --- a/src/ScadaLink.InboundAPI/ScadaLink.InboundAPI.csproj +++ b/src/ScadaLink.InboundAPI/ScadaLink.InboundAPI.csproj @@ -14,6 +14,9 @@ + + diff --git a/tests/ScadaLink.InboundAPI.Tests/Middleware/AuditWriteMiddlewareTests.cs b/tests/ScadaLink.InboundAPI.Tests/Middleware/AuditWriteMiddlewareTests.cs index a81a607..cd220a8 100644 --- a/tests/ScadaLink.InboundAPI.Tests/Middleware/AuditWriteMiddlewareTests.cs +++ b/tests/ScadaLink.InboundAPI.Tests/Middleware/AuditWriteMiddlewareTests.cs @@ -4,6 +4,8 @@ using System.Text.Json; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using ScadaLink.AuditLog.Configuration; using ScadaLink.Commons.Entities.Audit; using ScadaLink.Commons.Interfaces.Services; using ScadaLink.Commons.Types.Enums; @@ -79,8 +81,32 @@ public class AuditWriteMiddlewareTests private static AuditWriteMiddleware CreateMiddleware( RequestDelegate next, - ICentralAuditWriter writer) => - new(next, writer, NullLogger.Instance); + ICentralAuditWriter writer, + AuditLogOptions? options = null) => + new( + next, + writer, + NullLogger.Instance, + new StaticAuditLogOptionsMonitor(options ?? new AuditLogOptions())); + + /// + /// File-local test double — returns the + /// same snapshot on every read, no change-token plumbing required. Mirrors the + /// StaticMonitor pattern in + /// tests/ScadaLink.AuditLog.Tests/Payload/InboundChannelCapTests.cs. + /// + private sealed class StaticAuditLogOptionsMonitor : IOptionsMonitor + { + private readonly AuditLogOptions _value; + + public StaticAuditLogOptionsMonitor(AuditLogOptions value) => _value = value; + + public AuditLogOptions CurrentValue => _value; + + public AuditLogOptions Get(string? name) => _value; + + public IDisposable? OnChange(Action listener) => null; + } // --------------------------------------------------------------------- // 1. Happy path — InboundRequest/Delivered/HttpStatus 200 @@ -581,4 +607,86 @@ public class AuditWriteMiddlewareTests Assert.Equal(AuditStatus.Failed, evt.Status); Assert.Equal("partial", evt.ResponseSummary); } + + // --------------------------------------------------------------------- + // Bounded audit capture — memory safety follow-up. The capture site now + // honours AuditLogOptions.InboundMaxBytes at READ time (not just at + // filter-time), so a 500 MiB body cannot transiently allocate 500 MiB of + // string. The cap is local to the AUDIT copy; downstream readers and the + // real client still see every byte. + // --------------------------------------------------------------------- + + [Fact] + public async Task RequestBody_AboveInboundMaxBytes_TruncatedToCap_PayloadTruncatedTrue() + { + // 4 KiB cap, 20 KB body — the audit copy must be UTF-8 byte-safe + // capped at 4 KiB AND PayloadTruncated must flip, while the + // downstream handler still sees the full 20 KB payload. + const int cap = 4096; + var bigBody = new string('a', 20_000); + var writer = new RecordingAuditWriter(); + var ctx = BuildContext(body: bigBody); + + string? observedAfterMiddleware = null; + var mw = CreateMiddleware( + async hc => + { + using var reader = new StreamReader(hc.Request.Body); + observedAfterMiddleware = await reader.ReadToEndAsync(); + hc.Response.StatusCode = 200; + }, + writer, + options: new AuditLogOptions { InboundMaxBytes = cap }); + + await mw.InvokeAsync(ctx); + + // (iii) Downstream handler still sees the FULL body — the cap applied + // only to the audit copy. + Assert.Equal(bigBody, observedAfterMiddleware); + + var evt = Assert.Single(writer.Events); + // (i) Audit copy bounded at cap bytes (UTF-8 byte count). + Assert.NotNull(evt.RequestSummary); + Assert.True( + Encoding.UTF8.GetByteCount(evt.RequestSummary!) <= cap, + $"RequestSummary byte count {Encoding.UTF8.GetByteCount(evt.RequestSummary!)} exceeded cap {cap}"); + // (ii) Truncation flag set by the middleware (the filter will OR its + // own determination on top, but the middleware MUST set it itself). + Assert.True(evt.PayloadTruncated); + } + + [Fact] + public async Task ResponseBody_AboveInboundMaxBytes_TruncatedToCap_ClientStillReceivesAllBytes_PayloadTruncatedTrue() + { + // 4 KiB cap, 20 KB response — the test sink (acts as the real client) + // MUST receive all 20 KB while the audit copy is bounded at 4 KiB. + const int cap = 4096; + var bigResponse = new string('b', 20_000); + var writer = new RecordingAuditWriter(); + var ctx = BuildContext(); + var captured = new MemoryStream(); + ctx.Response.Body = captured; // stand-in for the client sink + + var mw = CreateMiddleware( + async hc => + { + hc.Response.StatusCode = 200; + await hc.Response.WriteAsync(bigResponse); + }, + writer, + options: new AuditLogOptions { InboundMaxBytes = cap }); + + await mw.InvokeAsync(ctx); + + // Client sink received every byte — the forwarding wrap is transparent. + Assert.Equal(bigResponse, Encoding.UTF8.GetString(captured.ToArray())); + + var evt = Assert.Single(writer.Events); + // Audit copy bounded at cap bytes. + Assert.NotNull(evt.ResponseSummary); + Assert.True( + Encoding.UTF8.GetByteCount(evt.ResponseSummary!) <= cap, + $"ResponseSummary byte count {Encoding.UTF8.GetByteCount(evt.ResponseSummary!)} exceeded cap {cap}"); + Assert.True(evt.PayloadTruncated); + } } diff --git a/tests/ScadaLink.InboundAPI.Tests/Middleware/MiddlewareOrderTests.cs b/tests/ScadaLink.InboundAPI.Tests/Middleware/MiddlewareOrderTests.cs index b74356f..f38dec2 100644 --- a/tests/ScadaLink.InboundAPI.Tests/Middleware/MiddlewareOrderTests.cs +++ b/tests/ScadaLink.InboundAPI.Tests/Middleware/MiddlewareOrderTests.cs @@ -6,6 +6,8 @@ using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using ScadaLink.AuditLog.Configuration; using ScadaLink.Commons.Entities.Audit; using ScadaLink.Commons.Interfaces.Services; using ScadaLink.Commons.Types.Enums; @@ -145,7 +147,8 @@ public class MiddlewareOrderTests // instantiates the type correctly. _ => Task.CompletedTask, writer, - NullLogger.Instance)); + NullLogger.Instance, + new StaticAuditLogOptionsMonitor(new AuditLogOptions()))); services.AddRouting(); services.AddAuthorization(); services.AddAuthentication("TestScheme") @@ -233,4 +236,22 @@ public class MiddlewareOrderTests return Task.FromResult(AuthenticateResult.Success(ticket)); } } + + /// + /// File-local test double — returns the + /// same snapshot on every read. Mirrors the helper in + /// AuditWriteMiddlewareTests. + /// + private sealed class StaticAuditLogOptionsMonitor : IOptionsMonitor + { + private readonly AuditLogOptions _value; + + public StaticAuditLogOptionsMonitor(AuditLogOptions value) => _value = value; + + public AuditLogOptions CurrentValue => _value; + + public AuditLogOptions Get(string? name) => _value; + + public IDisposable? OnChange(Action listener) => null; + } }