diff --git a/src/EntityFrameworkCore.Triggered.Extensions/ServiceCollectionExtensions.cs b/src/EntityFrameworkCore.Triggered.Extensions/ServiceCollectionExtensions.cs index 2b8c382..bd24701 100644 --- a/src/EntityFrameworkCore.Triggered.Extensions/ServiceCollectionExtensions.cs +++ b/src/EntityFrameworkCore.Triggered.Extensions/ServiceCollectionExtensions.cs @@ -1,50 +1,20 @@ using System; using System.Linq; using System.Reflection; -using EntityFrameworkCore.Triggered; -using EntityFrameworkCore.Triggered.Infrastructure.Internal; -using EntityFrameworkCore.Triggered.Lifecycles; +using EntityFrameworkCore.Triggered.Extensions; using Microsoft.Extensions.DependencyInjection.Extensions; namespace Microsoft.Extensions.DependencyInjection { public static class ServiceCollectionExtensions { - static readonly Type[] _wellKnownTriggerTypes = new Type[] { - typeof(IBeforeSaveTrigger<>), - typeof(IBeforeSaveAsyncTrigger<>), - typeof(IAfterSaveTrigger<>), - typeof(IAfterSaveAsyncTrigger<>), - typeof(IAfterSaveFailedTrigger<>), - typeof(IAfterSaveFailedAsyncTrigger<>), - typeof(IBeforeSaveStartingTrigger), - typeof(IBeforeSaveStartingAsyncTrigger), - typeof(IBeforeSaveCompletedTrigger), - typeof(IBeforeSaveCompletedAsyncTrigger), - typeof(IAfterSaveFailedStartingTrigger), - typeof(IAfterSaveFailedStartingAsyncTrigger), - typeof(IAfterSaveFailedCompletedTrigger), - typeof(IAfterSaveFailedCompletedAsyncTrigger), - typeof(IAfterSaveStartingTrigger), - typeof(IAfterSaveStartingAsyncTrigger), - typeof(IAfterSaveCompletedTrigger), - typeof(IAfterSaveCompletedAsyncTrigger) - }; - - static void RegisterTriggerTypes(Type triggerImplementationType, IServiceCollection services) + private static void RegisterTriggerTypes(Type triggerImplementationType, IServiceCollection services) { - foreach (var customTriggerType in _wellKnownTriggerTypes) - { - var customTriggers = customTriggerType.IsGenericTypeDefinition -#pragma warning disable EF1001 // Internal EF Core API usage. - ? TypeHelpers.FindGenericInterfaces(triggerImplementationType, customTriggerType) -#pragma warning restore EF1001 // Internal EF Core API usage. - : triggerImplementationType.GetInterfaces().Where(x => x == customTriggerType); + var triggerInterfaces = TriggerTypeHelper.GetTriggerInterfaces(triggerImplementationType); - foreach (var customTrigger in customTriggers) - { - services.Add(new ServiceDescriptor(customTrigger, sp => sp.GetRequiredService(triggerImplementationType), ServiceLifetime.Transient)); ; - } + foreach (var triggerInterface in triggerInterfaces) + { + services.Add(new ServiceDescriptor(triggerInterface, sp => sp.GetRequiredService(triggerImplementationType), ServiceLifetime.Transient)); } } @@ -60,6 +30,11 @@ public static IServiceCollection AddTrigger(this IServiceCollection se public static IServiceCollection AddTrigger(this IServiceCollection services, object triggerInstance) { + if (triggerInstance is null) + { + throw new ArgumentNullException(nameof(triggerInstance)); + } + services.TryAddSingleton(triggerInstance); RegisterTriggerTypes(triggerInstance.GetType(), services); @@ -83,29 +58,27 @@ public static IServiceCollection AddAssemblyTriggers(this IServiceCollection ser throw new ArgumentNullException(nameof(assemblies)); } - var assemblyTypes = assemblies - .SelectMany(x => x.GetTypes()) - .Where(x => x.IsClass) - .Where(x => !x.IsAbstract); + if (assemblies.Length == 0) + { + return services; + } + + var assemblyTypes = assemblies.SelectMany(TriggerTypeHelper.GetAssemblyConcreteClasses); foreach (var assemblyType in assemblyTypes) { - var triggerTypes = assemblyType - .GetInterfaces() - .Where(x => _wellKnownTriggerTypes.Contains(x.IsConstructedGenericType ? x.GetGenericTypeDefinition() : x)); + var triggerInterfaces = TriggerTypeHelper.GetTriggerInterfaces(assemblyType); - var registered = false; - - foreach (var triggerType in triggerTypes) + if (triggerInterfaces.Length == 0) { - if (!registered) - { - services.TryAdd(new ServiceDescriptor(assemblyType, assemblyType, lifetime)); + continue; + } - registered = true; - } + services.TryAdd(new ServiceDescriptor(assemblyType, assemblyType, lifetime)); - services.Add(new ServiceDescriptor(triggerType, sp => sp.GetRequiredService(assemblyType), ServiceLifetime.Transient)); + foreach (var triggerInterface in triggerInterfaces) + { + services.Add(new ServiceDescriptor(triggerInterface, sp => sp.GetRequiredService(assemblyType), ServiceLifetime.Transient)); } } diff --git a/src/EntityFrameworkCore.Triggered.Extensions/TriggerTypeHelper.cs b/src/EntityFrameworkCore.Triggered.Extensions/TriggerTypeHelper.cs new file mode 100644 index 0000000..3e51135 --- /dev/null +++ b/src/EntityFrameworkCore.Triggered.Extensions/TriggerTypeHelper.cs @@ -0,0 +1,105 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using EntityFrameworkCore.Triggered.Lifecycles; + +namespace EntityFrameworkCore.Triggered.Extensions +{ + /// + /// Internal helpers shared between ServiceCollectionExtensions and + /// TriggersContextOptionsBuilderExtensions for trigger-type discovery. + /// + static internal class TriggerTypeHelper + { + // Open generic trigger interfaces — add new generic lifecycle interfaces here + private readonly static HashSet _genericTriggerTypes = new HashSet + { + typeof(IBeforeSaveTrigger<>), + typeof(IBeforeSaveAsyncTrigger<>), + typeof(IAfterSaveTrigger<>), + typeof(IAfterSaveAsyncTrigger<>), + typeof(IAfterSaveFailedTrigger<>), + typeof(IAfterSaveFailedAsyncTrigger<>), + }; + + // Non-generic lifecycle interfaces — add new non-generic lifecycle interfaces here + private readonly static HashSet _nonGenericTriggerTypes = new HashSet + { + typeof(IBeforeSaveStartingTrigger), + typeof(IBeforeSaveStartingAsyncTrigger), + typeof(IBeforeSaveCompletedTrigger), + typeof(IBeforeSaveCompletedAsyncTrigger), + typeof(IAfterSaveFailedStartingTrigger), + typeof(IAfterSaveFailedStartingAsyncTrigger), + typeof(IAfterSaveFailedCompletedTrigger), + typeof(IAfterSaveFailedCompletedAsyncTrigger), + typeof(IAfterSaveStartingTrigger), + typeof(IAfterSaveStartingAsyncTrigger), + typeof(IAfterSaveCompletedTrigger), + typeof(IAfterSaveCompletedAsyncTrigger), + }; + + // Caches the resolved trigger interfaces per implementation type to avoid repeated reflection + private readonly static ConcurrentDictionary _triggerInterfaceCache = + new ConcurrentDictionary(); + + /// + /// Returns the types defined in , gracefully handling + /// that occurs when some types cannot be + /// loaded due to missing dependencies (e.g. optional assemblies not present at runtime). + /// + private static IEnumerable GetAssemblyTypes(Assembly assembly) + { + try + { + return assembly.GetTypes(); + } + catch (ReflectionTypeLoadException e) + { + // Return only the types that could be loaded; nulls represent types that failed + return e.Types.OfType(); + } + } + + /// + /// Returns the non-abstract classes defined in , gracefully handling + /// that occurs when some types cannot be loaded due to missing + /// dependencies (e.g. optional assemblies not present at runtime). + /// + static internal IEnumerable GetAssemblyConcreteClasses(Assembly assembly) => + GetAssemblyTypes(assembly).Where(t => t is { IsClass: true, IsAbstract: false }); + + /// + /// Returns the subset of 's interfaces that + /// match a known trigger interface, using a per-type cache. + /// + static internal Type[] GetTriggerInterfaces(Type triggerImplementationType) + => _triggerInterfaceCache.GetOrAdd(triggerImplementationType, t => + { + var interfaces = t.GetInterfaces(); + var result = new List(interfaces.Length); + + foreach (var iface in interfaces) + { + if (iface.IsConstructedGenericType) + { + if (_genericTriggerTypes.Contains(iface.GetGenericTypeDefinition())) + { + result.Add(iface); + } + } + else if (_nonGenericTriggerTypes.Contains(iface)) + { + result.Add(iface); + } + } + + return result.ToArray(); + }); + } +} + + + diff --git a/src/EntityFrameworkCore.Triggered.Extensions/TriggersContextOptionsBuilderExtensions.cs b/src/EntityFrameworkCore.Triggered.Extensions/TriggersContextOptionsBuilderExtensions.cs index 4c6d297..8c00f88 100644 --- a/src/EntityFrameworkCore.Triggered.Extensions/TriggersContextOptionsBuilderExtensions.cs +++ b/src/EntityFrameworkCore.Triggered.Extensions/TriggersContextOptionsBuilderExtensions.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using System.Reflection; +using EntityFrameworkCore.Triggered.Extensions; using EntityFrameworkCore.Triggered.Infrastructure; using Microsoft.Extensions.DependencyInjection; @@ -24,14 +25,20 @@ public static TriggersContextOptionsBuilder AddAssemblyTriggers(this TriggersCon throw new ArgumentNullException(nameof(assemblies)); } - var assemblyTypes = assemblies - .SelectMany(x => x.GetTypes()) - .Where(x => x.IsClass) - .Where(x => !x.IsAbstract); + if (assemblies.Length == 0) + { + return builder; + } + + var assemblyTypes = assemblies.SelectMany(TriggerTypeHelper.GetAssemblyConcreteClasses); foreach (var assemblyType in assemblyTypes) { - builder.AddTrigger(assemblyType, lifetime); + // Only register types that actually implement a known trigger interface + if (TriggerTypeHelper.GetTriggerInterfaces(assemblyType).Length > 0) + { + builder.AddTrigger(assemblyType, lifetime); + } } return builder; diff --git a/src/EntityFrameworkCore.Triggered/Infrastructure/Internal/TriggersOptionExtension.cs b/src/EntityFrameworkCore.Triggered/Infrastructure/Internal/TriggersOptionExtension.cs index 3c24b89..b71012b 100644 --- a/src/EntityFrameworkCore.Triggered/Infrastructure/Internal/TriggersOptionExtension.cs +++ b/src/EntityFrameworkCore.Triggered/Infrastructure/Internal/TriggersOptionExtension.cs @@ -17,6 +17,7 @@ public class TriggersOptionExtension : IDbContextOptionsExtension sealed class ExtensionInfo : DbContextOptionsExtensionInfo { private string? _logFragment; + private int? _serviceProviderHashCode; public ExtensionInfo(IDbContextOptionsExtension extension) : base(extension) { } @@ -44,14 +45,19 @@ public override void PopulateDebugInfo(IDictionary debugInfo) throw new ArgumentNullException(nameof(debugInfo)); } - debugInfo["Triggers:TriggersCount"] = (Extension._triggers?.Count() ?? 0).ToString(); - debugInfo["Triggers:TriggerTypesCount"] = (Extension._triggerTypes?.Count() ?? 0).ToString(); + debugInfo["Triggers:TriggersCount"] = (Extension._triggers?.Count ?? 0).ToString(); + debugInfo["Triggers:TriggerTypesCount"] = (Extension._triggerTypes?.Count ?? 0).ToString(); debugInfo["Triggers:MaxCascadeCycles"] = Extension._maxCascadeCycles.ToString(); debugInfo["Triggers:CascadeBehavior"] = Extension._cascadeBehavior.ToString(); } public override int GetServiceProviderHashCode() { + if (_serviceProviderHashCode.HasValue) + { + return _serviceProviderHashCode.Value; + } + var hashCode = new HashCode(); if (Extension._triggers != null) @@ -78,28 +84,56 @@ public override int GetServiceProviderHashCode() hashCode.Add(Extension._serviceProviderTransform); } - return hashCode.ToHashCode(); + _serviceProviderHashCode = hashCode.ToHashCode(); + return _serviceProviderHashCode.Value; } public override bool ShouldUseSameServiceProvider(DbContextOptionsExtensionInfo other) - => other is ExtensionInfo otherInfo - && Enumerable.SequenceEqual(Extension._triggers ?? Enumerable.Empty>(), otherInfo.Extension._triggers ?? Enumerable.Empty>()) - && Enumerable.SequenceEqual(Extension._triggerTypes ?? Enumerable.Empty(), otherInfo.Extension._triggerTypes ?? Enumerable.Empty()) - && Extension._maxCascadeCycles == otherInfo.Extension._maxCascadeCycles - && Extension._cascadeBehavior == otherInfo.Extension._cascadeBehavior - && Extension._serviceProviderTransform == otherInfo.Extension._serviceProviderTransform; + { + if (other is not ExtensionInfo otherInfo) + { + return false; + } + + // Check cheap scalar comparisons first + if (Extension._maxCascadeCycles != otherInfo.Extension._maxCascadeCycles + || Extension._cascadeBehavior != otherInfo.Extension._cascadeBehavior + || Extension._serviceProviderTransform != otherInfo.Extension._serviceProviderTransform) + { + return false; + } + + // Check list counts before doing full sequence comparison + var triggersCount = Extension._triggers?.Count ?? 0; + var otherTriggersCount = otherInfo.Extension._triggers?.Count ?? 0; + if (triggersCount != otherTriggersCount) + { + return false; + } + + var triggerTypesCount = Extension._triggerTypes?.Count ?? 0; + var otherTriggerTypesCount = otherInfo.Extension._triggerTypes?.Count ?? 0; + if (triggerTypesCount != otherTriggerTypesCount) + { + return false; + } + + // Full sequence comparison only when counts match + return Enumerable.SequenceEqual(Extension._triggers ?? Enumerable.Empty>(), otherInfo.Extension._triggers ?? Enumerable.Empty>()) + && Enumerable.SequenceEqual(Extension._triggerTypes ?? Enumerable.Empty(), otherInfo.Extension._triggerTypes ?? Enumerable.Empty()); + } } private ExtensionInfo? _info; - private IEnumerable<(object typeOrInstance, ServiceLifetime lifetime)>? _triggers; - private IEnumerable _triggerTypes; + private List<(object typeOrInstance, ServiceLifetime lifetime)>? _triggers; + private List _triggerTypes; private int _maxCascadeCycles = 100; private CascadeBehavior _cascadeBehavior = CascadeBehavior.EntityAndType; private Func? _serviceProviderTransform; public TriggersOptionExtension() { - _triggerTypes = new[] { + _triggerTypes = new List { typeof(IBeforeSaveTrigger<>), typeof(IBeforeSaveAsyncTrigger<>), typeof(IAfterSaveTrigger<>), @@ -125,10 +159,10 @@ public TriggersOptionExtension(TriggersOptionExtension copyFrom) { if (copyFrom._triggers != null) { - _triggers = copyFrom._triggers; + _triggers = new List<(object typeOrInstance, ServiceLifetime lifetime)>(copyFrom._triggers); } - _triggerTypes = copyFrom._triggerTypes; + _triggerTypes = new List(copyFrom._triggerTypes); _maxCascadeCycles = copyFrom._maxCascadeCycles; _cascadeBehavior = copyFrom._cascadeBehavior; _serviceProviderTransform = copyFrom._serviceProviderTransform; @@ -223,13 +257,13 @@ public void Validate(IDbContextOptions options) { } private bool TypeIsValidTrigger(Type type) { - if (TypeHelpers.FindGenericInterfaces(type, typeof(IBeforeSaveTrigger<>)) != null || TypeHelpers.FindGenericInterfaces(type, typeof(IAfterSaveTrigger<>)) != null) + if (TypeHelpers.FindGenericInterfaces(type, typeof(IBeforeSaveTrigger<>)).Any() || TypeHelpers.FindGenericInterfaces(type, typeof(IAfterSaveTrigger<>)).Any()) { return true; } else if (_triggerTypes != null) { - return _triggerTypes.Any(triggerType => TypeHelpers.FindGenericInterfaces(type, triggerType) != null); + return _triggerTypes.Any(triggerType => TypeHelpers.FindGenericInterfaces(type, triggerType).Any()); } else { @@ -263,17 +297,8 @@ public TriggersOptionExtension WithAdditionalTrigger(Type triggerType, ServiceLi } var clone = Clone(); - var triggerEnumerable = Enumerable.Repeat(((object)triggerType, lifetime), 1); - - if (clone._triggers == null) - { - clone._triggers = triggerEnumerable; - } - else - { - clone._triggers = clone._triggers.Concat(triggerEnumerable); - } - + clone._triggers ??= new List<(object typeOrInstance, ServiceLifetime lifetime)>(); + clone._triggers.Add(((object)triggerType, lifetime)); return clone; } @@ -291,17 +316,8 @@ public TriggersOptionExtension WithAdditionalTrigger(object instance) } var clone = Clone(); - var triggersEnumerable = Enumerable.Repeat((instance, ServiceLifetime.Singleton), 1); - - if (clone._triggers == null) - { - clone._triggers = triggersEnumerable; - } - else - { - clone._triggers = clone._triggers.Concat(triggersEnumerable); - } - + clone._triggers ??= new List<(object typeOrInstance, ServiceLifetime lifetime)>(); + clone._triggers.Add((instance, ServiceLifetime.Singleton)); return clone; } @@ -313,19 +329,9 @@ public TriggersOptionExtension WithAdditionalTriggerType(Type triggerType) throw new ArgumentNullException(nameof(triggerType)); } - var clone = Clone(); - var triggerTypesEnumerable = Enumerable.Repeat(triggerType, 1); - - if (clone._triggerTypes == null) - { - clone._triggerTypes = triggerTypesEnumerable; - } - else - { - clone._triggerTypes = clone._triggerTypes.Concat(triggerTypesEnumerable); - } - + clone._triggerTypes ??= new List(); + clone._triggerTypes.Add(triggerType); return clone; } diff --git a/src/EntityFrameworkCore.Triggered/Internal/TriggerDiscoveryService.cs b/src/EntityFrameworkCore.Triggered/Internal/TriggerDiscoveryService.cs index 10092e6..110e7ae 100644 --- a/src/EntityFrameworkCore.Triggered/Internal/TriggerDiscoveryService.cs +++ b/src/EntityFrameworkCore.Triggered/Internal/TriggerDiscoveryService.cs @@ -28,41 +28,60 @@ public IEnumerable DiscoverTriggers(Type openTriggerType, Typ { var registry = _triggerTypeRegistryService.ResolveRegistry(openTriggerType, entityType, triggerTypeDescriptorFactory); - var triggerTypeDescriptors = registry.GetTriggerTypeDescriptors(); + // On the first call for this (openTriggerType, entityType) combination the active + // descriptor cache is null — use the full hierarchy. On subsequent calls we skip + // descriptors that are known to produce no results, eliminating empty DI lookups. + var precomputedActive = registry.GetActiveDescriptors(); + var triggerTypeDescriptors = precomputedActive ?? registry.GetTriggerTypeDescriptors(); + if (triggerTypeDescriptors.Length == 0) { + if (precomputedActive == null) + { + registry.SetActiveDescriptors(Array.Empty()); + } + return Enumerable.Empty(); } - else + + List? triggerDescriptors = null; + // Only track active descriptors when the cache has not been populated yet + List? newActiveDescriptors = precomputedActive == null ? new List() : null; + + foreach (var triggerTypeDescriptor in triggerTypeDescriptors) { - List? triggerDescriptors = null; + var triggers = _triggerFactory.Resolve(ServiceProvider, triggerTypeDescriptor.TriggerType); + var addedToActive = false; - foreach (var triggerTypeDescriptor in triggerTypeDescriptors) + foreach (var trigger in triggers) { - var triggers = _triggerFactory.Resolve(ServiceProvider, triggerTypeDescriptor.TriggerType); - foreach (var trigger in triggers) + if (trigger != null) { - if (triggerDescriptors == null) - { - triggerDescriptors = new List(); - } + (triggerDescriptors ??= new List()).Add(new TriggerDescriptor(triggerTypeDescriptor, trigger)); - if (trigger != null) + if (newActiveDescriptors != null && !addedToActive) { - triggerDescriptors.Add(new TriggerDescriptor(triggerTypeDescriptor, trigger)); + newActiveDescriptors.Add(triggerTypeDescriptor); + addedToActive = true; } } } + } - if (triggerDescriptors == null) - { - return Enumerable.Empty(); - } - else - { - triggerDescriptors.Sort(_triggerDescriptorComparer); - return triggerDescriptors; - } + // Persist the active set so future calls skip the empty lookups + if (newActiveDescriptors != null) + { + registry.SetActiveDescriptors(newActiveDescriptors.ToArray()); + } + + if (triggerDescriptors == null) + { + return Enumerable.Empty(); + } + else + { + triggerDescriptors.Sort(_triggerDescriptorComparer); + return triggerDescriptors; } } @@ -70,59 +89,83 @@ public IEnumerable DiscoverAsyncTriggers(Type openTrigge { var registry = _triggerTypeRegistryService.ResolveRegistry(openTriggerType, entityType, triggerTypeDescriptorFactory); - var triggerTypeDescriptors = registry.GetTriggerTypeDescriptors(); + var precomputedActive = registry.GetActiveDescriptors(); + var triggerTypeDescriptors = precomputedActive ?? registry.GetTriggerTypeDescriptors(); + if (triggerTypeDescriptors.Length == 0) { + if (precomputedActive == null) + { + registry.SetActiveDescriptors(Array.Empty()); + } + return Enumerable.Empty(); } - else + + List? triggerDescriptors = null; + List? newActiveDescriptors = precomputedActive == null ? new List() : null; + + foreach (var triggerTypeDescriptor in triggerTypeDescriptors) { - List? triggerDescriptors = null; + var triggers = _triggerFactory.Resolve(ServiceProvider, triggerTypeDescriptor.TriggerType); + var addedToActive = false; - foreach (var triggerTypeDescriptor in triggerTypeDescriptors) + foreach (var trigger in triggers) { - var triggers = _triggerFactory.Resolve(ServiceProvider, triggerTypeDescriptor.TriggerType); - foreach (var trigger in triggers) + if (trigger != null) { - if (triggerDescriptors == null) - { - triggerDescriptors = new List(); - } + (triggerDescriptors ??= new List()).Add(new AsyncTriggerDescriptor(triggerTypeDescriptor, trigger)); - if (trigger != null) + if (newActiveDescriptors != null && !addedToActive) { - triggerDescriptors.Add(new AsyncTriggerDescriptor(triggerTypeDescriptor, trigger)); + newActiveDescriptors.Add(triggerTypeDescriptor); + addedToActive = true; } } } + } - if (triggerDescriptors == null) - { - return Enumerable.Empty(); - } - else - { - triggerDescriptors.Sort(_triggerDescriptorComparer); - return triggerDescriptors; - } + if (newActiveDescriptors != null) + { + registry.SetActiveDescriptors(newActiveDescriptors.ToArray()); + } + + if (triggerDescriptors == null) + { + return Enumerable.Empty(); + } + else + { + triggerDescriptors.Sort(_triggerDescriptorComparer); + return triggerDescriptors; } } public IEnumerable DiscoverTriggers() { // We can skip the registry as there is no generic argument - var triggers = _triggerFactory.Resolve(ServiceProvider, typeof(TTrigger)); - - return triggers - .Select((trigger, index) => ( - trigger, - defaultPriority: index, - customPriority: (trigger as ITriggerPriority)?.Priority ?? 0 - )) + var resolvedTriggers = _triggerFactory.Resolve(ServiceProvider, typeof(TTrigger)); + + // Materialise eagerly so we can short-circuit for the common case of 0 registered + // lifecycle triggers, avoiding the LINQ chain allocations on every SaveChanges call. + List<(object trigger, int defaultPriority, int customPriority)>? sorted = null; + var index = 0; + + foreach (var trigger in resolvedTriggers) + { + (sorted ??= new List<(object, int, int)>()).Add( + (trigger, index++, (trigger as ITriggerPriority)?.Priority ?? 0)); + } + + if (sorted == null) + { + return Enumerable.Empty(); + } + + return sorted .OrderBy(x => x.customPriority) .ThenBy(x => x.defaultPriority) - .Select(x => x.trigger) - .Cast(); + .Select(x => (TTrigger)x.trigger); } public IServiceProvider ServiceProvider diff --git a/src/EntityFrameworkCore.Triggered/Internal/TriggerFactory.cs b/src/EntityFrameworkCore.Triggered/Internal/TriggerFactory.cs index d24c0b8..9f98316 100644 --- a/src/EntityFrameworkCore.Triggered/Internal/TriggerFactory.cs +++ b/src/EntityFrameworkCore.Triggered/Internal/TriggerFactory.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; using Microsoft.Extensions.DependencyInjection; namespace EntityFrameworkCore.Triggered.Internal @@ -33,18 +32,15 @@ public IEnumerable Resolve(IServiceProvider serviceProvider, Type trigge // Alternatively, triggers may be registered with the extension configuration var instanceFactoryType = _instanceFactoryTypeCache.GetOrAdd(triggerType, - triggerType => typeof(ITriggerInstanceFactory<>).MakeGenericType(triggerType) + t => typeof(ITriggerInstanceFactory<>).MakeGenericType(t) ); - var triggerServiceFactories = _internalServiceProvider.GetServices(instanceFactoryType); - if (triggerServiceFactories.Any()) + // Iterate once — eliminates the former .Any() + foreach double-enumeration + foreach (var triggerServiceFactory in _internalServiceProvider.GetServices(instanceFactoryType)) { - foreach (var triggerServiceFactory in triggerServiceFactories) + if (triggerServiceFactory is ITriggerInstanceFactory factory) { - if (triggerServiceFactory is not null) - { - yield return ((ITriggerInstanceFactory)triggerServiceFactory).Create(serviceProvider ?? _internalServiceProvider); - } + yield return factory.Create(serviceProvider ?? _internalServiceProvider); } } } diff --git a/src/EntityFrameworkCore.Triggered/Internal/TriggerTypeRegistry.cs b/src/EntityFrameworkCore.Triggered/Internal/TriggerTypeRegistry.cs index c531dce..f17aadf 100644 --- a/src/EntityFrameworkCore.Triggered/Internal/TriggerTypeRegistry.cs +++ b/src/EntityFrameworkCore.Triggered/Internal/TriggerTypeRegistry.cs @@ -1,8 +1,8 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using EntityFrameworkCore.Triggered.Infrastructure.Internal; -using EntityFrameworkCore.Triggered.Internal.Descriptors; namespace EntityFrameworkCore.Triggered.Internal { @@ -13,6 +13,11 @@ public sealed class TriggerTypeRegistry TTriggerTypeDescriptor[]? _resolvedDescriptors; + // Populated after the first resolution: only the descriptors that produced at least one + // trigger instance. null = not yet computed; empty array = computed, none were active. + // Written once (first-write-wins via CompareExchange) so no further synchronisation needed. + TTriggerTypeDescriptor[]? _activeDescriptors; + public TriggerTypeRegistry(Type entityType, Func triggerTypeDescriptorFactory) { _entityType = entityType; @@ -28,7 +33,6 @@ IEnumerable GetEntityTypeHierarchy() foreach (var interfaceType in type.GetInterfaces()) { yield return interfaceType; - } yield return type; @@ -52,5 +56,20 @@ public TTriggerTypeDescriptor[] GetTriggerTypeDescriptors() return _resolvedDescriptors; } + + /// + /// Returns the subset of descriptors that produced at least one trigger instance during a + /// previous resolution, or null if this information has not yet been computed. + /// An empty array means a previous resolution confirmed that no triggers are registered. + /// + public TTriggerTypeDescriptor[]? GetActiveDescriptors() + => Volatile.Read(ref _activeDescriptors); + + /// + /// Records the subset of descriptors that produced triggers during a resolution. + /// Only the first call has any effect (first-write-wins); subsequent calls are ignored. + /// + public void SetActiveDescriptors(TTriggerTypeDescriptor[] activeDescriptors) + => Interlocked.CompareExchange(ref _activeDescriptors, activeDescriptors, null); } }