Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions ServiceScan.SourceGenerator.Tests/AddServicesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,48 @@ public void AddServicesFromAnotherAssembly()
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
}

[Fact]
public void AddServicesFromReferencedCompilationsByDefault()
{
var coreCompilation = CreateCompilation(
"""
namespace Core;
public interface IService { }
""")
.WithAssemblyName("Core");

var implementation1Compilation = CreateCompilation(["""
namespace Module1;
public class MyService1 : Core.IService { }
"""],
[coreCompilation])
.WithAssemblyName("Module1");

var implementation2Compilation = CreateCompilation(["""
namespace Module2;
public class MyService2 : Core.IService { }
"""],
[coreCompilation])
.WithAssemblyName("Module2");

var attribute = $"[GenerateServiceRegistrations(AssignableTo = typeof(Core.IService), Lifetime = ServiceLifetime.Scoped)]";
var registrationsCompilation = CreateCompilation(
[Sources.MethodWithAttribute(attribute)],
[coreCompilation, implementation1Compilation, implementation2Compilation]);

var results = CSharpGeneratorDriver
.Create(_generator)
.RunGenerators(registrationsCompilation)
.GetRunResult();

var registrations = $"""
return services
.AddScoped<global::Core.IService, global::Module1.MyService1>()
.AddScoped<global::Core.IService, global::Module2.MyService2>();
""";
Assert.Equal(Sources.GetMethodImplementation(registrations), results.GeneratedTrees[1].ToString());
}

[Fact]
public void AddServiceWithNonDirectInterface()
{
Expand Down Expand Up @@ -990,20 +1032,29 @@ public void DontGenerateAnythingIfTypeIsInvalid()
}

private static Compilation CreateCompilation(params string[] source)
{
return CreateCompilation(source, []);
}

private static Compilation CreateCompilation(string[] source, Compilation[] referencedCompilations)
{
var path = Path.GetDirectoryName(typeof(object).Assembly.Location)!;
var runtimeAssemblyPath = Path.Combine(path, "System.Runtime.dll");

var runtimeReference = MetadataReference.CreateFromFile(typeof(object).Assembly.Location);

return CSharpCompilation.Create("compilation",
source.Select(s => CSharpSyntaxTree.ParseText(s)),
[
var metadataReferences = new MetadataReference[]
{
MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
MetadataReference.CreateFromFile(runtimeAssemblyPath),
MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location),
MetadataReference.CreateFromFile(typeof(External.IExternalService).Assembly.Location),
],
MetadataReference.CreateFromFile(typeof(External.IExternalService).Assembly.Location)
}
.Concat(referencedCompilations.Select(c => c.ToMetadataReference()));

return CSharpCompilation.Create("compilation",
source.Select(s => CSharpSyntaxTree.ParseText(s)),
metadataReferences,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ public partial class DependencyInjectionGenerator
var assemblyOfType = attribute.AssemblyOfTypeName is null
? null
: compilation.GetTypeByMetadataName(attribute.AssemblyOfTypeName);

var assembly = (assemblyOfType ?? containingType).ContainingAssembly;

var assemblies = assemblyOfType is not null
? [assemblyOfType.ContainingAssembly]
: GetSolutionAssemblies(compilation);
Copy link

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment to explain that GetSolutionAssemblies returns all solution assemblies when assemblyOfType is null, providing context on why scanning all referenced compilations is desired.

Copilot uses AI. Check for mistakes.

var assignableToType = attribute.AssignableToTypeName is null
? null
Expand Down Expand Up @@ -49,7 +51,7 @@ public partial class DependencyInjectionGenerator
excludeAssignableToType = excludeAssignableToType.Construct(typeArguments);
}

foreach (var type in GetTypesFromAssembly(assembly))
foreach (var type in assemblies.SelectMany(GetTypesFromAssembly))
Copy link

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Review the potential performance impact when scanning all solution assemblies. If the number of referenced compilations is large, consider filtering to include only assembly types that are relevant for dependency injection.

Copilot uses AI. Check for mistakes.
{
if (type.IsAbstract || type.IsStatic || !type.CanBeReferencedByName || type.TypeKind != TypeKind.Class)
continue;
Expand All @@ -71,7 +73,7 @@ public partial class DependencyInjectionGenerator

if (excludeByTypeNameRegex != null && excludeByTypeNameRegex.IsMatch(type.ToDisplayString()))
continue;

if (excludeAssignableToType != null && IsAssignableTo(type, excludeAssignableToType, out _))
continue;

Expand Down Expand Up @@ -137,6 +139,19 @@ private static bool IsAssignableTo(INamedTypeSymbol type, INamedTypeSymbol assig
return false;
}

private static IEnumerable<IAssemblySymbol> GetSolutionAssemblies(Compilation compilation)
{
yield return compilation.Assembly;

foreach (var reference in compilation.References)
{
if (reference is CompilationReference)
{
yield return (IAssemblySymbol)compilation.GetAssemblyOrModuleSymbol(reference);
}
}
}

private static IEnumerable<INamedTypeSymbol> GetTypesFromAssembly(IAssemblySymbol assemblySymbol)
{
var @namespace = assemblySymbol.GlobalNamespace;
Expand Down
Loading