using System.Collections.Concurrent;
using System.Reflection;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using Dahomey.Cbor.ObjectModel;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using SurrealDb.Net;
using SurrealDb.Net.Models.LiveQuery;
using SurrealDb.Net.Models;
using ZB.MOM.WW.CBDDC.Core;
using ZB.MOM.WW.CBDDC.Core.Network;
using ZB.MOM.WW.CBDDC.Core.Storage;
using ZB.MOM.WW.CBDDC.Core.Sync;
namespace ZB.MOM.WW.CBDDC.Persistence.Surreal;
///
/// Abstract base class for Surreal-backed document stores.
/// Handles local oplog/document-metadata persistence and remote-sync suppression.
///
/// The application context type used by the concrete store.
public abstract class SurrealDocumentStore : IDocumentStore, ISurrealCdcWorkerLifecycle, IDisposable
where TContext : class
{
private static readonly Regex SurrealIdentifierRegex = new("^[A-Za-z_][A-Za-z0-9_]*$", RegexOptions.Compiled);
private readonly List _cdcWatchers = new();
private readonly SurrealCdcPollingOptions _cdcPollingOptions;
private readonly SemaphoreSlim _cdcWorkerLifecycleGate = new(1, 1);
private readonly SemaphoreSlim _liveSelectSignal = new(0, 1);
private readonly ISurrealCdcCheckpointPersistence? _checkpointPersistence;
private readonly object _clockLock = new();
private readonly HashSet _registeredCollections = new(StringComparer.Ordinal);
///
/// Semaphore used to suppress CDC-triggered oplog entry creation during remote sync.
///
private readonly SemaphoreSlim _remoteSyncGuard = new(1, 1);
private readonly ConcurrentDictionary _suppressedCdcEvents = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary _watchedCollections = new(
StringComparer.Ordinal);
private CancellationTokenSource? _cdcWorkerCts;
private Task? _cdcWorkerTask;
private CancellationTokenSource? _liveSelectCts;
private readonly List _liveSelectTasks = new();
protected readonly IPeerNodeConfigurationProvider _configProvider;
protected readonly IConflictResolver _conflictResolver;
protected readonly TContext _context;
protected readonly ILogger> _logger;
protected readonly ICBDDCSurrealSchemaInitializer _schemaInitializer;
protected readonly ISurrealDbClient _surrealClient;
protected readonly IVectorClockService _vectorClock;
// HLC state for local change timestamp generation.
private int _logicalCounter;
private long _lastPhysicalTime;
///
/// Initializes a new instance of the class.
///
/// The application context used by the concrete store.
/// The embedded Surreal client provider.
/// The Surreal schema initializer.
/// The peer node configuration provider.
/// The vector clock service used for local oplog state.
/// Optional conflict resolver; defaults to last-write-wins.
/// Optional CDC checkpoint persistence component.
/// Optional CDC polling options.
/// Optional logger instance.
protected SurrealDocumentStore(
TContext context,
ICBDDCSurrealEmbeddedClient surrealEmbeddedClient,
ICBDDCSurrealSchemaInitializer schemaInitializer,
IPeerNodeConfigurationProvider configProvider,
IVectorClockService vectorClockService,
IConflictResolver? conflictResolver = null,
ISurrealCdcCheckpointPersistence? checkpointPersistence = null,
SurrealCdcPollingOptions? cdcPollingOptions = null,
ILogger? logger = null)
{
_context = context ?? throw new ArgumentNullException(nameof(context));
_ = surrealEmbeddedClient ?? throw new ArgumentNullException(nameof(surrealEmbeddedClient));
_surrealClient = surrealEmbeddedClient.Client;
_schemaInitializer = schemaInitializer ?? throw new ArgumentNullException(nameof(schemaInitializer));
_configProvider = configProvider ?? throw new ArgumentNullException(nameof(configProvider));
_vectorClock = vectorClockService ?? throw new ArgumentNullException(nameof(vectorClockService));
_conflictResolver = conflictResolver ?? new LastWriteWinsConflictResolver();
_checkpointPersistence = checkpointPersistence;
_cdcPollingOptions = NormalizePollingOptions(cdcPollingOptions);
_logger = CreateTypedLogger(logger);
_lastPhysicalTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
_logicalCounter = 0;
}
///
/// Releases managed resources used by this document store.
///
public virtual void Dispose()
{
try
{
StopCdcWorkerAsync(CancellationToken.None).GetAwaiter().GetResult();
}
catch
{
}
foreach (var watcher in _cdcWatchers)
try
{
watcher.Dispose();
}
catch
{
}
_cdcWatchers.Clear();
_cdcWorkerCts?.Dispose();
_liveSelectCts?.Dispose();
_liveSelectSignal.Dispose();
_cdcWorkerLifecycleGate.Dispose();
_remoteSyncGuard.Dispose();
}
private static ILogger> CreateTypedLogger(ILogger? logger)
{
if (logger is null) return NullLogger>.Instance;
if (logger is ILogger> typedLogger) return typedLogger;
return new ForwardingLogger(logger);
}
private sealed class ForwardingLogger : ILogger>
{
private readonly ILogger _inner;
///
/// Initializes a new instance of the class.
///
/// The logger instance to forward calls to.
public ForwardingLogger(ILogger inner)
{
_inner = inner;
}
///
public IDisposable? BeginScope(TState state) where TState : notnull
{
return _inner.BeginScope(state);
}
///
public bool IsEnabled(LogLevel logLevel)
{
return _inner.IsEnabled(logLevel);
}
///
public void Log(
LogLevel logLevel,
EventId eventId,
TState state,
Exception? exception,
Func formatter)
{
_inner.Log(logLevel, eventId, state, exception, formatter);
}
}
#region CDC Registration
private static string BuildSuppressionKey(string collection, string key, OperationType operationType)
{
return $"{collection}|{key}|{(int)operationType}";
}
private void RegisterSuppressedCdcEvent(string collection, string key, OperationType operationType)
{
string suppressionKey = BuildSuppressionKey(collection, key, operationType);
_suppressedCdcEvents.AddOrUpdate(suppressionKey, 1, (_, current) => current + 1);
}
private bool TryConsumeSuppressedCdcEvent(string collection, string key, OperationType operationType)
{
string suppressionKey = BuildSuppressionKey(collection, key, operationType);
while (true)
{
if (!_suppressedCdcEvents.TryGetValue(suppressionKey, out int current)) return false;
if (current <= 1) return _suppressedCdcEvents.TryRemove(suppressionKey, out _);
if (_suppressedCdcEvents.TryUpdate(suppressionKey, current - 1, current)) return true;
}
}
private bool IsCdcPollingWorkerActiveForCollection(string collection)
{
return IsCdcWorkerRunning &&
_watchedCollections.ContainsKey(collection);
}
///
/// Registers a watchable collection for local change tracking.
///
/// The entity type emitted by the watch source.
/// Logical collection name used by oplog and metadata records.
/// Watchable change source.
/// Function used to resolve the entity key.
/// Whether to subscribe to in-memory collection events.
protected void WatchCollection(
string collectionName,
ISurrealWatchableCollection collection,
Func keySelector,
bool subscribeForInMemoryEvents = true)
where TEntity : class
{
if (string.IsNullOrWhiteSpace(collectionName))
throw new ArgumentException("Collection name is required.", nameof(collectionName));
ArgumentNullException.ThrowIfNull(collection);
ArgumentNullException.ThrowIfNull(keySelector);
_registeredCollections.Add(collectionName);
string tableName = ResolveSurrealTableName(collection, collectionName);
_watchedCollections[collectionName] = new WatchedCollectionRegistration(collectionName, tableName);
if (!subscribeForInMemoryEvents) return;
var watcher = collection.Subscribe(new CdcObserver(collectionName, keySelector, this));
_cdcWatchers.Add(watcher);
}
private sealed class CdcObserver : IObserver>
where TEntity : class
{
private readonly string _collectionName;
private readonly Func _keySelector;
private readonly SurrealDocumentStore _store;
///
/// Initializes a new instance of the class.
///
/// The logical collection name.
/// The key selector for observed entities.
/// The owning document store.
public CdcObserver(
string collectionName,
Func keySelector,
SurrealDocumentStore store)
{
_collectionName = collectionName;
_keySelector = keySelector;
_store = store;
}
///
public void OnNext(SurrealCollectionChange changeEvent)
{
if (_store.IsCdcPollingWorkerActiveForCollection(_collectionName)) return;
var operationType = changeEvent.OperationType == OperationType.Delete
? OperationType.Delete
: OperationType.Put;
string entityId = changeEvent.DocumentId ?? "";
if (operationType == OperationType.Put && changeEvent.Entity != null)
{
string selectedKey = _keySelector(changeEvent.Entity);
if (!string.IsNullOrWhiteSpace(selectedKey)) entityId = selectedKey;
}
if (operationType == OperationType.Delete && string.IsNullOrWhiteSpace(entityId)) return;
if (_store.TryConsumeSuppressedCdcEvent(_collectionName, entityId, operationType)) return;
if (_store._remoteSyncGuard.CurrentCount == 0) return;
if (operationType == OperationType.Delete)
{
_store.OnLocalChangeDetectedAsync(_collectionName, entityId, OperationType.Delete, null)
.GetAwaiter().GetResult();
return;
}
if (changeEvent.Entity == null) return;
var content = JsonSerializer.SerializeToElement(changeEvent.Entity);
string key = _keySelector(changeEvent.Entity);
if (string.IsNullOrWhiteSpace(key)) key = entityId;
if (string.IsNullOrWhiteSpace(key)) return;
_store.OnLocalChangeDetectedAsync(_collectionName, key, OperationType.Put, content)
.GetAwaiter().GetResult();
}
///
public void OnError(Exception error)
{
}
///
public void OnCompleted()
{
}
}
private static string ResolveSurrealTableName(
ISurrealWatchableCollection collection,
string fallbackCollectionName)
where TEntity : class
{
Type collectionType = collection.GetType();
const BindingFlags flags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
foreach (string memberName in new[] { "TableName", "_tableName", "tableName" })
{
PropertyInfo? property = collectionType.GetProperty(memberName, flags);
if (property?.CanRead == true &&
property.GetValue(collection) is string propertyValue &&
!string.IsNullOrWhiteSpace(propertyValue))
return propertyValue;
FieldInfo? field = collectionType.GetField(memberName, flags);
if (field?.GetValue(collection) is string fieldValue &&
!string.IsNullOrWhiteSpace(fieldValue))
return fieldValue;
}
return fallbackCollectionName;
}
private static SurrealCdcPollingOptions NormalizePollingOptions(SurrealCdcPollingOptions? options)
{
TimeSpan interval = options?.PollInterval ?? TimeSpan.FromMilliseconds(250);
if (interval <= TimeSpan.Zero) interval = TimeSpan.FromMilliseconds(250);
int batchSize = options?.BatchSize ?? 100;
if (batchSize <= 0) batchSize = 100;
TimeSpan liveReconnectDelay = options?.LiveSelectReconnectDelay ?? TimeSpan.FromSeconds(2);
if (liveReconnectDelay <= TimeSpan.Zero) liveReconnectDelay = TimeSpan.FromSeconds(2);
return new SurrealCdcPollingOptions
{
Enabled = options?.Enabled ?? true,
PollInterval = interval,
BatchSize = batchSize,
EnableLiveSelectAccelerator = options?.EnableLiveSelectAccelerator ?? true,
LiveSelectReconnectDelay = liveReconnectDelay
};
}
private readonly record struct WatchedCollectionRegistration(
string CollectionName,
string TableName);
protected readonly record struct PendingCursorCheckpoint(
string TableName,
ulong Cursor);
#endregion
#region CDC Worker Lifecycle
///
public bool IsCdcWorkerRunning =>
_cdcWorkerTask != null &&
!_cdcWorkerTask.IsCompleted;
///
public async Task StartCdcWorkerAsync(CancellationToken cancellationToken = default)
{
if (!_cdcPollingOptions.Enabled)
{
_logger.LogDebug("Surreal CDC worker start skipped because polling is disabled.");
return;
}
if (_checkpointPersistence == null)
{
_logger.LogDebug("Surreal CDC worker start skipped because checkpoint persistence is not configured.");
return;
}
await _cdcWorkerLifecycleGate.WaitAsync(cancellationToken);
try
{
cancellationToken.ThrowIfCancellationRequested();
if (IsCdcWorkerRunning) return;
await EnsureReadyAsync(cancellationToken);
StartLiveSelectAcceleratorsUnsafe();
_cdcWorkerCts = new CancellationTokenSource();
_cdcWorkerTask = Task.Run(() => RunCdcWorkerAsync(_cdcWorkerCts.Token), CancellationToken.None);
_logger.LogInformation(
"Started Surreal CDC worker with interval {IntervalMs} ms, batch size {BatchSize}, live accelerator {LiveAccelerator}.",
_cdcPollingOptions.PollInterval.TotalMilliseconds,
_cdcPollingOptions.BatchSize,
_cdcPollingOptions.EnableLiveSelectAccelerator);
}
finally
{
_cdcWorkerLifecycleGate.Release();
}
}
///
public async Task PollCdcOnceAsync(CancellationToken cancellationToken = default)
{
if (!_cdcPollingOptions.Enabled) return;
if (_checkpointPersistence == null) return;
if (_watchedCollections.IsEmpty) return;
await EnsureReadyAsync(cancellationToken);
await PollWatchedCollectionsOnceAsync(cancellationToken);
}
///
public async Task StopCdcWorkerAsync(CancellationToken cancellationToken = default)
{
Task? workerTask;
CancellationTokenSource? workerCts;
Task[] liveSelectTasks;
CancellationTokenSource? liveSelectCts;
await _cdcWorkerLifecycleGate.WaitAsync(cancellationToken);
try
{
workerTask = _cdcWorkerTask;
workerCts = _cdcWorkerCts;
_cdcWorkerTask = null;
_cdcWorkerCts = null;
liveSelectTasks = _liveSelectTasks.ToArray();
_liveSelectTasks.Clear();
liveSelectCts = _liveSelectCts;
_liveSelectCts = null;
}
finally
{
_cdcWorkerLifecycleGate.Release();
}
if (workerTask == null)
{
workerCts?.Dispose();
if (liveSelectTasks.Length == 0)
{
liveSelectCts?.Dispose();
return;
}
}
try
{
workerCts?.Cancel();
liveSelectCts?.Cancel();
if (workerTask != null) await workerTask.WaitAsync(cancellationToken);
if (liveSelectTasks.Length > 0)
{
Task waitAll = Task.WhenAll(liveSelectTasks);
try
{
await waitAll.WaitAsync(cancellationToken);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
}
catch
{
}
}
}
catch (OperationCanceledException) when ((workerTask?.IsCanceled ?? false) || cancellationToken.IsCancellationRequested)
{
}
finally
{
workerCts?.Dispose();
liveSelectCts?.Dispose();
}
}
private async Task RunCdcWorkerAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
try
{
await PollCdcOnceAsync(cancellationToken);
if (!_cdcPollingOptions.EnableLiveSelectAccelerator || _liveSelectCts == null || _liveSelectTasks.Count == 0)
{
await Task.Delay(_cdcPollingOptions.PollInterval, cancellationToken);
continue;
}
Task delayTask = Task.Delay(_cdcPollingOptions.PollInterval, cancellationToken);
Task signalTask = _liveSelectSignal.WaitAsync(cancellationToken);
await Task.WhenAny(delayTask, signalTask);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
break;
}
catch (Exception exception)
{
_logger.LogError(exception, "Surreal CDC worker polling iteration failed.");
try
{
await Task.Delay(_cdcPollingOptions.PollInterval, cancellationToken);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
break;
}
}
_logger.LogDebug("Stopped Surreal CDC worker.");
}
private void StartLiveSelectAcceleratorsUnsafe()
{
if (!_cdcPollingOptions.EnableLiveSelectAccelerator) return;
if (_watchedCollections.IsEmpty) return;
if (_liveSelectCts != null) return;
_liveSelectCts = new CancellationTokenSource();
_liveSelectTasks.Clear();
foreach (WatchedCollectionRegistration watched in _watchedCollections.Values
.OrderBy(v => v.CollectionName, StringComparer.Ordinal))
_liveSelectTasks.Add(Task.Run(
() => RunLiveSelectAcceleratorAsync(watched, _liveSelectCts.Token),
CancellationToken.None));
}
private async Task RunLiveSelectAcceleratorAsync(
WatchedCollectionRegistration watched,
CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
try
{
await using var liveQuery =
await _surrealClient.LiveTable