diff --git a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs index 2c8e67b..af7aa1c 100644 --- a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs +++ b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs @@ -37,7 +37,7 @@ public class MyService2 : IService { } .Add{lifetime}() .Add{lifetime}(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -56,7 +56,7 @@ public void AddServicesFromAnotherAssembly() .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -98,7 +98,7 @@ public class MyService2 : Core.IService { } .AddScoped() .AddScoped(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -127,7 +127,7 @@ public class MyService2 : AbstractService { } .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -156,7 +156,7 @@ public class MyService2 : AbstractService { } .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -184,7 +184,7 @@ public class MyStringService : IService { } .AddTransient, global::GeneratorTests.MyIntService>() .AddTransient, global::GeneratorTests.MyStringService>(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -212,7 +212,7 @@ public class MyIntAndStringService : IService, IService, IOtherInte .AddTransient, global::GeneratorTests.MyIntAndStringService>() .AddTransient, global::GeneratorTests.MyIntAndStringService>(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -240,7 +240,7 @@ public class MyIntAndStringService : IService, IService { } .AddSingleton>(s => s.GetRequiredService()) .AddSingleton>(s => s.GetRequiredService()); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -267,7 +267,7 @@ public class MyStringService : IService { } return services .AddTransient, global::GeneratorTests.MyIntService>(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -295,7 +295,7 @@ public class MyService2 : AbstractService { } .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -323,7 +323,7 @@ public class MyService2 : AbstractService { } .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -348,7 +348,7 @@ public class MyService { } return services .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -376,7 +376,7 @@ public class MyStringService : AbstractService { } .AddTransient, global::GeneratorTests.MyIntService>() .AddTransient, global::GeneratorTests.MyStringService>(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -404,7 +404,7 @@ public class MyService2 : IGenericService { } .AddTransient(typeof(global::GeneratorTests.IGenericService<>), typeof(global::GeneratorTests.MyService1<>)) .AddTransient(typeof(global::GeneratorTests.IGenericService<>), typeof(global::GeneratorTests.MyService2<>)); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -432,7 +432,7 @@ public class MyService2 : IService { } .AddTransient(typeof(global::GeneratorTests.IService), typeof(global::GeneratorTests.MyService1<>)) .AddTransient(typeof(global::GeneratorTests.IService), typeof(global::GeneratorTests.MyService2<>)); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -461,7 +461,7 @@ public class ServiceWithNonMatchingName {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -498,7 +498,7 @@ public class ServiceWithoutAttribute {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -533,7 +533,7 @@ public class ServiceWithNonMatchingName {} return services .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -562,7 +562,7 @@ public class ServiceWithNonMatchingName {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -590,7 +590,7 @@ public class ThirdService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -626,7 +626,7 @@ public class ThirdService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -664,7 +664,7 @@ public class FourthService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -696,7 +696,7 @@ public class ThirdService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -728,7 +728,7 @@ public class ThirdService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -762,7 +762,7 @@ public class FourthService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -797,7 +797,7 @@ public class FourthService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -828,7 +828,7 @@ public class MyThirdService : IService { } .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -860,7 +860,7 @@ public class InterfacelessService {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -895,7 +895,7 @@ public class MyService: IServiceA, IServiceB {} .AddSingleton(s => s.GetRequiredService()) .AddSingleton(s => s.GetRequiredService()); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -934,7 +934,7 @@ public void Dispose() {} .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -971,7 +971,7 @@ private class NestedPrivateService : IService { } // Shouldn't be added as non-a .AddTransient() .AddTransient(); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -1002,7 +1002,7 @@ public class MyService2 : IService { } .AddKeyedTransient(GetName()) .AddKeyedTransient(GetName()); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -1033,7 +1033,7 @@ public class MyService2 : IService { } .AddKeyedTransient(GetName(typeof(global::GeneratorTests.MyService1))) .AddKeyedTransient(GetName(typeof(global::GeneratorTests.MyService2))); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -1069,7 +1069,7 @@ public class MyService2 : IService .AddKeyedTransient(global::GeneratorTests.MyService1.Key) .AddKeyedTransient(global::GeneratorTests.MyService2.Key); """; - Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[2].ToString()); } [Fact] @@ -1084,8 +1084,8 @@ public void DontGenerateAnythingIfTypeIsInvalid() .RunGenerators(compilation) .GetRunResult(); - // One file for generated attribute itself. - Assert.Single(results.GeneratedTrees); + // Two files: one for EmbeddedAttribute and one for generated attribute itself. + Assert.Equal(2, results.GeneratedTrees.Length); } private static Compilation CreateCompilation(params string[] source) diff --git a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs index 3fb9e01..4b65b65 100644 --- a/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs +++ b/ServiceScan.SourceGenerator.Tests/CustomHandlerTests.cs @@ -54,7 +54,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -102,7 +102,7 @@ public static partial void ProcessServices( string value, decimal number) } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -147,7 +147,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -201,7 +201,7 @@ public static partial class ServicesExtensions } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -259,7 +259,7 @@ public static partial IServiceCollection ProcessServices(this IServiceCollection } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -310,7 +310,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -360,7 +360,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -439,7 +439,7 @@ public static partial class ModelBuilderExtensions } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -487,7 +487,7 @@ public partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -538,7 +538,7 @@ public partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -592,7 +592,7 @@ public partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -645,7 +645,7 @@ public partial void ProcessServices( global::Microsoft.Extensions.DependencyInje } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -705,7 +705,7 @@ public static partial class ServiceCollectionExtensions } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -758,7 +758,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -803,7 +803,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -858,7 +858,7 @@ public static partial void AddHandlers() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -914,7 +914,7 @@ public static partial void AddHandlers() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -967,7 +967,7 @@ public static partial void AddProcessors() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -1028,7 +1028,7 @@ public static partial void AddHandlers() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -1079,7 +1079,7 @@ public static partial void ProcessServices() } } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } private static Compilation CreateCompilation(params string[] source) diff --git a/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs b/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs index f01284d..34ae7c0 100644 --- a/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs +++ b/ServiceScan.SourceGenerator.Tests/DiagnosticTests.cs @@ -171,7 +171,7 @@ public static partial class ServicesExtensions } } """; - Assert.Equal(expectedFile, results.GeneratedTrees[1].ToString()); + Assert.Equal(expectedFile, results.GeneratedTrees[2].ToString()); } [Fact] @@ -211,7 +211,7 @@ public static partial void AddServices(this global::Microsoft.Extensions.Depende } } """; - Assert.Equal(expectedFile, results.GeneratedTrees[1].ToString()); + Assert.Equal(expectedFile, results.GeneratedTrees[2].ToString()); } [Fact] diff --git a/ServiceScan.SourceGenerator.Tests/GeneratedMethodTests.cs b/ServiceScan.SourceGenerator.Tests/GeneratedMethodTests.cs index 2acb602..c0132b3 100644 --- a/ServiceScan.SourceGenerator.Tests/GeneratedMethodTests.cs +++ b/ServiceScan.SourceGenerator.Tests/GeneratedMethodTests.cs @@ -57,7 +57,7 @@ namespace GeneratorTests; } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -97,7 +97,7 @@ public static partial void AddServices(this IServiceCollection services) } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -137,7 +137,7 @@ public static partial IServiceCollection AddServices( IServiceCollection service } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -177,7 +177,7 @@ private partial void AddServices( IServiceCollection services) } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -216,7 +216,7 @@ public static partial IServiceCollection AddServices(this IServiceCollection ser } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } [Fact] @@ -256,7 +256,7 @@ public static partial IServiceCollection AddServices(this IServiceCollection str } """; - Assert.Equal(expected, results.GeneratedTrees[1].ToString()); + Assert.Equal(expected, results.GeneratedTrees[2].ToString()); } private static Compilation CreateCompilation(params string[] source) diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs index f2970f4..0ef50fc 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.cs @@ -15,6 +15,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(context => { + context.AddEmbeddedAttributeDefinition(); context.AddSource("ServiceScanAttributes.Generated.cs", SourceText.From(GenerateAttributeInfo.Source, Encoding.UTF8)); }); diff --git a/ServiceScan.SourceGenerator/Extensions/IncrementalGeneratorPostInitializationContextExtensions.cs b/ServiceScan.SourceGenerator/Extensions/IncrementalGeneratorPostInitializationContextExtensions.cs new file mode 100644 index 0000000..d764b34 --- /dev/null +++ b/ServiceScan.SourceGenerator/Extensions/IncrementalGeneratorPostInitializationContextExtensions.cs @@ -0,0 +1,22 @@ +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace ServiceScan.SourceGenerator.Extensions; + +internal static class IncrementalGeneratorPostInitializationContextExtensions +{ + private const string EmbeddedAttributeSource = """ + namespace Microsoft.CodeAnalysis + { + internal sealed partial class EmbeddedAttribute : global::System.Attribute + { + } + } + """; + + public static void AddEmbeddedAttributeDefinition(this IncrementalGeneratorPostInitializationContext context) + { + context.AddSource("Microsoft.CodeAnalysis.EmbeddedAttribute", SourceText.From(EmbeddedAttributeSource, Encoding.UTF8)); + } +} diff --git a/ServiceScan.SourceGenerator/GenerateAttributeInfo.cs b/ServiceScan.SourceGenerator/GenerateAttributeInfo.cs index 9fef6b0..d2dfc5e 100644 --- a/ServiceScan.SourceGenerator/GenerateAttributeInfo.cs +++ b/ServiceScan.SourceGenerator/GenerateAttributeInfo.cs @@ -8,12 +8,12 @@ internal static class GenerateAttributeInfo #nullable enable using System; - using System.Diagnostics; + using Microsoft.CodeAnalysis; using Microsoft.Extensions.DependencyInjection; namespace ServiceScan.SourceGenerator; - [Conditional("CODE_ANALYSIS")] + [Embedded] [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] internal class GenerateServiceRegistrationsAttribute : Attribute {