diff --git a/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs b/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs
index b28fe466ea..61f159ae1c 100644
--- a/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs
+++ b/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs
@@ -185,10 +185,10 @@ private void AddResourceDefinitions(Assembly assembly, ResourceDescriptor resour
private void RegisterImplementations(Assembly assembly, Type interfaceType, ResourceDescriptor resourceDescriptor)
{
var genericArguments = interfaceType.GetTypeInfo().GenericTypeParameters.Length == 2 ? new[] { resourceDescriptor.ResourceType, resourceDescriptor.IdType } : new[] { resourceDescriptor.ResourceType };
- var (implementation, registrationInterface) = TypeLocator.GetGenericInterfaceImplementation(assembly, interfaceType, genericArguments);
-
- if (implementation != null)
+ var result = TypeLocator.GetGenericInterfaceImplementation(assembly, interfaceType, genericArguments);
+ if (result != null)
{
+ var (implementation, registrationInterface) = result.Value;
_services.AddScoped(registrationInterface, implementation);
}
}
diff --git a/src/JsonApiDotNetCore/Configuration/TypeLocator.cs b/src/JsonApiDotNetCore/Configuration/TypeLocator.cs
index 6de8b18c14..fd8b420b1b 100644
--- a/src/JsonApiDotNetCore/Configuration/TypeLocator.cs
+++ b/src/JsonApiDotNetCore/Configuration/TypeLocator.cs
@@ -38,41 +38,52 @@ public static ResourceDescriptor TryGetResourceDescriptor(Type type)
}
///
- /// Gets all implementations of the generic interface.
+ /// Gets all implementations of a generic interface.
///
- /// The assembly to search.
- /// The open generic type, e.g. `typeof(IResourceService<>)`.
- /// Parameters to the generic type.
+ /// The assembly to search in.
+ /// The open generic interface.
+ /// Generic type parameters to construct the generic interface.
///
/// ), typeof(Article), typeof(Guid));
+ /// GetGenericInterfaceImplementation(assembly, typeof(IResourceService<,>), typeof(Article), typeof(Guid));
/// ]]>
///
- public static (Type implementation, Type registrationInterface) GetGenericInterfaceImplementation(Assembly assembly, Type openGenericInterface, params Type[] genericInterfaceArguments)
+ public static (Type implementation, Type registrationInterface)? GetGenericInterfaceImplementation(Assembly assembly, Type openGenericInterface, params Type[] interfaceGenericTypeArguments)
{
if (assembly == null) throw new ArgumentNullException(nameof(assembly));
if (openGenericInterface == null) throw new ArgumentNullException(nameof(openGenericInterface));
- if (genericInterfaceArguments == null) throw new ArgumentNullException(nameof(genericInterfaceArguments));
- if (genericInterfaceArguments.Length == 0) throw new ArgumentException("No arguments supplied for the generic interface.", nameof(genericInterfaceArguments));
- if (!openGenericInterface.IsGenericType) throw new ArgumentException("Requested type is not a generic type.", nameof(openGenericInterface));
+ if (interfaceGenericTypeArguments == null) throw new ArgumentNullException(nameof(interfaceGenericTypeArguments));
- foreach (var type in assembly.GetTypes())
+ if (!openGenericInterface.IsInterface || !openGenericInterface.IsGenericType ||
+ openGenericInterface != openGenericInterface.GetGenericTypeDefinition())
+ {
+ throw new ArgumentException($"Specified type '{openGenericInterface.FullName}' is not an open generic interface.", nameof(openGenericInterface));
+ }
+
+ if (interfaceGenericTypeArguments.Length != openGenericInterface.GetGenericArguments().Length)
{
- var interfaces = type.GetInterfaces();
- foreach (var @interface in interfaces)
+ throw new ArgumentException(
+ $"Interface '{openGenericInterface.FullName}' requires {openGenericInterface.GetGenericArguments().Length} type parameters instead of {interfaceGenericTypeArguments.Length}.",
+ nameof(interfaceGenericTypeArguments));
+ }
+
+ foreach (var nextType in assembly.GetTypes())
+ {
+ foreach (var nextGenericInterface in nextType.GetInterfaces().Where(x => x.IsGenericType))
{
- if (@interface.IsGenericType)
+ var nextOpenGenericInterface = nextGenericInterface.GetGenericTypeDefinition();
+ if (nextOpenGenericInterface == openGenericInterface)
{
- var genericTypeDefinition = @interface.GetGenericTypeDefinition();
- if (@interface.GetGenericArguments().First() == genericInterfaceArguments.First() &&genericTypeDefinition == openGenericInterface.GetGenericTypeDefinition())
+ var nextGenericArguments = nextGenericInterface.GetGenericArguments();
+ if (nextGenericArguments.Length == interfaceGenericTypeArguments.Length && nextGenericArguments.SequenceEqual(interfaceGenericTypeArguments))
{
- return (type, genericTypeDefinition.MakeGenericType(genericInterfaceArguments));
+ return (nextType, nextOpenGenericInterface.MakeGenericType(interfaceGenericTypeArguments));
}
}
}
}
- return (null, null);
+ return null;
}
///
diff --git a/test/UnitTests/Graph/TypeLocator_Tests.cs b/test/UnitTests/Graph/TypeLocator_Tests.cs
index cc3a3803b3..a5f5aa1803 100644
--- a/test/UnitTests/Graph/TypeLocator_Tests.cs
+++ b/test/UnitTests/Graph/TypeLocator_Tests.cs
@@ -19,15 +19,16 @@ public void GetGenericInterfaceImplementation_Gets_Implementation()
var expectedInterface = typeof(IGenericInterface);
// Act
- var (implementation, registrationInterface) = TypeLocator.GetGenericInterfaceImplementation(
+ var result = TypeLocator.GetGenericInterfaceImplementation(
assembly,
openGeneric,
genericArg
);
// Assert
- Assert.Equal(expectedImplementation, implementation);
- Assert.Equal(expectedInterface, registrationInterface);
+ Assert.NotNull(result);
+ Assert.Equal(expectedImplementation, result.Value.implementation);
+ Assert.Equal(expectedInterface, result.Value.registrationInterface);
}
[Fact]