diff --git a/README.md b/README.md index f724417..8687e3a 100644 --- a/README.md +++ b/README.md @@ -115,5 +115,6 @@ public static partial class ServiceCollectionExtensions | **AsImplementedInterfaces** | If true, the registered types will be registered as implemented interfaces instead of their actual type. | | **AsSelf** | If true, types will be registered with their actual type. It can be combined with `AsImplementedInterfaces`. In that case implemented interfaces will be "forwarded" to an actual implementation type | | **TypeNameFilter** | Set this value to filter the types to register by their full name. You can use '*' wildcards. You can also use ',' to separate multiple filters. | +| **WithAttribute** | Filter types by specified attribute type present. | | **KeySelector** | Set this value to a static method name returning string. Returned value will be used as a key for the registration. Method should either be generic, or have a single parameter of type `Type`. | -| **CustomHandler** | Set this property to a static generic method name in the current class. Set this property to a static generic method name in the current class. This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, `KeySelector` properties. | +| **CustomHandler** | Set this property to a static generic method name in the current class. This property is incompatible with `Lifetime`, `AsImplementedInterfaces`, `AsSelf`, `KeySelector` properties. | diff --git a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs index 9c05990..d364492 100644 --- a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs +++ b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs @@ -366,6 +366,78 @@ public class ServiceWithNonMatchingName {} Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); } + [Fact] + public void AddServicesWithAttributeFilter() + { + var attribute = """[GenerateServiceRegistrations(WithAttribute = typeof(ServiceAttribute))]"""; + + var compilation = CreateCompilation( + Sources.MethodWithAttribute(attribute), + """ + using System; + + namespace GeneratorTests; + + [AttributeUsage(AttributeTargets.Class)] + public sealed class ServiceAttribute : Attribute; + + [Service] + public class MyFirstService {} + + [Service] + public class MySecondService {} + + public class ServiceWithoutAttribute {} + """); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var registrations = $""" + return services + .AddTransient() + .AddTransient(); + """; + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + } + + [Fact] + public void AddServicesWithAttributeFilterAndTypeNameFilter() + { + var attribute = """[GenerateServiceRegistrations(WithAttribute = typeof(ServiceAttribute), TypeNameFilter = "*Service")]"""; + + var compilation = CreateCompilation( + Sources.MethodWithAttribute(attribute), + """ + using System; + + namespace GeneratorTests; + + [AttributeUsage(AttributeTargets.Class)] + public sealed class ServiceAttribute : Attribute; + + [Service] + public class MyFirstService {} + + public class MySecondServiceWithoutAttribute {} + + public class ServiceWithNonMatchingName {} + """); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(compilation) + .GetRunResult(); + + var registrations = $""" + return services + .AddTransient(); + """; + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + } + [Fact] public void AddServicesWithTypeNameFilter_MultipleGroups() { diff --git a/ServiceScan.SourceGenerator.Tests/ServiceScan.SourceGenerator.Tests.csproj b/ServiceScan.SourceGenerator.Tests/ServiceScan.SourceGenerator.Tests.csproj index f32ed40..40944e3 100644 --- a/ServiceScan.SourceGenerator.Tests/ServiceScan.SourceGenerator.Tests.csproj +++ b/ServiceScan.SourceGenerator.Tests/ServiceScan.SourceGenerator.Tests.csproj @@ -9,9 +9,9 @@ - + - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs index 64d9c84..7a0a11a 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -20,6 +20,10 @@ public partial class DependencyInjectionGenerator ? null : compilation.GetTypeByMetadataName(attribute.AssignableToTypeName); + var withAttributeType = attribute.WithAttributeTypeName is null + ? null + : compilation.GetTypeByMetadataName(attribute.WithAttributeTypeName); + if (assignableToType != null && attribute.AssignableToGenericArguments != null) { var typeArguments = attribute.AssignableToGenericArguments.Value.Select(t => compilation.GetTypeByMetadataName(t)).ToArray(); @@ -31,6 +35,12 @@ public partial class DependencyInjectionGenerator if (type.IsAbstract || type.IsStatic || !type.CanBeReferencedByName || type.TypeKind != TypeKind.Class) continue; + if (withAttributeType != null) + { + if (!type.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, withAttributeType))) + continue; + } + if (attribute.TypeNameFilter != null) { var regex = $"^({Regex.Escape(attribute.TypeNameFilter).Replace(@"\*", ".*").Replace(",", "|")})$"; diff --git a/ServiceScan.SourceGenerator/GenerateAttributeSource.cs b/ServiceScan.SourceGenerator/GenerateAttributeSource.cs index bc324aa..4e2caef 100644 --- a/ServiceScan.SourceGenerator/GenerateAttributeSource.cs +++ b/ServiceScan.SourceGenerator/GenerateAttributeSource.cs @@ -20,13 +20,18 @@ internal class GenerateServiceRegistrationsAttribute : Attribute /// If not specified, the assembly containing the method with this attribute will be used. /// public Type? FromAssemblyOf { get; set; } - + /// /// Set the type that the registered types must be assignable to. /// Types will be registered with this type as the service type, /// unless or is set. /// public Type? AssignableTo { get; set; } + + /// + /// Filter types by specified attribute type present. + /// + public Type? WithAttribute { get; set; } /// /// Set the lifetime of the registered services. diff --git a/ServiceScan.SourceGenerator/Model/AttributeModel.cs b/ServiceScan.SourceGenerator/Model/AttributeModel.cs index 40ba27a..249ae4c 100644 --- a/ServiceScan.SourceGenerator/Model/AttributeModel.cs +++ b/ServiceScan.SourceGenerator/Model/AttributeModel.cs @@ -7,6 +7,7 @@ record AttributeModel( string? AssignableToTypeName, EquatableArray? AssignableToGenericArguments, string? AssemblyOfTypeName, + string? WithAttributeTypeName, string Lifetime, string? TypeNameFilter, string? KeySelector, @@ -17,12 +18,13 @@ record AttributeModel( Location Location, bool HasErrors) { - public bool HasSearchCriteria => TypeNameFilter != null || AssignableToTypeName != null; + public bool HasSearchCriteria => TypeNameFilter != null || AssignableToTypeName != null || WithAttributeTypeName != null; public static AttributeModel Create(AttributeData attribute, IMethodSymbol method) { var assemblyType = attribute.NamedArguments.FirstOrDefault(a => a.Key == "FromAssemblyOf").Value.Value as INamedTypeSymbol; var assignableTo = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AssignableTo").Value.Value as INamedTypeSymbol; + var withAttributeType = attribute.NamedArguments.FirstOrDefault(a => a.Key == "WithAttribute").Value.Value as INamedTypeSymbol; var asImplementedInterfaces = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AsImplementedInterfaces").Value.Value is true; var asSelf = attribute.NamedArguments.FirstOrDefault(a => a.Key == "AsSelf").Value.Value is true; var typeNameFilter = attribute.NamedArguments.FirstOrDefault(a => a.Key == "TypeNameFilter").Value.Value as string; @@ -45,6 +47,7 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho if (string.IsNullOrWhiteSpace(typeNameFilter)) typeNameFilter = null; + var withAttributeTypeName = withAttributeType?.ToFullMetadataName(); var assemblyOfTypeName = assemblyType?.ToFullMetadataName(); var assignableToTypeName = assignableTo?.ToFullMetadataName(); EquatableArray? assignableToGenericArguments = assignableTo != null && assignableTo.IsGenericType && !assignableTo.IsUnboundGenericType @@ -62,12 +65,15 @@ public static AttributeModel Create(AttributeData attribute, IMethodSymbol metho var textSpan = attribute.ApplicationSyntaxReference.Span; var location = Location.Create(syntax, textSpan); - var hasError = assemblyType is { TypeKind: TypeKind.Error } || assignableTo is { TypeKind: TypeKind.Error }; + var hasError = assemblyType is { TypeKind: TypeKind.Error } + || assignableTo is { TypeKind: TypeKind.Error } + || withAttributeType is { TypeKind: TypeKind.Error }; return new( assignableToTypeName, assignableToGenericArguments, assemblyOfTypeName, + withAttributeTypeName, lifetime, typeNameFilter, keySelector, diff --git a/version.json b/version.json index 60d4107..4a04ae3 100644 --- a/version.json +++ b/version.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/dotnet/Nerdbank.GitVersioning/main/src/NerdBank.GitVersioning/version.schema.json", - "version": "1.3", + "version": "1.4", "publicReleaseRefSpec": [ "^refs/heads/main", "^refs/heads/v\\d+(?:\\.\\d+)?$"