diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/AccessTimeService.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/AccessTimeService.cs index 1d5f3aa..3bacfac 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/AccessTimeService.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/AccessTimeService.cs @@ -25,6 +25,18 @@ public static class AccessTimeService // Mirror Go's init(): nothing to pre-allocate in .NET. } + /// + /// Explicit init hook for Go parity. + /// Mirrors package init() in server/ats/ats.go. + /// This method is intentionally idempotent. + /// + public static void Init() + { + // Ensure a non-zero cached timestamp is present. + var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() * 1_000_000L; + Interlocked.CompareExchange(ref _utime, now, 0); + } + /// /// Registers a user. Starts the background timer when the first registrant calls this. /// Each call to must be paired with a call to . diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/IpQueue.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/IpQueue.cs index ac1d4bc..db0e0c8 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/IpQueue.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/IpQueue.cs @@ -40,6 +40,24 @@ public sealed class IpQueue /// Default maximum size of the recycled backing-list capacity. public const int DefaultMaxRecycleSize = 4 * 1024; + /// + /// Functional option type used by . + /// Mirrors Go ipQueueOpt. + /// + public delegate void IpQueueOption(IpQueueOptions options); + + /// + /// Option bag used by . + /// Mirrors Go ipQueueOpts. + /// + public sealed class IpQueueOptions + { + public int MaxRecycleSize { get; set; } = DefaultMaxRecycleSize; + public Func? SizeCalc { get; set; } + public ulong MaxSize { get; set; } + public int MaxLen { get; set; } + } + private long _inprogress; private readonly object _lock = new(); @@ -68,6 +86,56 @@ public sealed class IpQueue /// Notification channel reader — wait on this to learn items were added. public ChannelReader Ch => _ch.Reader; + /// + /// Option helper that configures maximum recycled backing-list size. + /// Mirrors Go ipqMaxRecycleSize. + /// + public static IpQueueOption IpqMaxRecycleSize(int max) => + options => options.MaxRecycleSize = max; + + /// + /// Option helper that enables size accounting for queue elements. + /// Mirrors Go ipqSizeCalculation. + /// + public static IpQueueOption IpqSizeCalculation(Func calc) => + options => options.SizeCalc = calc; + + /// + /// Option helper that limits queue pushes by total accounted size. + /// Mirrors Go ipqLimitBySize. + /// + public static IpQueueOption IpqLimitBySize(ulong max) => + options => options.MaxSize = max; + + /// + /// Option helper that limits queue pushes by element count. + /// Mirrors Go ipqLimitByLen. + /// + public static IpQueueOption IpqLimitByLen(int max) => + options => options.MaxLen = max; + + /// + /// Factory wrapper for Go parity. + /// Mirrors newIPQueue. + /// + public static IpQueue NewIPQueue( + string name, + ConcurrentDictionary? registry = null, + params IpQueueOption[] options) + { + var opts = new IpQueueOptions(); + foreach (var option in options) + option(opts); + + return new IpQueue( + name, + registry, + opts.MaxRecycleSize, + opts.SizeCalc, + opts.MaxSize, + opts.MaxLen); + } + /// /// Creates a new queue, optionally registering it in . /// Mirrors newIPQueue. diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/RateCounter.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/RateCounter.cs index e1eec15..76b69ea 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/RateCounter.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/RateCounter.cs @@ -38,6 +38,12 @@ public sealed class RateCounter Interval = TimeSpan.FromSeconds(1); } + /// + /// Factory wrapper for Go parity. + /// Mirrors newRateCounter. + /// + public static RateCounter NewRateCounter(long limit) => new(limit); + /// /// Returns true if the event is within the rate limit for the current window. /// Mirrors rateCounter.allow. diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/ServerUtilities.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/ServerUtilities.cs index 88ed9be..fa23173 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/ServerUtilities.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/ServerUtilities.cs @@ -14,6 +14,8 @@ // Adapted from server/util.go in the NATS server Go source. using System.Net; +using System.Text; +using System.Text.Json; using System.Text.RegularExpressions; namespace ZB.MOM.NatsNet.Server.Internal; @@ -268,6 +270,25 @@ public static class ServerUtilities return client; } + /// + /// Parity wrapper for Go natsDialTimeout. + /// Accepts a network label (tcp/tcp4/tcp6) and host:port address. + /// + public static Task NatsDialTimeout( + string network, string address, TimeSpan timeout) + { + if (!string.Equals(network, "tcp", StringComparison.OrdinalIgnoreCase) && + !string.Equals(network, "tcp4", StringComparison.OrdinalIgnoreCase) && + !string.Equals(network, "tcp6", StringComparison.OrdinalIgnoreCase)) + throw new NotSupportedException($"unsupported network: {network}"); + + var (host, port, err) = ParseHostPort(address, defaultPort: 0); + if (err != null || port <= 0) + throw new InvalidOperationException($"invalid dial address: {address}", err); + + return NatsDialTimeoutAsync(host, port, timeout); + } + // ------------------------------------------------------------------------- // URL redaction // ------------------------------------------------------------------------- @@ -337,6 +358,54 @@ public static class ServerUtilities return result; } + // ------------------------------------------------------------------------- + // RefCountedUrlSet wrappers (Go parity mapping) + // ------------------------------------------------------------------------- + + /// + /// Parity wrapper for . + /// Mirrors refCountedUrlSet.addUrl. + /// + public static bool AddUrl(RefCountedUrlSet urlSet, string urlStr) + { + ArgumentNullException.ThrowIfNull(urlSet); + return urlSet.AddUrl(urlStr); + } + + /// + /// Parity wrapper for . + /// Mirrors refCountedUrlSet.removeUrl. + /// + public static bool RemoveUrl(RefCountedUrlSet urlSet, string urlStr) + { + ArgumentNullException.ThrowIfNull(urlSet); + return urlSet.RemoveUrl(urlStr); + } + + /// + /// Parity wrapper for . + /// Mirrors refCountedUrlSet.getAsStringSlice. + /// + public static string[] GetAsStringSlice(RefCountedUrlSet urlSet) + { + ArgumentNullException.ThrowIfNull(urlSet); + return urlSet.GetAsStringSlice(); + } + + // ------------------------------------------------------------------------- + // INFO helpers + // ------------------------------------------------------------------------- + + /// + /// Serialises into an INFO line (INFO {...}\r\n). + /// Mirrors generateInfoJSON. + /// + public static byte[] GenerateInfoJSON(global::ZB.MOM.NatsNet.Server.ServerInfo info) + { + var json = JsonSerializer.Serialize(info); + return Encoding.UTF8.GetBytes($"INFO {json}\r\n"); + } + // ------------------------------------------------------------------------- // Copy helpers // ------------------------------------------------------------------------- @@ -391,6 +460,13 @@ public static class ServerUtilities return channel.Writer; } + + /// + /// Parity wrapper for . + /// Mirrors parallelTaskQueue. + /// + public static System.Threading.Channels.ChannelWriter ParallelTaskQueue(int maxParallelism = 0) => + CreateParallelTaskQueue(maxParallelism); } // ------------------------------------------------------------------------- diff --git a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/SignalHandler.cs b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/SignalHandler.cs index f149073..f0e4429 100644 --- a/dotnet/src/ZB.MOM.NatsNet.Server/Internal/SignalHandler.cs +++ b/dotnet/src/ZB.MOM.NatsNet.Server/Internal/SignalHandler.cs @@ -187,6 +187,12 @@ public static class SignalHandler _ => throw new ArgumentOutOfRangeException(nameof(command), $"unknown signal \"{CommandToString(command)}\""), }; + /// + /// Go parity alias for . + /// Mirrors CommandToSignal in signal.go. + /// + public static UnixSignal CommandToSignal(ServerCommand command) => CommandToUnixSignal(command); + private static Exception? SendSignal(int pid, UnixSignal signal) { try diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/AccessTimeServiceTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/AccessTimeServiceTests.cs index e77f9d7..216fc86 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/AccessTimeServiceTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/AccessTimeServiceTests.cs @@ -77,4 +77,16 @@ public sealed class AccessTimeServiceTests : IDisposable // Mirror: TestUnbalancedUnregister Should.Throw(() => AccessTimeService.Unregister()); } + + [Fact] + public void Init_ShouldBeIdempotentAndNonThrowing() + { + Should.NotThrow(() => AccessTimeService.Init()); + var first = AccessTimeService.AccessTime(); + first.ShouldBeGreaterThan(0); + + Should.NotThrow(() => AccessTimeService.Init()); + var second = AccessTimeService.AccessTime(); + second.ShouldBeGreaterThan(0); + } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/IpQueueTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/IpQueueTests.cs index 5bac632..6edd968 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/IpQueueTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/IpQueueTests.cs @@ -28,6 +28,62 @@ namespace ZB.MOM.NatsNet.Server.Tests.Internal; /// public sealed class IpQueueTests { + [Fact] + public void IpqMaxRecycleSize_ShouldAffectQueueConfig() + { + var q = IpQueue.NewIPQueue("opt-max-recycle", null, IpQueue.IpqMaxRecycleSize(123)); + q.MaxRecycleSize.ShouldBe(123); + } + + [Fact] + public void IpqSizeCalculation_AndLimitBySize_ShouldEnforceLimit() + { + var q = IpQueue.NewIPQueue( + "opt-size-limit", + null, + IpQueue.IpqSizeCalculation(e => (ulong)e.Length), + IpQueue.IpqLimitBySize(8)); + + var (_, err1) = q.Push(new byte[4]); + err1.ShouldBeNull(); + + var (_, err2) = q.Push(new byte[4]); + err2.ShouldBeNull(); + + var (_, err3) = q.Push(new byte[1]); + err3.ShouldBeSameAs(IpQueueErrors.SizeLimitReached); + } + + [Fact] + public void IpqLimitByLen_ShouldEnforceLengthLimit() + { + var q = IpQueue.NewIPQueue("opt-len-limit", null, IpQueue.IpqLimitByLen(2)); + + q.Push(1).error.ShouldBeNull(); + q.Push(2).error.ShouldBeNull(); + q.Push(3).error.ShouldBeSameAs(IpQueueErrors.LenLimitReached); + } + + [Fact] + public void NewIPQueue_ShouldApplyOptionsAndRegister() + { + var registry = new ConcurrentDictionary(); + var q = IpQueue.NewIPQueue( + "opt-factory", + registry, + IpQueue.IpqMaxRecycleSize(55), + IpQueue.IpqLimitByLen(1)); + + q.MaxRecycleSize.ShouldBe(55); + registry.TryGetValue("opt-factory", out var registered).ShouldBeTrue(); + registered.ShouldBeSameAs(q); + + var (_, err1) = q.Push(1); + err1.ShouldBeNull(); + var (_, err2) = q.Push(2); + err2.ShouldBeSameAs(IpQueueErrors.LenLimitReached); + } + [Fact] public void Basic_ShouldInitialiseCorrectly() { diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/RateCounterTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/RateCounterTests.cs index fbb97b3..3ad04e9 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/RateCounterTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/RateCounterTests.cs @@ -22,6 +22,17 @@ namespace ZB.MOM.NatsNet.Server.Tests.Internal; /// public sealed class RateCounterTests { + [Fact] + public void NewRateCounter_ShouldCreateWithDefaultInterval() + { + var counter = RateCounter.NewRateCounter(2); + counter.Interval.ShouldBe(TimeSpan.FromSeconds(1)); + + counter.Allow().ShouldBeTrue(); + counter.Allow().ShouldBeTrue(); + counter.Allow().ShouldBeFalse(); + } + [Fact] public async Task RateCounter_ShouldAllowUpToLimitThenBlockAndReset() { diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/ServerUtilitiesTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/ServerUtilitiesTests.cs index bb66dd4..9b8eb1f 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/ServerUtilitiesTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/ServerUtilitiesTests.cs @@ -11,7 +11,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System.Net; +using System.Text.Json; using Shouldly; +using ZB.MOM.NatsNet.Server; using ZB.MOM.NatsNet.Server.Internal; namespace ZB.MOM.NatsNet.Server.Tests.Internal; @@ -191,4 +194,86 @@ public sealed class ServerUtilitiesTests $"VersionAtLeast({version}, {major}, {minor}, {update})"); } } + + [Fact] + public void RefCountedUrlSet_Wrappers_ShouldTrackRefCounts() + { + var set = new RefCountedUrlSet(); + ServerUtilities.AddUrl(set, "nats://a:4222").ShouldBeTrue(); + ServerUtilities.AddUrl(set, "nats://a:4222").ShouldBeFalse(); + ServerUtilities.AddUrl(set, "nats://b:4222").ShouldBeTrue(); + + ServerUtilities.RemoveUrl(set, "nats://a:4222").ShouldBeFalse(); + ServerUtilities.RemoveUrl(set, "nats://a:4222").ShouldBeTrue(); + + var urls = ServerUtilities.GetAsStringSlice(set); + urls.Length.ShouldBe(1); + urls[0].ShouldBe("nats://b:4222"); + } + + [Fact] + public async Task NatsDialTimeout_ShouldConnectWithinTimeout() + { + using var listener = new System.Net.Sockets.TcpListener(IPAddress.Loopback, 0); + listener.Start(); + var port = ((IPEndPoint)listener.LocalEndpoint).Port; + var acceptTask = listener.AcceptTcpClientAsync(); + + using var client = await ServerUtilities.NatsDialTimeout( + "tcp", + $"127.0.0.1:{port}", + TimeSpan.FromSeconds(2)); + + client.Connected.ShouldBeTrue(); + using var accepted = await acceptTask; + accepted.Connected.ShouldBeTrue(); + } + + [Fact] + public void GenerateInfoJSON_ShouldEmitInfoLineWithCRLF() + { + var info = new ServerInfo + { + Id = "S1", + Name = "n1", + Host = "127.0.0.1", + Port = 4222, + Version = "2.0.0", + Proto = 1, + GoVersion = "go1.23", + }; + + var bytes = ServerUtilities.GenerateInfoJSON(info); + var line = System.Text.Encoding.UTF8.GetString(bytes); + line.ShouldStartWith("INFO "); + line.ShouldEndWith("\r\n"); + + var json = line["INFO ".Length..^2]; + var payload = JsonSerializer.Deserialize(json); + payload.ShouldNotBeNull(); + payload!.Id.ShouldBe("S1"); + } + + [Fact] + public async Task ParallelTaskQueue_ShouldExecuteQueuedActions() + { + var writer = ServerUtilities.ParallelTaskQueue(maxParallelism: 2); + var ran = 0; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + for (var i = 0; i < 4; i++) + { + var accepted = writer.TryWrite(() => + { + if (Interlocked.Increment(ref ran) == 4) + tcs.TrySetResult(); + }); + accepted.ShouldBeTrue(); + } + + writer.TryComplete().ShouldBeTrue(); + var finished = await Task.WhenAny(tcs.Task, Task.Delay(TimeSpan.FromSeconds(2))); + finished.ShouldBe(tcs.Task); + ran.ShouldBe(4); + } } diff --git a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/SignalHandlerTests.cs b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/SignalHandlerTests.cs index 9c0d134..7e9e14f 100644 --- a/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/SignalHandlerTests.cs +++ b/dotnet/tests/ZB.MOM.NatsNet.Server.Tests/Internal/SignalHandlerTests.cs @@ -35,6 +35,16 @@ public sealed class SignalHandlerTests : IDisposable SignalHandler.CommandToUnixSignal(ServerCommand.LameDuckMode).ShouldBe(UnixSignal.SigUsr2); } + [Fact] + public void CommandToSignal_ShouldMatchCommandToUnixSignal() + { + foreach (var command in Enum.GetValues()) + { + SignalHandler.CommandToSignal(command) + .ShouldBe(SignalHandler.CommandToUnixSignal(command)); + } + } + [Fact] // T:3155 public void SetProcessName_ShouldNotThrow() { diff --git a/porting.db b/porting.db index c7f70d7..18214a0 100644 Binary files a/porting.db and b/porting.db differ diff --git a/reports/current.md b/reports/current.md index a91f8e7..b86d6cb 100644 --- a/reports/current.md +++ b/reports/current.md @@ -1,6 +1,6 @@ # NATS .NET Porting Status Report -Generated: 2026-02-27 15:04:33 UTC +Generated: 2026-02-27 15:27:06 UTC ## Modules (12 total) @@ -12,10 +12,10 @@ Generated: 2026-02-27 15:04:33 UTC | Status | Count | |--------|-------| -| deferred | 2397 | -| n_a | 18 | +| deferred | 2377 | +| n_a | 24 | | stub | 1 | -| verified | 1257 | +| verified | 1271 | ## Unit Tests (3257 total) @@ -34,4 +34,4 @@ Generated: 2026-02-27 15:04:33 UTC ## Overall Progress -**1884/6942 items complete (27.1%)** +**1904/6942 items complete (27.4%)** diff --git a/reports/report_c0aaae9.md b/reports/report_c0aaae9.md new file mode 100644 index 0000000..b86d6cb --- /dev/null +++ b/reports/report_c0aaae9.md @@ -0,0 +1,37 @@ +# NATS .NET Porting Status Report + +Generated: 2026-02-27 15:27:06 UTC + +## Modules (12 total) + +| Status | Count | +|--------|-------| +| verified | 12 | + +## Features (3673 total) + +| Status | Count | +|--------|-------| +| deferred | 2377 | +| n_a | 24 | +| stub | 1 | +| verified | 1271 | + +## Unit Tests (3257 total) + +| Status | Count | +|--------|-------| +| deferred | 2660 | +| n_a | 187 | +| verified | 410 | + +## Library Mappings (36 total) + +| Status | Count | +|--------|-------| +| mapped | 36 | + + +## Overall Progress + +**1904/6942 items complete (27.4%)** diff --git a/tools/go-analyzer/analyzer.go b/tools/go-analyzer/analyzer.go index bf26165..96c3a00 100644 --- a/tools/go-analyzer/analyzer.go +++ b/tools/go-analyzer/analyzer.go @@ -256,6 +256,10 @@ func (a *Analyzer) parseTestFile(filePath string) ([]TestFunc, []ImportInfo, int } test.FeatureName = a.inferFeatureName(name) + test.BestFeatureIdx = -1 + if fn.Body != nil { + test.Calls = a.extractCalls(fn.Body) + } tests = append(tests, test) } @@ -331,6 +335,210 @@ func (a *Analyzer) inferFeatureName(testName string) string { return name } +// extractCalls walks an AST block statement and extracts all function/method calls. +func (a *Analyzer) extractCalls(body *ast.BlockStmt) []CallInfo { + seen := make(map[string]bool) + var calls []CallInfo + + ast.Inspect(body, func(n ast.Node) bool { + callExpr, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + var ci CallInfo + switch fun := callExpr.Fun.(type) { + case *ast.Ident: + ci = CallInfo{FuncName: fun.Name} + case *ast.SelectorExpr: + ci = CallInfo{ + RecvOrPkg: extractIdent(fun.X), + MethodName: fun.Sel.Name, + IsSelector: true, + } + default: + return true + } + + key := ci.callKey() + if !seen[key] && !isFilteredCall(ci) { + seen[key] = true + calls = append(calls, ci) + } + return true + }) + + return calls +} + +// extractIdent extracts an identifier name from an expression (handles X in X.Y). +func extractIdent(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + return extractIdent(e.X) + "." + e.Sel.Name + default: + return "" + } +} + +// isFilteredCall returns true if a call should be excluded from feature matching. +func isFilteredCall(c CallInfo) bool { + if c.IsSelector { + recv := c.RecvOrPkg + // testing.T/B methods + if recv == "t" || recv == "b" || recv == "tb" { + return true + } + // stdlib packages + if stdlibPkgs[recv] { + return true + } + // NATS client libs + if recv == "nats" || recv == "nuid" || recv == "nkeys" || recv == "jwt" { + return true + } + return false + } + + // Go builtins + name := c.FuncName + if builtinFuncs[name] { + return true + } + + // Test assertion helpers + lower := strings.ToLower(name) + if strings.HasPrefix(name, "require_") { + return true + } + for _, prefix := range []string{"check", "verify", "assert", "expect"} { + if strings.HasPrefix(lower, prefix) { + return true + } + } + + return false +} + +// featureRef identifies a feature within the analysis result. +type featureRef struct { + moduleIdx int + featureIdx int + goFile string + goClass string +} + +// resolveCallGraph matches test calls against known features across all modules. +func resolveCallGraph(result *AnalysisResult) { + // Build method index: go_method name → list of feature refs + methodIndex := make(map[string][]featureRef) + for mi, mod := range result.Modules { + for fi, feat := range mod.Features { + ref := featureRef{ + moduleIdx: mi, + featureIdx: fi, + goFile: feat.GoFile, + goClass: feat.GoClass, + } + methodIndex[feat.GoMethod] = append(methodIndex[feat.GoMethod], ref) + } + } + + // For each test, resolve calls to features + for mi := range result.Modules { + mod := &result.Modules[mi] + for ti := range mod.Tests { + test := &mod.Tests[ti] + seen := make(map[int]bool) // feature indices already linked + var linked []int + + testFileBase := sourceFileBase(test.GoFile) + + for _, call := range test.Calls { + // Look up the method name + name := call.MethodName + if !call.IsSelector { + name = call.FuncName + } + + candidates := methodIndex[name] + if len(candidates) == 0 { + continue + } + // Ambiguity threshold: skip very common method names + if len(candidates) > 10 { + continue + } + + // Filter to same module + var sameModule []featureRef + for _, ref := range candidates { + if ref.moduleIdx == mi { + sameModule = append(sameModule, ref) + } + } + if len(sameModule) == 0 { + continue + } + + for _, ref := range sameModule { + if !seen[ref.featureIdx] { + seen[ref.featureIdx] = true + linked = append(linked, ref.featureIdx) + } + } + } + + test.LinkedFeatures = linked + + // Set BestFeatureIdx using priority: + // (a) existing inferFeatureName match + // (b) same-file-base match + // (c) first remaining candidate + if test.BestFeatureIdx < 0 && len(linked) > 0 { + // Try same-file-base match first + for _, fi := range linked { + featFileBase := sourceFileBase(mod.Features[fi].GoFile) + if featFileBase == testFileBase { + test.BestFeatureIdx = fi + break + } + } + // Fall back to first candidate + if test.BestFeatureIdx < 0 { + test.BestFeatureIdx = linked[0] + } + } + } + } +} + +// sourceFileBase strips _test.go suffix and path to get the base file name. +func sourceFileBase(goFile string) string { + base := filepath.Base(goFile) + base = strings.TrimSuffix(base, "_test.go") + base = strings.TrimSuffix(base, ".go") + return base +} + +var stdlibPkgs = map[string]bool{ + "fmt": true, "time": true, "strings": true, "bytes": true, "errors": true, + "os": true, "math": true, "sort": true, "reflect": true, "sync": true, + "context": true, "io": true, "filepath": true, "strconv": true, + "encoding": true, "json": true, "binary": true, "hex": true, "rand": true, + "runtime": true, "atomic": true, "slices": true, "testing": true, + "net": true, "bufio": true, "crypto": true, "log": true, "regexp": true, + "unicode": true, "http": true, "url": true, +} + +var builtinFuncs = map[string]bool{ + "make": true, "append": true, "len": true, "cap": true, "close": true, + "delete": true, "panic": true, "recover": true, "print": true, + "println": true, "copy": true, "new": true, +} + // isStdlib checks if an import path is a Go standard library package. func isStdlib(importPath string) bool { firstSlash := strings.Index(importPath, "/") diff --git a/tools/go-analyzer/main.go b/tools/go-analyzer/main.go index 22f29a0..072960c 100644 --- a/tools/go-analyzer/main.go +++ b/tools/go-analyzer/main.go @@ -11,28 +11,47 @@ func main() { sourceDir := flag.String("source", "", "Path to Go source root (e.g., ../../golang/nats-server)") dbPath := flag.String("db", "", "Path to SQLite database file (e.g., ../../porting.db)") schemaPath := flag.String("schema", "", "Path to SQL schema file (e.g., ../../porting-schema.sql)") + mode := flag.String("mode", "full", "Analysis mode: 'full' (default) or 'call-graph' (incremental)") flag.Parse() - if *sourceDir == "" || *dbPath == "" || *schemaPath == "" { - fmt.Fprintf(os.Stderr, "Usage: go-analyzer --source --db --schema \n") + if *sourceDir == "" || *dbPath == "" { + fmt.Fprintf(os.Stderr, "Usage: go-analyzer --source --db [--schema ] [--mode full|call-graph]\n") flag.PrintDefaults() os.Exit(1) } + switch *mode { + case "full": + runFull(*sourceDir, *dbPath, *schemaPath) + case "call-graph": + runCallGraph(*sourceDir, *dbPath) + default: + log.Fatalf("Unknown mode %q: must be 'full' or 'call-graph'", *mode) + } +} + +func runFull(sourceDir, dbPath, schemaPath string) { + if schemaPath == "" { + log.Fatal("--schema is required for full mode") + } + // Open DB and apply schema - db, err := OpenDB(*dbPath, *schemaPath) + db, err := OpenDB(dbPath, schemaPath) if err != nil { log.Fatalf("Failed to open database: %v", err) } defer db.Close() // Run analysis - analyzer := NewAnalyzer(*sourceDir) + analyzer := NewAnalyzer(sourceDir) result, err := analyzer.Analyze() if err != nil { log.Fatalf("Analysis failed: %v", err) } + // Resolve call graph before writing + resolveCallGraph(result) + // Write to DB writer := NewDBWriter(db) if err := writer.WriteAll(result); err != nil { @@ -46,3 +65,35 @@ func main() { fmt.Printf(" Dependencies: %d\n", len(result.Dependencies)) fmt.Printf(" Imports: %d\n", len(result.Imports)) } + +func runCallGraph(sourceDir, dbPath string) { + // Open existing DB without schema + db, err := OpenDBNoSchema(dbPath) + if err != nil { + log.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Run analysis (parse Go source) + analyzer := NewAnalyzer(sourceDir) + result, err := analyzer.Analyze() + if err != nil { + log.Fatalf("Analysis failed: %v", err) + } + + // Resolve call graph + resolveCallGraph(result) + + // Update DB incrementally + writer := NewDBWriter(db) + stats, err := writer.UpdateCallGraph(result) + if err != nil { + log.Fatalf("Failed to update call graph: %v", err) + } + + fmt.Printf("Call graph analysis complete:\n") + fmt.Printf(" Tests analyzed: %d\n", stats.TestsAnalyzed) + fmt.Printf(" Tests linked: %d\n", stats.TestsLinked) + fmt.Printf(" Dependency rows: %d\n", stats.DependencyRows) + fmt.Printf(" Feature IDs set: %d\n", stats.FeatureIDsSet) +} diff --git a/tools/go-analyzer/sqlite.go b/tools/go-analyzer/sqlite.go index ff7448e..4e76bfb 100644 --- a/tools/go-analyzer/sqlite.go +++ b/tools/go-analyzer/sqlite.go @@ -152,3 +152,176 @@ func (w *DBWriter) insertLibrary(tx *sql.Tx, imp *ImportInfo) error { ) return err } + +// OpenDBNoSchema opens an existing SQLite database without applying schema. +// It verifies that the required tables exist. +func OpenDBNoSchema(dbPath string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON") + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + + // Verify required tables exist + for _, table := range []string{"modules", "features", "unit_tests", "dependencies"} { + var name string + err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name) + if err != nil { + db.Close() + return nil, fmt.Errorf("required table %q not found: %w", table, err) + } + } + + return db, nil +} + +// CallGraphStats holds summary statistics from a call-graph update. +type CallGraphStats struct { + TestsAnalyzed int + TestsLinked int + DependencyRows int + FeatureIDsSet int +} + +// UpdateCallGraph writes call-graph analysis results to the database incrementally. +func (w *DBWriter) UpdateCallGraph(result *AnalysisResult) (*CallGraphStats, error) { + stats := &CallGraphStats{} + + // Load module name→ID mapping + moduleIDs := make(map[string]int64) + rows, err := w.db.Query("SELECT id, name FROM modules") + if err != nil { + return nil, fmt.Errorf("querying modules: %w", err) + } + for rows.Next() { + var id int64 + var name string + if err := rows.Scan(&id, &name); err != nil { + rows.Close() + return nil, err + } + moduleIDs[name] = id + } + rows.Close() + + // Load feature DB IDs: "module_name:go_method:go_class" → id + type featureKey struct { + moduleName string + goMethod string + goClass string + } + featureDBIDs := make(map[featureKey]int64) + rows, err = w.db.Query(` + SELECT f.id, m.name, f.go_method, COALESCE(f.go_class, '') + FROM features f + JOIN modules m ON f.module_id = m.id + `) + if err != nil { + return nil, fmt.Errorf("querying features: %w", err) + } + for rows.Next() { + var id int64 + var modName, goMethod, goClass string + if err := rows.Scan(&id, &modName, &goMethod, &goClass); err != nil { + rows.Close() + return nil, err + } + featureDBIDs[featureKey{modName, goMethod, goClass}] = id + } + rows.Close() + + // Load test DB IDs: "module_name:go_method" → id + testDBIDs := make(map[string]int64) + rows, err = w.db.Query(` + SELECT ut.id, m.name, ut.go_method + FROM unit_tests ut + JOIN modules m ON ut.module_id = m.id + `) + if err != nil { + return nil, fmt.Errorf("querying unit_tests: %w", err) + } + for rows.Next() { + var id int64 + var modName, goMethod string + if err := rows.Scan(&id, &modName, &goMethod); err != nil { + rows.Close() + return nil, err + } + testDBIDs[modName+":"+goMethod] = id + } + rows.Close() + + // Begin transaction + tx, err := w.db.Begin() + if err != nil { + return nil, fmt.Errorf("beginning transaction: %w", err) + } + defer tx.Rollback() + + // Clear old call-graph data + if _, err := tx.Exec("DELETE FROM dependencies WHERE source_type='unit_test' AND dependency_kind='calls'"); err != nil { + return nil, fmt.Errorf("clearing old dependencies: %w", err) + } + if _, err := tx.Exec("UPDATE unit_tests SET feature_id = NULL"); err != nil { + return nil, fmt.Errorf("clearing old feature_ids: %w", err) + } + + // Prepare statements + insertDep, err := tx.Prepare("INSERT OR IGNORE INTO dependencies (source_type, source_id, target_type, target_id, dependency_kind) VALUES ('unit_test', ?, 'feature', ?, 'calls')") + if err != nil { + return nil, fmt.Errorf("preparing insert dependency: %w", err) + } + defer insertDep.Close() + + updateFeatureID, err := tx.Prepare("UPDATE unit_tests SET feature_id = ? WHERE id = ?") + if err != nil { + return nil, fmt.Errorf("preparing update feature_id: %w", err) + } + defer updateFeatureID.Close() + + // Process each module's tests + for _, mod := range result.Modules { + for _, test := range mod.Tests { + stats.TestsAnalyzed++ + + testDBID, ok := testDBIDs[mod.Name+":"+test.GoMethod] + if !ok { + continue + } + + // Insert dependency rows for linked features + if len(test.LinkedFeatures) > 0 { + stats.TestsLinked++ + } + for _, fi := range test.LinkedFeatures { + feat := mod.Features[fi] + featDBID, ok := featureDBIDs[featureKey{mod.Name, feat.GoMethod, feat.GoClass}] + if !ok { + continue + } + if _, err := insertDep.Exec(testDBID, featDBID); err != nil { + return nil, fmt.Errorf("inserting dependency for test %s: %w", test.GoMethod, err) + } + stats.DependencyRows++ + } + + // Set feature_id for best match + if test.BestFeatureIdx >= 0 { + feat := mod.Features[test.BestFeatureIdx] + featDBID, ok := featureDBIDs[featureKey{mod.Name, feat.GoMethod, feat.GoClass}] + if !ok { + continue + } + if _, err := updateFeatureID.Exec(featDBID, testDBID); err != nil { + return nil, fmt.Errorf("updating feature_id for test %s: %w", test.GoMethod, err) + } + stats.FeatureIDsSet++ + } + } + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("committing transaction: %w", err) + } + + return stats, nil +} diff --git a/tools/go-analyzer/types.go b/tools/go-analyzer/types.go index 3cbeaba..c6702fd 100644 --- a/tools/go-analyzer/types.go +++ b/tools/go-analyzer/types.go @@ -58,6 +58,28 @@ type TestFunc struct { GoLineCount int // FeatureName links this test to a feature by naming convention FeatureName string + // Calls holds raw function/method calls extracted from the test body AST + Calls []CallInfo + // LinkedFeatures holds indices into the parent module's Features slice + LinkedFeatures []int + // BestFeatureIdx is the primary feature match index (-1 = none) + BestFeatureIdx int +} + +// CallInfo represents a function or method call extracted from a test body. +type CallInfo struct { + FuncName string // direct call name: "newMemStore" + RecvOrPkg string // selector receiver/pkg: "ms", "fmt", "t" + MethodName string // selector method: "StoreMsg", "Fatalf" + IsSelector bool // true for X.Y() form +} + +// callKey returns a deduplication key for this call. +func (c CallInfo) callKey() string { + if c.IsSelector { + return c.RecvOrPkg + "." + c.MethodName + } + return c.FuncName } // Dependency represents a call relationship between two items.