Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ public class HelloWorldEndpoint : IEndpoint

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IEndpoint), CustomHandler = nameof(IEndpoint.MapEndpoint))]
[ScanForTypes(AssignableTo = typeof(IEndpoint), Handler = nameof(IEndpoint.MapEndpoint))]
public static partial IEndpointRouteBuilder MapEndpoints(this IEndpointRouteBuilder endpoints);
}
```

### Register Options types
Another example of `CustomHandler` is to register Options types. We can define custom `OptionAttribute`, which allows to specify configuration section key.
And then read that value in our `CustomHandler`:
Another example of `Handler` is to register Options types. We can define custom `OptionAttribute`, which allows to specify configuration section key.
And then read that value in our `Handler`:
```csharp
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false)]
public class OptionAttribute(string? section = null) : Attribute
Expand All @@ -116,7 +116,7 @@ public record SectionOption { }

public static partial class ServiceCollectionExtensions
{
[GenerateServiceRegistrations(AttributeFilter = typeof(OptionAttribute), CustomHandler = nameof(AddOption))]
[ScanForTypes(AttributeFilter = typeof(OptionAttribute), Handler = nameof(AddOption))]
public static partial IServiceCollection AddOptions(this IServiceCollection services, IConfiguration configuration);

private static void AddOption<T>(IServiceCollection services, IConfiguration configuration) where T : class
Expand All @@ -133,7 +133,7 @@ public static partial class ServiceCollectionExtensions
```csharp
public static partial class ModelBuilderExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IEntityTypeConfiguration<>), CustomHandler = nameof(ApplyConfiguration))]
[ScanForTypes(AssignableTo = typeof(IEntityTypeConfiguration<>), Handler = nameof(ApplyConfiguration))]
public static partial ModelBuilder ApplyEntityConfigurations(this ModelBuilder modelBuilder);

private static void ApplyConfiguration<T, TEntity>(ModelBuilder modelBuilder)
Expand Down Expand Up @@ -164,4 +164,17 @@ public static partial class ModelBuilderExtensions
| **ExcludeByTypeName** | Sets this value to exclude types from being registered by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. |
| **ExcludeByAttribute** | Excludes matching types by the specified attribute type being present. |
| **KeySelector** | Sets this property to add types as keyed services. This property should point to one of the following: <br>- The name of a static method in the current type with a string return type. The method should be either generic or have a single parameter of type `Type`. <br>- A constant field or static property in the implementation type. |
| **CustomHandler** | Sets this property to invoke a custom method for each type found instead of regular registration logic. This property should point to one of the following: <br>- Name of a generic method in the current type. <br>- Static method name in found types. <br>This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, and `KeySelector` properties. <br>**Note:** When using a generic `CustomHandler` method, types are automatically filtered by the generic constraints defined on the method's type parameters (e.g., `class`, `struct`, `new()`, interface constraints). |
| **CustomHandler** | *(Obsolete — use `ScanForTypes` instead.)* Sets this property to invoke a custom method for each type found instead of regular registration logic. |

`ScanForTypes` attribute is used to invoke a custom method for each matched type. It has the same filtering properties as `GenerateServiceRegistrations`, but without the registration-specific ones (`Lifetime`, `AsImplementedInterfaces`, `AsSelf`, `KeySelector`):
| Property | Description |
| --- | --- |
| **Handler** | Sets this property to invoke a custom method for each type found. This property should point to one of the following: <br>- Name of a generic method in the current type. <br>- Static method name in found types. <br>**Note:** Types are automatically filtered by the generic constraints defined on the method's type parameters (e.g., `class`, `struct`, `new()`, interface constraints). |
| **FromAssemblyOf** | Sets the assembly containing the given type as the source of types to scan. If not specified, the assembly containing the method with this attribute will be used. |
| **AssemblyNameFilter** | Sets this value to filter scanned assemblies by assembly name. This option is incompatible with `FromAssemblyOf`. You can use '*' wildcards. You can also use ',' to separate multiple filters. |
| **AssignableTo** | Sets the type that the scanned types must be assignable to. |
| **ExcludeAssignableTo** | Sets the type that the scanned types must *not* be assignable to. |
| **TypeNameFilter** | Sets this value to filter the types by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. |
| **AttributeFilter** | Filters types by the specified attribute type being present. |
| **ExcludeByTypeName** | Sets this value to exclude types by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. |
| **ExcludeByAttribute** | Excludes matching types by the specified attribute type being present. |
244 changes: 244 additions & 0 deletions ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1099,4 +1099,248 @@ private static Compilation CreateCompilation(params string[] source)
],
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
}

[Fact]
public void ScanForTypesAttribute_WithNoParameters()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[ScanForTypes(AssignableTo = typeof(IService), Handler = nameof(HandleType))]
public static partial void ProcessServices();

private static void HandleType<T>() => System.Console.WriteLine(typeof(T).Name);
}
""";

var services =
"""
namespace GeneratorTests;

public interface IService { }
public class MyService1 : IService { }
public class MyService2 : IService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices()
{
HandleType<global::GeneratorTests.MyService1>();
HandleType<global::GeneratorTests.MyService2>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[2].ToString());
}

[Fact]
public void ScanForTypesAttribute_WithParameters()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[ScanForTypes(TypeNameFilter = "*Service", Handler = nameof(HandleType))]
public static partial void ProcessServices(string value);

private static void HandleType<T>(string value) => System.Console.WriteLine(value + typeof(T).Name);
}
""";

var services =
"""
namespace GeneratorTests;

public class MyFirstService {}
public class MySecondService {}
public class ServiceWithNonMatchingName {}
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices( string value)
{
HandleType<global::GeneratorTests.MyFirstService>(value);
HandleType<global::GeneratorTests.MySecondService>(value);
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[2].ToString());
}

[Fact]
public void ScanForTypesAttribute_MultipleAttributes()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[ScanForTypes(AssignableTo = typeof(IFirstService), Handler = nameof(HandleFirstType))]
[ScanForTypes(AssignableTo = typeof(ISecondService), Handler = nameof(HandleSecondType))]
public static partial void ProcessServices();

private static void HandleFirstType<T>() => System.Console.WriteLine("First:" + typeof(T).Name);
private static void HandleSecondType<T>() => System.Console.WriteLine("Second:" + typeof(T).Name);
}
""";

var services =
"""
namespace GeneratorTests;

public interface IFirstService { }
public interface ISecondService { }
public class MyService1 : IFirstService { }
public class MyService2 : ISecondService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

var expected = $$"""
namespace GeneratorTests;

public static partial class ServicesExtensions
{
public static partial void ProcessServices()
{
HandleFirstType<global::GeneratorTests.MyService1>();
HandleSecondType<global::GeneratorTests.MyService2>();
}
}
""";
Assert.Equal(expected, results.GeneratedTrees[2].ToString());
}

[Fact]
public void ScanForTypesAttribute_MissingHandler_ReportsDiagnostic()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[ScanForTypes(AssignableTo = typeof(IService))]
public static partial void ProcessServices();
}
""";

var services =
"""
namespace GeneratorTests;

public interface IService { }
public class MyService : IService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.MissingCustomHandlerOnGenerateServiceHandler);
}

[Fact]
public void ScanForTypesAttribute_MissingSearchCriteria_ReportsDiagnostic()
{
var source = $$"""
using ServiceScan.SourceGenerator;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[ScanForTypes(Handler = nameof(HandleType))]
public static partial void ProcessServices();

private static void HandleType<T>() { }
}
""";

var compilation = CreateCompilation(source);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

Assert.Equal(results.Diagnostics.Single().Descriptor, DiagnosticDescriptors.MissingSearchCriteria);
}

[Fact]
public void MixingGenerateServiceRegistrationsAndScanForTypes_ReportsDiagnostic()
{
var source = $$"""
using ServiceScan.SourceGenerator;
using Microsoft.Extensions.DependencyInjection;

namespace GeneratorTests;

public static partial class ServicesExtensions
{
[GenerateServiceRegistrations(AssignableTo = typeof(IService))]
[ScanForTypes(AssignableTo = typeof(IService), Handler = nameof(HandleType))]
public static partial IServiceCollection ProcessServices(this IServiceCollection services);

private static void HandleType<T>() { }
}
""";

var services =
"""
namespace GeneratorTests;

public interface IService { }
public class MyService : IService { }
""";

var compilation = CreateCompilation(source, services);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(compilation)
.GetRunResult();

Assert.Contains(results.Diagnostics, d => d.Descriptor == DiagnosticDescriptors.CantMixServiceRegistrationsAndServiceHandler);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ public partial class DependencyInjectionGenerator
if (!method.IsPartialDefinition)
return Diagnostic.Create(NotPartialDefinition, method.Locations[0]);

// Check if ScanForTypesAttribute is also on this method - that's not allowed
var hasServiceHandlerAttribute = method.GetAttributes()
.Any(a => a.AttributeClass?.ToDisplayString() == GenerateAttributeInfo.HandlerMetadataName);

if (hasServiceHandlerAttribute)
return Diagnostic.Create(CantMixServiceRegistrationsAndServiceHandler, method.Locations[0]);

var position = context.TargetNode.SpanStart;
var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method, context.SemanticModel)).ToArray();
var hasCustomHandlers = attributeData.Any(a => a.CustomHandler != null);
Expand Down Expand Up @@ -94,4 +101,63 @@ public partial class DependencyInjectionGenerator
var model = MethodModel.Create(method, context.TargetNode);
return new MethodWithAttributesModel(model, [.. attributeData]);
}

private static DiagnosticModel<MethodWithAttributesModel>? ParseHandlerMethodModel(GeneratorAttributeSyntaxContext context)
{
if (context.TargetSymbol is not IMethodSymbol method)
return null;

if (!method.IsPartialDefinition)
return Diagnostic.Create(NotPartialDefinition, method.Locations[0]);

// Skip if this method also has GenerateServiceRegistrationsAttribute - that provider reports the mixing error
var hasServiceRegistrationsAttribute = method.GetAttributes()
.Any(a => a.AttributeClass?.ToDisplayString() == GenerateAttributeInfo.MetadataName);

if (hasServiceRegistrationsAttribute)
return null;

var position = context.TargetNode.SpanStart;
var attributeData = context.Attributes.Select(a => AttributeModel.Create(a, method, context.SemanticModel)).ToArray();

foreach (var attribute in attributeData)
{
if (attribute.CustomHandler == null)
return Diagnostic.Create(MissingCustomHandlerOnGenerateServiceHandler, attribute.Location);

if (!attribute.HasSearchCriteria)
return Diagnostic.Create(MissingSearchCriteria, attribute.Location);

if (attribute.AssemblyOfTypeName != null && attribute.AssemblyNameFilter != null)
return Diagnostic.Create(CantUseBothFromAssemblyOfAndAssemblyNameFilter, attribute.Location);

var customHandlerMethod = method.ContainingType.GetMethod(attribute.CustomHandler, context.SemanticModel, position);

if (customHandlerMethod != null)
{
if (!customHandlerMethod.IsGenericMethod)
return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location);

var typesMatch = Enumerable.SequenceEqual(
method.Parameters.Select(p => p.Type),
customHandlerMethod.Parameters.Select(p => p.Type),
SymbolEqualityComparer.Default);

if (!typesMatch)
return Diagnostic.Create(CustomHandlerMethodHasIncorrectSignature, attribute.Location);
}

if (attribute.HasErrors)
return null;
}

if (!method.ReturnsVoid &&
(method.Parameters.Length == 0 || !SymbolEqualityComparer.Default.Equals(method.Parameters[0].Type, method.ReturnType)))
{
return Diagnostic.Create(WrongReturnTypeForCustomHandler, method.Locations[0]);
}

var model = MethodModel.Create(method, context.TargetNode);
return new MethodWithAttributesModel(model, [.. attributeData]);
}
}
Loading
Loading