From 2a6409d3a7ab15a4aa2f152feb446b7912ef1b97 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:27:37 +0000 Subject: [PATCH 1/2] Initial plan From ac13efef53e45ff74f719260d0257a2e9a346d71 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:44:08 +0000 Subject: [PATCH 2/2] Implement ScanForTypes collection return feature (Type[] and IEnumerable) Co-authored-by: Dreamescaper <17177729+Dreamescaper@users.noreply.github.com> --- .../CustomHandlerTests.cs | 316 ++++++++++++++++++ ...jectionGenerator.FindServicesToRegister.cs | 47 ++- ...encyInjectionGenerator.ParseMethodModel.cs | 57 +++- .../DependencyInjectionGenerator.cs | 30 +- .../DiagnosticDescriptors.cs | 7 + .../Model/MethodImplementationModel.cs | 3 +- .../Model/MethodModel.cs | 23 +- 7 files changed, 454 insertions(+), 29 deletions(-) diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs index 4a027dd..d0d5061 100644 --- a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -1343,4 +1343,320 @@ public class MyService : IService { } Assert.Contains(results.Diagnostics, d => d.Descriptor == DiagnosticDescriptors.CantMixServiceRegistrationsAndServiceHandler); } + + [Fact] + public void ScanForTypesAttribute_ReturnsTypeArray_WithNoHandler() + { + var source = """ + using ServiceScan.SourceGenerator; + using System; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IService))] + public static partial Type[] GetServiceTypes(); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial global::System.Type[] GetServiceTypes() + { + return [typeof(global::GeneratorTests.MyService1), typeof(global::GeneratorTests.MyService2)]; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); + } + + [Fact] + public void ScanForTypesAttribute_ReturnsIEnumerableType_WithNoHandler() + { + var source = """ + using ServiceScan.SourceGenerator; + using System; + using System.Collections.Generic; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IService))] + public static partial IEnumerable GetServiceTypes(); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial global::System.Collections.Generic.IEnumerable GetServiceTypes() + { + return [typeof(global::GeneratorTests.MyService1), typeof(global::GeneratorTests.MyService2)]; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); + } + + [Fact] + public void ScanForTypesAttribute_ReturnsResponseArray_WithHandler() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IService), Handler = nameof(GetServiceInfo))] + public static partial ServiceInfo[] GetServiceInfos(); + + private static ServiceInfo GetServiceInfo() => new ServiceInfo(typeof(T).Name); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + + public class ServiceInfo + { + public ServiceInfo(string name) { } + } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial global::GeneratorTests.ServiceInfo[] GetServiceInfos() + { + return [GetServiceInfo(), GetServiceInfo()]; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); + } + + [Fact] + public void ScanForTypesAttribute_ReturnsIEnumerableResponse_WithHandler() + { + var source = """ + using ServiceScan.SourceGenerator; + using System.Collections.Generic; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IService), Handler = nameof(GetServiceInfo))] + public static partial IEnumerable GetServiceInfos(); + + private static ServiceInfo GetServiceInfo() => new ServiceInfo(typeof(T).Name); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService1 : IService { } + public class MyService2 : IService { } + + public class ServiceInfo + { + public ServiceInfo(string name) { } + } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial global::System.Collections.Generic.IEnumerable GetServiceInfos() + { + return [GetServiceInfo(), GetServiceInfo()]; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); + } + + [Fact] + public void ScanForTypesAttribute_ReturnsTypeArray_MultipleAttributes() + { + var source = """ + using ServiceScan.SourceGenerator; + using System; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IFirstService))] + [ScanForTypes(AssignableTo = typeof(ISecondService))] + public static partial Type[] GetServiceTypes(); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IFirstService { } + public interface ISecondService { } + public class MyService1 : IFirstService { } + public class MyService2 : ISecondService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var expected = """ + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + public static partial global::System.Type[] GetServiceTypes() + { + return [typeof(global::GeneratorTests.MyService1), typeof(global::GeneratorTests.MyService2)]; + } + } + """; + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); + } + + [Fact] + public void ScanForTypesAttribute_HandlerReturnTypeMismatch_ReportsDiagnostic() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IService), Handler = nameof(GetServiceName))] + public static partial ServiceInfo[] GetServiceInfos(); + + private static string GetServiceName() => typeof(T).Name; + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService : IService { } + + public class ServiceInfo { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.WrongHandlerReturnTypeForCollectionReturn); + } + + [Fact] + public void ScanForTypesAttribute_NoHandlerNonTypeCollection_ReportsDiagnostic() + { + var source = """ + using ServiceScan.SourceGenerator; + + namespace GeneratorTests; + + public static partial class ServicesExtensions + { + [ScanForTypes(AssignableTo = typeof(IService))] + public static partial string[] GetServiceNames(); + } + """; + + var services = + """ + namespace GeneratorTests; + + public interface IService { } + public class MyService : IService { } + """; + + var compilation = CreateCompilation(source, services); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.MissingCustomHandlerOnGenerateServiceHandler); + } } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs index 2e2da26..0687476 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FindServicesToRegister.cs @@ -26,6 +26,7 @@ private static DiagnosticModel FindServicesToRegister var containingType = compilation.GetTypeByMetadataName(method.TypeMetadataName); var registrations = new List(); var customHandlers = new List(); + var collectionItems = new List(); foreach (var attribute in attributes) { @@ -35,10 +36,44 @@ private static DiagnosticModel FindServicesToRegister { typesFound = true; - if (attribute.CustomHandler != null) + var implementationTypeName = implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + if (method.ReturnTypeIsCollection) { - var implementationTypeName = implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (attribute.CustomHandler == null) + { + // Case 1: no handler, return typeof(T) expressions + collectionItems.Add($"typeof({implementationTypeName})"); + } + else + { + // Case 2: handler maps T -> TResponse, generate handler invocation expressions + var arguments = string.Join(", ", method.Parameters.Select(p => p.Name)); + if (attribute.CustomHandlerMethodTypeParametersCount > 1 && matchedTypes != null) + { + foreach (var matchedType in matchedTypes) + { + var typeArguments = string.Join(", ", new[] { implementationTypeName } + .Concat(matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))); + + if (attribute.CustomHandlerType == CustomHandlerType.Method) + collectionItems.Add($"{attribute.CustomHandler}<{typeArguments}>({arguments})"); + else + collectionItems.Add($"{implementationTypeName}.{attribute.CustomHandler}({arguments})"); + } + } + else + { + if (attribute.CustomHandlerType == CustomHandlerType.Method) + collectionItems.Add($"{attribute.CustomHandler}<{implementationTypeName}>({arguments})"); + else + collectionItems.Add($"{implementationTypeName}.{attribute.CustomHandler}({arguments})"); + } + } + } + else if (attribute.CustomHandler != null) + { // If CustomHandler method has multiple type parameters, which are resolvable from the first one - we try to provide them. // e.g. ApplyConfiguration(ModelBuilder modelBuilder) where T : IEntityTypeConfiguration if (attribute.CustomHandlerMethodTypeParametersCount > 1 && matchedTypes != null) @@ -81,7 +116,7 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F { if (implementationType.IsGenericType) { - var implementationTypeName = implementationType.ConstructUnboundGenericType().ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var implementationTypeNameUnbound = implementationType.ConstructUnboundGenericType().ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var serviceTypeName = serviceType.IsGenericType ? serviceType.ConstructUnboundGenericType().ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) : serviceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); @@ -89,7 +124,7 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F var registration = new ServiceRegistrationModel( attribute.Lifetime, serviceTypeName, - implementationTypeName, + implementationTypeNameUnbound, ResolveImplementation: false, IsOpenGeneric: true, attribute.KeySelector, @@ -103,7 +138,7 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F var registration = new ServiceRegistrationModel( attribute.Lifetime, serviceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - implementationType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + implementationTypeName, shouldResolve, IsOpenGeneric: false, attribute.KeySelector, @@ -119,7 +154,7 @@ .. matchedType.TypeArguments.Select(a => a.ToDisplayString(SymbolDisplayFormat.F diagnostic ??= Diagnostic.Create(NoMatchingTypesFound, attribute.Location); } - var implementationModel = new MethodImplementationModel(method, [.. registrations], [.. customHandlers]); + var implementationModel = new MethodImplementationModel(method, [.. registrations], [.. customHandlers], [.. collectionItems]); return new(diagnostic, implementationModel); } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs index c699244..99738a3 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.ParseMethodModel.cs @@ -117,41 +117,62 @@ public partial class DependencyInjectionGenerator if (hasServiceRegistrationsAttribute) return null; + // Compute collection return type info upfront for use in validation + var (returnTypeIsCollection, collectionElementTypeSymbol) = MethodModel.GetCollectionReturnInfo(method.ReturnType); + var collectionElementTypeName = collectionElementTypeSymbol?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isTypeCollection = returnTypeIsCollection && + collectionElementTypeSymbol?.ContainingNamespace?.ToDisplayString() == "System" && + collectionElementTypeSymbol?.Name == "Type"; + var position = context.TargetNode.SpanStart; var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method, context.SemanticModel)).ToArray(); foreach (var attribute in attributeData) { if (attribute.CustomHandler == null) - return Diagnostic.Create(MissingCustomHandlerOnGenerateServiceHandler, attribute.Location); + { + // Without a Handler, the method must return Type[] or IEnumerable + if (!isTypeCollection) + return Diagnostic.Create(MissingCustomHandlerOnGenerateServiceHandler, attribute.Location); + } + else + { + var customHandlerMethod = method.ContainingType.GetMethod(attribute.CustomHandler, context.SemanticModel, position); - if (!attribute.HasSearchCriteria) - return Diagnostic.Create(MissingSearchCriteria, attribute.Location); + if (customHandlerMethod != null) + { + if (!customHandlerMethod.IsGenericMethod) + return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); - if (attribute.AssemblyOfTypeName != null && attribute.AssemblyNameFilter != null) - return Diagnostic.Create(CantUseBothFromAssemblyOfAndAssemblyNameFilter, attribute.Location); + // When method returns a collection, validate that handler return type matches the element type + if (returnTypeIsCollection) + { + var handlerReturnTypeName = customHandlerMethod.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (handlerReturnTypeName != collectionElementTypeName) + return Diagnostic.Create(WrongHandlerReturnTypeForCollectionReturn, attribute.Location); + } - var customHandlerMethod = method.ContainingType.GetMethod(attribute.CustomHandler, context.SemanticModel, position); + var typesMatch = Enumerable.SequenceEqual( + method.Parameters.Select(p => p.Type), + customHandlerMethod.Parameters.Select(p => p.Type), + SymbolEqualityComparer.Default); - if (customHandlerMethod != null) - { - if (!customHandlerMethod.IsGenericMethod) - return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); + if (!typesMatch) + return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); + } + } - var typesMatch = Enumerable.SequenceEqual( - method.Parameters.Select(p => p.Type), - customHandlerMethod.Parameters.Select(p => p.Type), - SymbolEqualityComparer.Default); + if (!attribute.HasSearchCriteria) + return Diagnostic.Create(MissingSearchCriteria, attribute.Location); - if (!typesMatch) - return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location); - } + if (attribute.AssemblyOfTypeName != null && attribute.AssemblyNameFilter != null) + return Diagnostic.Create(CantUseBothFromAssemblyOfAndAssemblyNameFilter, attribute.Location); if (attribute.HasErrors) return null; } - if (!method.ReturnsVoid && + if (!method.ReturnsVoid && !returnTypeIsCollection && (method.Parameters.Length == 0 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, method.ReturnType))) { return Diagnostic.Create(WrongReturnTypeForCustomHandler, method.Locations[0]); diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs index b4800c2..bcc7ad3 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs @@ -58,10 +58,12 @@ private static void GenerateSource(SourceProductionContext context, DiagnosticMo if (src.Model == null) return; - var (method, registrations, customHandling) = src.Model; + var (method, registrations, customHandling, collectionItems) = src.Model; string source = registrations.Count > 0 ? GenerateRegistrationsSource(method, registrations) - : GenerateCustomHandlingSource(method, customHandling); + : method.ReturnTypeIsCollection + ? GenerateCollectionSource(method, collectionItems) + : GenerateCustomHandlingSource(method, customHandling); source = source.ReplaceLineEndings(); @@ -127,6 +129,30 @@ private static string GenerateRegistrationsSource(MethodModel method, EquatableA return source; } + private static string GenerateCollectionSource(MethodModel method, EquatableArray collectionItems) + { + var namespaceDeclaration = method.Namespace is null ? "" : $"namespace {method.Namespace};"; + var parameters = string.Join(",", method.Parameters.Select((p, i) => + $"{(i == 0 && method.IsExtensionMethod ? "this" : "")} {p.Type} {p.Name}")); + + var items = string.Join(", ", collectionItems); + var methodBody = $"return [{items}];"; + + var source = $$""" + {{namespaceDeclaration}} + + {{method.TypeModifiers}} class {{method.TypeName}} + { + {{method.MethodModifiers}} {{method.ReturnType}} {{method.MethodName}}({{parameters}}) + { + {{methodBody}} + } + } + """; + + return source; + } + private static string GenerateCustomHandlingSource(MethodModel method, EquatableArray customHandlers) { var invocations = string.Join("\n", customHandlers.Select(h => diff --git a/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs b/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs index 54c9d78..69b00c8 100644 --- a/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs +++ b/ServiceScan.SourceGenerator/DiagnosticDescriptors.cs @@ -87,4 +87,11 @@ public static class DiagnosticDescriptors "Usage", DiagnosticSeverity.Error, true); + + public static readonly DiagnosticDescriptor WrongHandlerReturnTypeForCollectionReturn = new("DI0015", + "Handler return type does not match collection element type", + "Handler return type must match the element type of the collection return type", + "Usage", + DiagnosticSeverity.Error, + true); } diff --git a/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs b/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs index b372cc8..d571ccf 100644 --- a/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs +++ b/ServiceScan.SourceGenerator/Model/MethodImplementationModel.cs @@ -3,4 +3,5 @@ record MethodImplementationModel( MethodModel Method, EquatableArray Registrations, - EquatableArray CustomHandlers); + EquatableArray CustomHandlers, + EquatableArray CollectionItems); diff --git a/ServiceScan.SourceGenerator/Model/MethodModel.cs b/ServiceScan.SourceGenerator/Model/MethodModel.cs index 831c35c..9769a0e 100644 --- a/ServiceScan.SourceGenerator/Model/MethodModel.cs +++ b/ServiceScan.SourceGenerator/Model/MethodModel.cs @@ -16,7 +16,9 @@ record MethodModel( EquatableArray Parameters, bool IsExtensionMethod, bool ReturnsVoid, - string ReturnType) + string ReturnType, + bool ReturnTypeIsCollection, + string? CollectionElementTypeName) { public string ParameterName => Parameters.First().Name; @@ -27,6 +29,9 @@ public static MethodModel Create(IMethodSymbol method, SyntaxNode syntax) var typeSyntax = syntax.FirstAncestorOrSelf(); + var (returnTypeIsCollection, collectionElementTypeSymbol) = GetCollectionReturnInfo(method.ReturnType); + var collectionElementTypeName = collectionElementTypeSymbol?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return new MethodModel( Namespace: method.ContainingNamespace.IsGlobalNamespace ? null : method.ContainingNamespace.ToDisplayString(), TypeName: method.ContainingType.Name, @@ -37,7 +42,21 @@ public static MethodModel Create(IMethodSymbol method, SyntaxNode syntax) Parameters: parameters, IsExtensionMethod: method.IsExtensionMethod, ReturnsVoid: method.ReturnsVoid, - ReturnType: method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + ReturnType: method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + ReturnTypeIsCollection: returnTypeIsCollection, + CollectionElementTypeName: collectionElementTypeName); + } + + public static (bool isCollection, ITypeSymbol? elementTypeSymbol) GetCollectionReturnInfo(ITypeSymbol returnType) + { + if (returnType is IArrayTypeSymbol arrayType) + return (true, arrayType.ElementType); + + if (returnType is INamedTypeSymbol { IsGenericType: true, Arity: 1 } namedType && + namedType.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) + return (true, namedType.TypeArguments[0]); + + return (false, null); } private static string GetModifiers(SyntaxNode? syntax)