using System; using System.Linq; using System.Threading.Tasks; using Grpc.Core; using Grpc.Core.Interceptors; using Serilog; namespace ZB.MOM.WW.LmxProxy.Host.Security { /// /// gRPC interceptor for API key authentication. /// Validates API keys for incoming requests and enforces role-based access control. /// public class ApiKeyInterceptor : Interceptor { private static readonly ILogger Logger = Log.ForContext(); /// /// List of gRPC method names that require write access. /// private static readonly string[] WriteMethodNames = { "Write", "WriteBatch", "WriteBatchAndWait" }; private readonly ApiKeyService _apiKeyService; /// /// Initializes a new instance of the class. /// /// The API key service used for validation. /// Thrown if is null. public ApiKeyInterceptor(ApiKeyService apiKeyService) { _apiKeyService = apiKeyService ?? throw new ArgumentNullException(nameof(apiKeyService)); } /// /// Handles unary gRPC calls, validating API key and enforcing permissions. /// /// The request type. /// The response type. /// The request message. /// The server call context. /// The continuation delegate. /// The response message. /// Thrown if authentication or authorization fails. public override async Task UnaryServerHandler( TRequest request, ServerCallContext context, UnaryServerMethod continuation) { string apiKey = GetApiKeyFromContext(context); string methodName = GetMethodName(context.Method); if (string.IsNullOrEmpty(apiKey)) { Logger.Warning("Missing API key for method {Method} from {Peer}", context.Method, context.Peer); throw new RpcException(new Status(StatusCode.Unauthenticated, "API key is required")); } ApiKey? key = _apiKeyService.ValidateApiKey(apiKey); if (key == null) { Logger.Warning("Invalid API key for method {Method} from {Peer}", context.Method, context.Peer); throw new RpcException(new Status(StatusCode.Unauthenticated, "Invalid API key")); } // Check if method requires write access if (IsWriteMethod(methodName) && key.Role != ApiKeyRole.ReadWrite) { Logger.Warning("Insufficient permissions for method {Method} with API key {Description}", context.Method, key.Description); throw new RpcException(new Status(StatusCode.PermissionDenied, "API key does not have write permissions")); } // Add API key info to context items for use in service methods context.UserState["ApiKey"] = key; Logger.Debug("Authorized method {Method} for API key {Description}", context.Method, key.Description); return await continuation(request, context); } /// /// Handles server streaming gRPC calls, validating API key and enforcing permissions. /// /// The request type. /// The response type. /// The request message. /// The response stream writer. /// The server call context. /// The continuation delegate. /// A task representing the asynchronous operation. /// Thrown if authentication fails. public override async Task ServerStreamingServerHandler( TRequest request, IServerStreamWriter responseStream, ServerCallContext context, ServerStreamingServerMethod continuation) { string apiKey = GetApiKeyFromContext(context); if (string.IsNullOrEmpty(apiKey)) { Logger.Warning("Missing API key for streaming method {Method} from {Peer}", context.Method, context.Peer); throw new RpcException(new Status(StatusCode.Unauthenticated, "API key is required")); } ApiKey? key = _apiKeyService.ValidateApiKey(apiKey); if (key == null) { Logger.Warning("Invalid API key for streaming method {Method} from {Peer}", context.Method, context.Peer); throw new RpcException(new Status(StatusCode.Unauthenticated, "Invalid API key")); } // Add API key info to context items context.UserState["ApiKey"] = key; Logger.Debug("Authorized streaming method {Method} for API key {Description}", context.Method, key.Description); await continuation(request, responseStream, context); } /// /// Extracts the API key from the gRPC request headers. /// /// The server call context. /// The API key value, or an empty string if not found. private static string GetApiKeyFromContext(ServerCallContext context) { // Check for API key in metadata (headers) Metadata.Entry? entry = context.RequestHeaders.FirstOrDefault(e => e.Key.Equals("x-api-key", StringComparison.OrdinalIgnoreCase)); return entry?.Value ?? string.Empty; } /// /// Gets the method name from the full gRPC method string. /// /// The full method string (e.g., /package.Service/Method). /// The method name. private static string GetMethodName(string method) { // Method format is /package.Service/Method int lastSlash = method.LastIndexOf('/'); return lastSlash >= 0 ? method.Substring(lastSlash + 1) : method; } /// /// Determines whether the specified method name requires write access. /// /// The method name. /// true if the method requires write access; otherwise, false. private static bool IsWriteMethod(string methodName) => WriteMethodNames.Contains(methodName, StringComparer.OrdinalIgnoreCase); } }