diff --git a/src/ServiceControl.Persistence.RavenDB/Throughput/LicensingDataStore.cs b/src/ServiceControl.Persistence.RavenDB/Throughput/LicensingDataStore.cs index 6da42ddc8d..d6455cb025 100644 --- a/src/ServiceControl.Persistence.RavenDB/Throughput/LicensingDataStore.cs +++ b/src/ServiceControl.Persistence.RavenDB/Throughput/LicensingDataStore.cs @@ -190,6 +190,10 @@ public async Task UpdateUserIndicatorOnEndpoints(List userI .Where(document => document.SanitizedName.In(updates.Keys) || document.EndpointId.Name.In(updates.Keys)); var documents = await query.ToListAsync(cancellationToken); + + // Collect sanitized names needing sibling propagation to avoid issuing a query per document in the loop below. + var sanitizedNameToUserIndicator = new Dictionary(); + foreach (var document in documents) { if (updates.TryGetValue(document.SanitizedName, out var newValueFromSanitizedName)) @@ -199,14 +203,25 @@ public async Task UpdateUserIndicatorOnEndpoints(List userI else if (updates.TryGetValue(document.EndpointId.Name, out var newValueFromEndpoint)) { document.UserIndicator = newValueFromEndpoint; - //update all that match this sanitized name - var sanitizedMatchingQuery = session.Query() - .Where(sanitizedDocument => sanitizedDocument.SanitizedName == document.SanitizedName && sanitizedDocument.EndpointId.Name != document.EndpointId.Name); - var sanitizedMatchingDocuments = await sanitizedMatchingQuery.ToListAsync(cancellationToken); + sanitizedNameToUserIndicator[document.SanitizedName] = newValueFromEndpoint; + } + } - foreach (var matchingDocumentOnSanitizedName in sanitizedMatchingDocuments) + if (sanitizedNameToUserIndicator.Count > 0) + { + // One batched query for all sibling documents, instead of one query per document. + var sanitizedNames = sanitizedNameToUserIndicator.Keys.ToList(); + var alreadyLoadedIds = documents.Select(d => d.Id).ToHashSet(); + + var siblingDocuments = await session.Query() + .Where(d => d.SanitizedName.In(sanitizedNames)) + .ToListAsync(cancellationToken); + + foreach (var sibling in siblingDocuments.Where(d => !alreadyLoadedIds.Contains(d.Id))) + { + if (sanitizedNameToUserIndicator.TryGetValue(sibling.SanitizedName, out var indicator)) { - matchingDocumentOnSanitizedName.UserIndicator = newValueFromEndpoint; + sibling.UserIndicator = indicator; } } } diff --git a/src/ServiceControl.Persistence.Tests/Throughput/EndpointsTests.cs b/src/ServiceControl.Persistence.Tests/Throughput/EndpointsTests.cs index f677b03e8c..60aeb19ee9 100644 --- a/src/ServiceControl.Persistence.Tests/Throughput/EndpointsTests.cs +++ b/src/ServiceControl.Persistence.Tests/Throughput/EndpointsTests.cs @@ -199,6 +199,38 @@ public async Task Should_update_indicators_on_all_endpoint_sources_when_updated_ Assert.That(foundEndpointMonitoring.UserIndicator, Is.EqualTo(userIndicator)); } + [Test] + public async Task Should_update_user_indicators_on_more_than_30_endpoints_without_hitting_session_request_limit() + { + // Arrange + // Each pair shares a sanitized name but has different raw names. + // Updating by raw name (not sanitized name) triggers a sibling propagation query. + // In the original code, that was one DB query per endpoint, exceeding RavenDB's + // default limit of 30 requests per session when 30+ endpoints are updated at once. + const int endpointCount = 30; + var userIndicator = "someIndicator"; + + for (var i = 0; i < endpointCount; i++) + { + var sanitizedName = $"Endpoint{i}"; + await LicensingDataStore.SaveEndpoint(new Endpoint(sanitizedName, ThroughputSource.Audit) { SanitizedName = sanitizedName }, default); + await LicensingDataStore.SaveEndpoint(new Endpoint($"schema.{sanitizedName}", ThroughputSource.Monitoring) { SanitizedName = sanitizedName }, default); + } + + var updates = Enumerable.Range(0, endpointCount) + .Select(i => new UpdateUserIndicator { Name = $"schema.Endpoint{i}", UserIndicator = userIndicator }) + .ToList(); + + // Act - must not throw InvalidOperationException due to exceeding session request limit + await LicensingDataStore.UpdateUserIndicatorOnEndpoints(updates, default); + + // Assert + var allEndpoints = (await LicensingDataStore.GetAllEndpoints(true, default)).ToList(); + + Assert.That(allEndpoints, Has.Count.EqualTo(endpointCount * 2)); + Assert.That(allEndpoints, Has.All.Matches(e => e.UserIndicator == userIndicator)); + } + [TestCase(10, 5, false)] [TestCase(10, 20, true)] public async Task Should_correctly_report_throughput_existence_for_X_days(int daysSinceLastThroughputEntry, int timeFrameToCheck, bool expectedValue)