diff --git a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs index 9208a71..17e8cff 100644 --- a/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs +++ b/ServiceScan.SourceGenerator.Tests/AddServicesTests.cs @@ -59,6 +59,48 @@ public void AddServicesFromAnotherAssembly() Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); } + [Fact] + public void AddServicesFromReferencedCompilationsByDefault() + { + var coreCompilation = CreateCompilation( + """ + namespace Core; + public interface IService { } + """) + .WithAssemblyName("Core"); + + var implementation1Compilation = CreateCompilation([""" + namespace Module1; + public class MyService1 : Core.IService { } + """], + [coreCompilation]) + .WithAssemblyName("Module1"); + + var implementation2Compilation = CreateCompilation([""" + namespace Module2; + public class MyService2 : Core.IService { } + """], + [coreCompilation]) + .WithAssemblyName("Module2"); + + var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(Core.IService), Lifetime = ServiceLifetime.Scoped)]"; + var registrationsCompilation = CreateCompilation( + [Sources.MethodWithAttribute(attribute)], + [coreCompilation, implementation1Compilation, implementation2Compilation]); + + var results = CSharpGeneratorDriver + .Create(_generator) + .RunGenerators(registrationsCompilation) + .GetRunResult(); + + var registrations = $""" + return services + .AddScoped() + .AddScoped(); + """; + Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString()); + } + [Fact] public void AddServiceWithNonDirectInterface() { @@ -990,20 +1032,29 @@ public void DontGenerateAnythingIfTypeIsInvalid() } private static Compilation CreateCompilation(params string[] source) + { + return CreateCompilation(source, []); + } + + private static Compilation CreateCompilation(string[] source, Compilation[] referencedCompilations) { var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!; var runtimeAssemblyPath = Path.Combine(path, "System.Runtime.dll"); var runtimeReference = MetadataReference.CreateFromFile(typeof(object).Assembly.Location); - return CSharpCompilation.Create("compilation", - source.Select(s => CSharpSyntaxTree.ParseText(s)), - [ + var metadataReferences = new MetadataReference[] + { MetadataReference.CreateFromFile(typeof(object).Assembly.Location), MetadataReference.CreateFromFile(runtimeAssemblyPath), MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location), - MetadataReference.CreateFromFile(typeof(External.IExternalService).Assembly.Location), - ], + MetadataReference.CreateFromFile(typeof(External.IExternalService).Assembly.Location) + } + .Concat(referencedCompilations.Select(c => c.ToMetadataReference())); + + return CSharpCompilation.Create("compilation", + source.Select(s => CSharpSyntaxTree.ParseText(s)), + metadataReferences, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); } } diff --git a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs index a8ba52a..defef0b 100644 --- a/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs +++ b/ServiceScan.SourceGenerator/DependencyInjectionGenerator.FilterTypes.cs @@ -15,8 +15,10 @@ public partial class DependencyInjectionGenerator var assemblyOfType = attribute.AssemblyOfTypeName is null ? null : compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName); - - var assembly = (assemblyOfType ?? containingType).ContainingAssembly; + + var assemblies = assemblyOfType is not null + ? [assemblyOfType.ContainingAssembly] + : GetSolutionAssemblies(compilation); var assignableToType = attribute.AssignableToTypeName is null ? null @@ -49,7 +51,7 @@ public partial class DependencyInjectionGenerator excludeAssignableToType = excludeAssignableToType.Construct(typeArguments); } - foreach (var type in GetTypesFromAssembly(assembly)) + foreach (var type in assemblies.SelectMany(GetTypesFromAssembly)) { if (type.IsAbstract || type.IsStatic || !type.CanBeReferencedByName || type.TypeKind != TypeKind.Class) continue; @@ -71,7 +73,7 @@ public partial class DependencyInjectionGenerator if (excludeByTypeNameRegex != null && excludeByTypeNameRegex.IsMatch(type.ToDisplayString())) continue; - + if (excludeAssignableToType != null && IsAssignableTo(type, excludeAssignableToType, out _)) continue; @@ -137,6 +139,19 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig return false; } + private static IEnumerable GetSolutionAssemblies(Compilation compilation) + { + yield return compilation.Assembly; + + foreach (var reference in compilation.References) + { + if (reference is CompilationReference) + { + yield return (IAssemblySymbol)compilation.GetAssemblyOrModuleSymbol(reference); + } + } + } + private static IEnumerable GetTypesFromAssembly(IAssemblySymbol assemblySymbol) { var @namespace = assemblySymbol.GlobalNamespace;