fix(management-service): resolve ManagementService-001/002/003 — enforce site scope on query/snapshot handlers and DebugStreamHub

This commit is contained in:
Joseph Doherty
2026-05-16 19:47:17 -04:00
parent 6f4efdfa2e
commit b249ca3bf7
5 changed files with 404 additions and 28 deletions

View File

@@ -2,6 +2,7 @@ using System.Text;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using ScadaLink.Commons.Interfaces.Repositories;
using ScadaLink.Commons.Messages.DebugView;
using ScadaLink.Commons.Messages.Streaming;
using ScadaLink.Communication;
@@ -17,6 +18,26 @@ namespace ScadaLink.ManagementService;
public class DebugStreamHub : Hub
{
private const string SessionIdKey = "DebugStreamSessionId";
private const string RolesKey = "DebugStreamRoles";
private const string PermittedSiteIdsKey = "DebugStreamPermittedSiteIds";
/// <summary>
/// Pure site-scope authorization check for a debug-stream subscription.
/// Returns true when the caller may subscribe to a debug stream for an instance
/// belonging to <paramref name="instanceSiteId"/>.
/// Admin role, or an empty <paramref name="permittedSiteIds"/> (system-wide
/// Deployment), grants access to any site; otherwise the instance's site must be
/// in the permitted set.
/// </summary>
public static bool IsInstanceAccessAllowed(
IReadOnlyCollection<string> roles,
IReadOnlyCollection<string> permittedSiteIds,
int instanceSiteId)
{
if (roles.Contains("Admin", StringComparer.OrdinalIgnoreCase)) return true;
if (permittedSiteIds.Count == 0) return true; // system-wide deployment
return permittedSiteIds.Contains(instanceSiteId.ToString());
}
private readonly DebugStreamService _debugStreamService;
private readonly IHubContext<DebugStreamHub> _hubContext;
@@ -93,6 +114,11 @@ public class DebugStreamHub : Hub
return;
}
// Persist the resolved identity on the connection so per-instance site-scope
// enforcement can be applied to SubscribeInstance calls.
Context.Items[RolesKey] = mappingResult.Roles.ToArray();
Context.Items[PermittedSiteIdsKey] = mappingResult.PermittedSiteIds.ToArray();
_logger.LogInformation("DebugStreamHub connection established for {Username}", username);
await base.OnConnectedAsync();
}
@@ -108,6 +134,41 @@ public class DebugStreamHub : Hub
var connectionId = Context.ConnectionId;
// Per-instance site-scope enforcement: a site-scoped Deployment user must not
// be able to stream an instance belonging to a site outside their scope.
var httpContext = Context.GetHttpContext();
if (httpContext == null)
{
_logger.LogWarning("DebugStreamHub: {ConnectionId} subscribe rejected — no HTTP context", connectionId);
await Clients.Caller.SendAsync("OnStreamTerminated", "Authorization context unavailable.");
return;
}
var roles = Context.Items.TryGetValue(RolesKey, out var rolesObj) && rolesObj is string[] r
? r : Array.Empty<string>();
var permittedSiteIds = Context.Items.TryGetValue(PermittedSiteIdsKey, out var sitesObj) && sitesObj is string[] s
? s : Array.Empty<string>();
var instanceRepo = httpContext.RequestServices.GetRequiredService<ITemplateEngineRepository>();
var instance = await instanceRepo.GetInstanceByIdAsync(instanceId);
if (instance == null)
{
_logger.LogWarning("DebugStreamHub: {ConnectionId} subscribe rejected — instance {InstanceId} not found",
connectionId, instanceId);
await Clients.Caller.SendAsync("OnStreamTerminated", $"Instance {instanceId} not found.");
return;
}
if (!IsInstanceAccessAllowed(roles, permittedSiteIds, instance.SiteId))
{
_logger.LogWarning(
"DebugStreamHub: {ConnectionId} subscribe to instance {InstanceId} denied — site {SiteId} outside permitted scope",
connectionId, instanceId, instance.SiteId);
await Clients.Caller.SendAsync("OnStreamTerminated",
$"Access denied: instance {instanceId} belongs to a site outside your Deployment scope.");
return;
}
try
{
// Use IHubContext for callbacks — the hub instance is transient (disposed after method returns),

View File

@@ -164,7 +164,7 @@ public class ManagementActor : ReceiveActor
// Instances
ListInstancesCommand cmd => await HandleListInstances(sp, cmd, user),
GetInstanceCommand cmd => await HandleGetInstance(sp, cmd),
GetInstanceCommand cmd => await HandleGetInstance(sp, cmd, user),
CreateInstanceCommand cmd => await HandleCreateInstance(sp, cmd, user),
MgmtDeployInstanceCommand cmd => await HandleDeployInstance(sp, cmd, user),
MgmtEnableInstanceCommand cmd => await HandleEnableInstance(sp, cmd, user),
@@ -179,18 +179,18 @@ public class ManagementActor : ReceiveActor
// Sites
ListSitesCommand => await HandleListSites(sp, user),
GetSiteCommand cmd => await HandleGetSite(sp, cmd),
GetSiteCommand cmd => await HandleGetSite(sp, cmd, user),
CreateSiteCommand cmd => await HandleCreateSite(sp, cmd, user.Username),
UpdateSiteCommand cmd => await HandleUpdateSite(sp, cmd, user.Username),
DeleteSiteCommand cmd => await HandleDeleteSite(sp, cmd, user.Username),
ListAreasCommand cmd => await HandleListAreas(sp, cmd),
ListAreasCommand cmd => await HandleListAreas(sp, cmd, user),
CreateAreaCommand cmd => await HandleCreateArea(sp, cmd, user.Username),
DeleteAreaCommand cmd => await HandleDeleteArea(sp, cmd, user.Username),
UpdateAreaCommand cmd => await HandleUpdateArea(sp, cmd, user.Username),
// Data Connections
ListDataConnectionsCommand cmd => await HandleListDataConnections(sp, cmd),
GetDataConnectionCommand cmd => await HandleGetDataConnection(sp, cmd),
GetDataConnectionCommand cmd => await HandleGetDataConnection(sp, cmd, user),
CreateDataConnectionCommand cmd => await HandleCreateDataConnection(sp, cmd, user.Username),
UpdateDataConnectionCommand cmd => await HandleUpdateDataConnection(sp, cmd, user.Username),
DeleteDataConnectionCommand cmd => await HandleDeleteDataConnection(sp, cmd, user.Username),
@@ -263,11 +263,11 @@ public class ManagementActor : ReceiveActor
GetSiteHealthCommand cmd => HandleGetSiteHealth(sp, cmd),
// Remote Queries
QueryEventLogsCommand cmd => await HandleQueryEventLogs(sp, cmd),
QueryParkedMessagesCommand cmd => await HandleQueryParkedMessages(sp, cmd),
RetryParkedMessageCommand cmd => await HandleRetryParkedMessage(sp, cmd),
DiscardParkedMessageCommand cmd => await HandleDiscardParkedMessage(sp, cmd),
DebugSnapshotCommand cmd => await HandleDebugSnapshot(sp, cmd),
QueryEventLogsCommand cmd => await HandleQueryEventLogs(sp, cmd, user),
QueryParkedMessagesCommand cmd => await HandleQueryParkedMessages(sp, cmd, user),
RetryParkedMessageCommand cmd => await HandleRetryParkedMessage(sp, cmd, user),
DiscardParkedMessageCommand cmd => await HandleDiscardParkedMessage(sp, cmd, user),
DebugSnapshotCommand cmd => await HandleDebugSnapshot(sp, cmd, user),
// Role resolution (for CLI LDAP auth)
ResolveRolesCommand cmd => await HandleResolveRoles(sp, cmd),
@@ -329,6 +329,21 @@ public class ManagementActor : ReceiveActor
EnforceSiteScope(user, instance.SiteId);
}
/// <summary>
/// Resolves a site by its string identifier and enforces site-scope.
/// Used by remote-query handlers that key off the site identifier rather than its ID.
/// </summary>
private static async Task EnforceSiteScopeForIdentifier(IServiceProvider sp, AuthenticatedUser user, string siteIdentifier)
{
if (user.PermittedSiteIds.Length == 0) return;
if (user.Roles.Contains("Admin", StringComparer.OrdinalIgnoreCase)) return;
var repo = sp.GetRequiredService<ISiteRepository>();
var site = await repo.GetSiteByIdentifierAsync(siteIdentifier);
if (site != null)
EnforceSiteScope(user, site.Id);
}
/// <summary>
/// Helper to log an audit entry after a successful mutation.
/// </summary>
@@ -507,10 +522,13 @@ public class ManagementActor : ReceiveActor
return instances;
}
private static async Task<object?> HandleGetInstance(IServiceProvider sp, GetInstanceCommand cmd)
private static async Task<object?> HandleGetInstance(IServiceProvider sp, GetInstanceCommand cmd, AuthenticatedUser user)
{
var repo = sp.GetRequiredService<ITemplateEngineRepository>();
return await repo.GetInstanceByIdAsync(cmd.InstanceId);
var instance = await repo.GetInstanceByIdAsync(cmd.InstanceId);
if (instance != null)
EnforceSiteScope(user, instance.SiteId);
return instance;
}
private static async Task<object?> HandleCreateInstance(IServiceProvider sp, CreateInstanceCommand cmd, AuthenticatedUser user)
@@ -638,16 +656,18 @@ public class ManagementActor : ReceiveActor
: throw new InvalidOperationException(result.Error);
}
private static async Task<object?> HandleRetryParkedMessage(IServiceProvider sp, RetryParkedMessageCommand cmd)
private static async Task<object?> HandleRetryParkedMessage(IServiceProvider sp, RetryParkedMessageCommand cmd, AuthenticatedUser user)
{
await EnforceSiteScopeForIdentifier(sp, user, cmd.SiteIdentifier);
var commService = sp.GetRequiredService<CommunicationService>();
var request = new Commons.Messages.RemoteQuery.ParkedMessageRetryRequest(
Guid.NewGuid().ToString("N"), cmd.SiteIdentifier, cmd.MessageId, DateTimeOffset.UtcNow);
return await commService.RetryParkedMessageAsync(cmd.SiteIdentifier, request);
}
private static async Task<object?> HandleDiscardParkedMessage(IServiceProvider sp, DiscardParkedMessageCommand cmd)
private static async Task<object?> HandleDiscardParkedMessage(IServiceProvider sp, DiscardParkedMessageCommand cmd, AuthenticatedUser user)
{
await EnforceSiteScopeForIdentifier(sp, user, cmd.SiteIdentifier);
var commService = sp.GetRequiredService<CommunicationService>();
var request = new Commons.Messages.RemoteQuery.ParkedMessageDiscardRequest(
Guid.NewGuid().ToString("N"), cmd.SiteIdentifier, cmd.MessageId, DateTimeOffset.UtcNow);
@@ -670,10 +690,13 @@ public class ManagementActor : ReceiveActor
return sites;
}
private static async Task<object?> HandleGetSite(IServiceProvider sp, GetSiteCommand cmd)
private static async Task<object?> HandleGetSite(IServiceProvider sp, GetSiteCommand cmd, AuthenticatedUser user)
{
var repo = sp.GetRequiredService<ISiteRepository>();
return await repo.GetSiteByIdAsync(cmd.SiteId);
var site = await repo.GetSiteByIdAsync(cmd.SiteId);
if (site != null)
EnforceSiteScope(user, site.Id);
return site;
}
private static async Task<object?> HandleCreateSite(IServiceProvider sp, CreateSiteCommand cmd, string user)
@@ -730,8 +753,9 @@ public class ManagementActor : ReceiveActor
return true;
}
private static async Task<object?> HandleListAreas(IServiceProvider sp, ListAreasCommand cmd)
private static async Task<object?> HandleListAreas(IServiceProvider sp, ListAreasCommand cmd, AuthenticatedUser user)
{
EnforceSiteScope(user, cmd.SiteId);
var repo = sp.GetRequiredService<ICentralUiRepository>();
return await repo.GetAreaTreeBySiteIdAsync(cmd.SiteId);
}
@@ -771,10 +795,13 @@ public class ManagementActor : ReceiveActor
return await repo.GetAllDataConnectionsAsync();
}
private static async Task<object?> HandleGetDataConnection(IServiceProvider sp, GetDataConnectionCommand cmd)
private static async Task<object?> HandleGetDataConnection(IServiceProvider sp, GetDataConnectionCommand cmd, AuthenticatedUser user)
{
var repo = sp.GetRequiredService<ISiteRepository>();
return await repo.GetDataConnectionByIdAsync(cmd.DataConnectionId);
var conn = await repo.GetDataConnectionByIdAsync(cmd.DataConnectionId);
if (conn != null)
EnforceSiteScope(user, conn.SiteId);
return conn;
}
private static async Task<object?> HandleCreateDataConnection(IServiceProvider sp, CreateDataConnectionCommand cmd, string user)
@@ -1462,8 +1489,9 @@ public class ManagementActor : ReceiveActor
// Remote Query handlers
// ========================================================================
private static async Task<object?> HandleQueryEventLogs(IServiceProvider sp, QueryEventLogsCommand cmd)
private static async Task<object?> HandleQueryEventLogs(IServiceProvider sp, QueryEventLogsCommand cmd, AuthenticatedUser user)
{
await EnforceSiteScopeForIdentifier(sp, user, cmd.SiteIdentifier);
var commService = sp.GetRequiredService<CommunicationService>();
var request = new EventLogQueryRequest(
Guid.NewGuid().ToString("N"),
@@ -1478,8 +1506,9 @@ public class ManagementActor : ReceiveActor
return await commService.QueryEventLogsAsync(cmd.SiteIdentifier, request);
}
private static async Task<object?> HandleQueryParkedMessages(IServiceProvider sp, QueryParkedMessagesCommand cmd)
private static async Task<object?> HandleQueryParkedMessages(IServiceProvider sp, QueryParkedMessagesCommand cmd, AuthenticatedUser user)
{
await EnforceSiteScopeForIdentifier(sp, user, cmd.SiteIdentifier);
var commService = sp.GetRequiredService<CommunicationService>();
var request = new ParkedMessageQueryRequest(
Guid.NewGuid().ToString("N"),
@@ -1490,12 +1519,14 @@ public class ManagementActor : ReceiveActor
return await commService.QueryParkedMessagesAsync(cmd.SiteIdentifier, request);
}
private static async Task<object?> HandleDebugSnapshot(IServiceProvider sp, DebugSnapshotCommand cmd)
private static async Task<object?> HandleDebugSnapshot(IServiceProvider sp, DebugSnapshotCommand cmd, AuthenticatedUser user)
{
var instanceRepo = sp.GetRequiredService<ITemplateEngineRepository>();
var instance = await instanceRepo.GetInstanceByIdAsync(cmd.InstanceId)
?? throw new InvalidOperationException($"Instance {cmd.InstanceId} not found.");
EnforceSiteScope(user, instance.SiteId);
var siteRepo = sp.GetRequiredService<ISiteRepository>();
var site = await siteRepo.GetSiteByIdAsync(instance.SiteId)
?? throw new InvalidOperationException($"Site {instance.SiteId} not found.");