diff --git a/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs b/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs index b28fe466ea..61f159ae1c 100644 --- a/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs +++ b/src/JsonApiDotNetCore/Configuration/ServiceDiscoveryFacade.cs @@ -185,10 +185,10 @@ private void AddResourceDefinitions(Assembly assembly, ResourceDescriptor resour private void RegisterImplementations(Assembly assembly, Type interfaceType, ResourceDescriptor resourceDescriptor) { var genericArguments = interfaceType.GetTypeInfo().GenericTypeParameters.Length == 2 ? new[] { resourceDescriptor.ResourceType, resourceDescriptor.IdType } : new[] { resourceDescriptor.ResourceType }; - var (implementation, registrationInterface) = TypeLocator.GetGenericInterfaceImplementation(assembly, interfaceType, genericArguments); - - if (implementation != null) + var result = TypeLocator.GetGenericInterfaceImplementation(assembly, interfaceType, genericArguments); + if (result != null) { + var (implementation, registrationInterface) = result.Value; _services.AddScoped(registrationInterface, implementation); } } diff --git a/src/JsonApiDotNetCore/Configuration/TypeLocator.cs b/src/JsonApiDotNetCore/Configuration/TypeLocator.cs index 6de8b18c14..fd8b420b1b 100644 --- a/src/JsonApiDotNetCore/Configuration/TypeLocator.cs +++ b/src/JsonApiDotNetCore/Configuration/TypeLocator.cs @@ -38,41 +38,52 @@ public static ResourceDescriptor TryGetResourceDescriptor(Type type) } /// - /// Gets all implementations of the generic interface. + /// Gets all implementations of a generic interface. /// - /// The assembly to search. - /// The open generic type, e.g. `typeof(IResourceService<>)`. - /// Parameters to the generic type. + /// The assembly to search in. + /// The open generic interface. + /// Generic type parameters to construct the generic interface. /// /// ), typeof(Article), typeof(Guid)); + /// GetGenericInterfaceImplementation(assembly, typeof(IResourceService<,>), typeof(Article), typeof(Guid)); /// ]]> /// - public static (Type implementation, Type registrationInterface) GetGenericInterfaceImplementation(Assembly assembly, Type openGenericInterface, params Type[] genericInterfaceArguments) + public static (Type implementation, Type registrationInterface)? GetGenericInterfaceImplementation(Assembly assembly, Type openGenericInterface, params Type[] interfaceGenericTypeArguments) { if (assembly == null) throw new ArgumentNullException(nameof(assembly)); if (openGenericInterface == null) throw new ArgumentNullException(nameof(openGenericInterface)); - if (genericInterfaceArguments == null) throw new ArgumentNullException(nameof(genericInterfaceArguments)); - if (genericInterfaceArguments.Length == 0) throw new ArgumentException("No arguments supplied for the generic interface.", nameof(genericInterfaceArguments)); - if (!openGenericInterface.IsGenericType) throw new ArgumentException("Requested type is not a generic type.", nameof(openGenericInterface)); + if (interfaceGenericTypeArguments == null) throw new ArgumentNullException(nameof(interfaceGenericTypeArguments)); - foreach (var type in assembly.GetTypes()) + if (!openGenericInterface.IsInterface || !openGenericInterface.IsGenericType || + openGenericInterface != openGenericInterface.GetGenericTypeDefinition()) + { + throw new ArgumentException($"Specified type '{openGenericInterface.FullName}' is not an open generic interface.", nameof(openGenericInterface)); + } + + if (interfaceGenericTypeArguments.Length != openGenericInterface.GetGenericArguments().Length) { - var interfaces = type.GetInterfaces(); - foreach (var @interface in interfaces) + throw new ArgumentException( + $"Interface '{openGenericInterface.FullName}' requires {openGenericInterface.GetGenericArguments().Length} type parameters instead of {interfaceGenericTypeArguments.Length}.", + nameof(interfaceGenericTypeArguments)); + } + + foreach (var nextType in assembly.GetTypes()) + { + foreach (var nextGenericInterface in nextType.GetInterfaces().Where(x => x.IsGenericType)) { - if (@interface.IsGenericType) + var nextOpenGenericInterface = nextGenericInterface.GetGenericTypeDefinition(); + if (nextOpenGenericInterface == openGenericInterface) { - var genericTypeDefinition = @interface.GetGenericTypeDefinition(); - if (@interface.GetGenericArguments().First() == genericInterfaceArguments.First() &&genericTypeDefinition == openGenericInterface.GetGenericTypeDefinition()) + var nextGenericArguments = nextGenericInterface.GetGenericArguments(); + if (nextGenericArguments.Length == interfaceGenericTypeArguments.Length && nextGenericArguments.SequenceEqual(interfaceGenericTypeArguments)) { - return (type, genericTypeDefinition.MakeGenericType(genericInterfaceArguments)); + return (nextType, nextOpenGenericInterface.MakeGenericType(interfaceGenericTypeArguments)); } } } } - return (null, null); + return null; } /// diff --git a/test/UnitTests/Graph/TypeLocator_Tests.cs b/test/UnitTests/Graph/TypeLocator_Tests.cs index cc3a3803b3..a5f5aa1803 100644 --- a/test/UnitTests/Graph/TypeLocator_Tests.cs +++ b/test/UnitTests/Graph/TypeLocator_Tests.cs @@ -19,15 +19,16 @@ public void GetGenericInterfaceImplementation_Gets_Implementation() var expectedInterface = typeof(IGenericInterface); // Act - var (implementation, registrationInterface) = TypeLocator.GetGenericInterfaceImplementation( + var result = TypeLocator.GetGenericInterfaceImplementation( assembly, openGeneric, genericArg ); // Assert - Assert.Equal(expectedImplementation, implementation); - Assert.Equal(expectedInterface, registrationInterface); + Assert.NotNull(result); + Assert.Equal(expectedImplementation, result.Value.implementation); + Assert.Equal(expectedInterface, result.Value.registrationInterface); } [Fact]