using System;
using System.Linq;
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Serilog;
namespace ZB.MOM.WW.LmxProxy.Host.Security
{
///
/// gRPC interceptor for API key authentication.
/// Validates API keys for incoming requests and enforces role-based access control.
///
public class ApiKeyInterceptor : Interceptor
{
private static readonly ILogger Logger = Log.ForContext();
///
/// List of gRPC method names that require write access.
///
private static readonly string[] WriteMethodNames =
{
"Write",
"WriteBatch",
"WriteBatchAndWait"
};
private readonly ApiKeyService _apiKeyService;
///
/// Initializes a new instance of the class.
///
/// The API key service used for validation.
/// Thrown if is null.
public ApiKeyInterceptor(ApiKeyService apiKeyService)
{
_apiKeyService = apiKeyService ?? throw new ArgumentNullException(nameof(apiKeyService));
}
///
/// Handles unary gRPC calls, validating API key and enforcing permissions.
///
/// The request type.
/// The response type.
/// The request message.
/// The server call context.
/// The continuation delegate.
/// The response message.
/// Thrown if authentication or authorization fails.
public override async Task UnaryServerHandler(
TRequest request,
ServerCallContext context,
UnaryServerMethod continuation)
{
string apiKey = GetApiKeyFromContext(context);
string methodName = GetMethodName(context.Method);
if (string.IsNullOrEmpty(apiKey))
{
Logger.Warning("Missing API key for method {Method} from {Peer}",
context.Method, context.Peer);
throw new RpcException(new Status(StatusCode.Unauthenticated, "API key is required"));
}
ApiKey? key = _apiKeyService.ValidateApiKey(apiKey);
if (key == null)
{
Logger.Warning("Invalid API key for method {Method} from {Peer}",
context.Method, context.Peer);
throw new RpcException(new Status(StatusCode.Unauthenticated, "Invalid API key"));
}
// Check if method requires write access
if (IsWriteMethod(methodName) && key.Role != ApiKeyRole.ReadWrite)
{
Logger.Warning("Insufficient permissions for method {Method} with API key {Description}",
context.Method, key.Description);
throw new RpcException(new Status(StatusCode.PermissionDenied,
"API key does not have write permissions"));
}
// Add API key info to context items for use in service methods
context.UserState["ApiKey"] = key;
Logger.Debug("Authorized method {Method} for API key {Description}",
context.Method, key.Description);
return await continuation(request, context);
}
///
/// Handles server streaming gRPC calls, validating API key and enforcing permissions.
///
/// The request type.
/// The response type.
/// The request message.
/// The response stream writer.
/// The server call context.
/// The continuation delegate.
/// A task representing the asynchronous operation.
/// Thrown if authentication fails.
public override async Task ServerStreamingServerHandler(
TRequest request,
IServerStreamWriter responseStream,
ServerCallContext context,
ServerStreamingServerMethod continuation)
{
string apiKey = GetApiKeyFromContext(context);
if (string.IsNullOrEmpty(apiKey))
{
Logger.Warning("Missing API key for streaming method {Method} from {Peer}",
context.Method, context.Peer);
throw new RpcException(new Status(StatusCode.Unauthenticated, "API key is required"));
}
ApiKey? key = _apiKeyService.ValidateApiKey(apiKey);
if (key == null)
{
Logger.Warning("Invalid API key for streaming method {Method} from {Peer}",
context.Method, context.Peer);
throw new RpcException(new Status(StatusCode.Unauthenticated, "Invalid API key"));
}
// Add API key info to context items
context.UserState["ApiKey"] = key;
Logger.Debug("Authorized streaming method {Method} for API key {Description}",
context.Method, key.Description);
await continuation(request, responseStream, context);
}
///
/// Extracts the API key from the gRPC request headers.
///
/// The server call context.
/// The API key value, or an empty string if not found.
private static string GetApiKeyFromContext(ServerCallContext context)
{
// Check for API key in metadata (headers)
Metadata.Entry? entry = context.RequestHeaders.FirstOrDefault(e =>
e.Key.Equals("x-api-key", StringComparison.OrdinalIgnoreCase));
return entry?.Value ?? string.Empty;
}
///
/// Gets the method name from the full gRPC method string.
///
/// The full method string (e.g., /package.Service/Method).
/// The method name.
private static string GetMethodName(string method)
{
// Method format is /package.Service/Method
int lastSlash = method.LastIndexOf('/');
return lastSlash >= 0 ? method.Substring(lastSlash + 1) : method;
}
///
/// Determines whether the specified method name requires write access.
///
/// The method name.
/// true if the method requires write access; otherwise, false.
private static bool IsWriteMethod(string methodName) =>
WriteMethodNames.Contains(methodName, StringComparer.OrdinalIgnoreCase);
}
}