diff --git a/src/main/java/org/springframework/data/repository/support/Repositories.java b/src/main/java/org/springframework/data/repository/support/Repositories.java index fb1e4ac855..65ffe8eb9c 100644 --- a/src/main/java/org/springframework/data/repository/support/Repositories.java +++ b/src/main/java/org/springframework/data/repository/support/Repositories.java @@ -39,6 +39,7 @@ import org.springframework.data.util.ProxyUtils; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.ConcurrentLruCache; /** * Wrapper class to access repository instances obtained from a {@link ListableBeanFactory}. @@ -58,6 +59,8 @@ public class Repositories implements Iterable> { private final Optional beanFactory; private final Map, String> repositoryBeanNames; private final Map, RepositoryFactoryInformation> repositoryFactoryInfos; + private final ConcurrentLruCache, Class> domainTypeMapping = new ConcurrentLruCache<>(64, + this::getRepositoryDomainTypeFor); /** * Constructor to create the {@link #NONE} instance. @@ -124,7 +127,7 @@ public boolean hasRepositoryFor(Class domainClass) { Assert.notNull(domainClass, DOMAIN_TYPE_MUST_NOT_BE_NULL); - Class userClass = ProxyUtils.getUserClass(domainClass); + Class userClass = domainTypeMapping.get(ProxyUtils.getUserClass(domainClass)); return repositoryFactoryInfos.containsKey(userClass); } @@ -139,7 +142,7 @@ public Optional getRepositoryFor(Class domainClass) { Assert.notNull(domainClass, DOMAIN_TYPE_MUST_NOT_BE_NULL); - Class userClass = ProxyUtils.getUserClass(domainClass); + Class userClass = domainTypeMapping.get(ProxyUtils.getUserClass(domainClass)); Optional repositoryBeanName = Optional.ofNullable(repositoryBeanNames.get(userClass)); return beanFactory.flatMap(it -> repositoryBeanName.map(it::getBean)); @@ -157,7 +160,7 @@ private RepositoryFactoryInformation getRepositoryFactoryInfoFor Assert.notNull(domainClass, DOMAIN_TYPE_MUST_NOT_BE_NULL); - Class userType = ProxyUtils.getUserClass(domainClass); + Class userType = domainTypeMapping.get(ProxyUtils.getUserClass(domainClass)); RepositoryFactoryInformation repositoryInfo = repositoryFactoryInfos.get(userType); if (repositoryInfo != null) { @@ -303,6 +306,33 @@ private void cacheFirstOrPrimary(Class type, RepositoryFactoryInformation inf this.repositoryBeanNames.put(type, name); } + /** + * Returns the repository domain type for which to look up the repository. The input can either be a repository + * managed type directly. Or it can be a sub-type of a repository managed one, in which case we check the domain types + * we have repositories registered for for assignability. + * + * @param domainType must not be {@literal null}. + * @return + */ + private Class getRepositoryDomainTypeFor(Class domainType) { + + Assert.notNull(domainType, "Domain type must not be null!"); + + Set> declaredTypes = repositoryBeanNames.keySet(); + + if (declaredTypes.contains(domainType)) { + return domainType; + } + + for (Class declaredType : declaredTypes) { + if (declaredType.isAssignableFrom(domainType)) { + return declaredType; + } + } + + return domainType; + } + /** * Null-object to avoid nasty {@literal null} checks in cache lookups. * diff --git a/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java b/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java index 451d573d64..47ede32019 100755 --- a/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java +++ b/src/test/java/org/springframework/data/repository/support/RepositoriesUnitTests.java @@ -29,7 +29,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; - import org.springframework.aop.framework.ProxyFactory; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; @@ -206,6 +205,47 @@ void keepsPrimaryRepositoryInCaseOfMultipleOnesIfContextIsNotAConfigurableListab }); } + @Test // GH-2406 + void exposesParentRepositoryForChildIfOnlyParentRepositoryIsRegistered() { + + Repositories repositories = bootstrapRepositories(ParentRepository.class); + + assertRepositoryAvailableFor(repositories, Child.class, ParentRepository.class); + } + + @Test // GH-2406 + void usesChildRepositoryIfRegistered() { + + Repositories repositories = bootstrapRepositories(ParentRepository.class, ChildRepository.class); + + assertRepositoryAvailableFor(repositories, Child.class, ChildRepository.class); + } + + private void assertRepositoryAvailableFor(Repositories repositories, Class domainTypem, + Class repositoryInterface) { + + assertThat(repositories.hasRepositoryFor(domainTypem)).isTrue(); + assertThat(repositories.getRepositoryFor(domainTypem)) + .hasValueSatisfying(it -> assertThat(it).isInstanceOf(repositoryInterface)); + assertThat(repositories.getRepositoryInformationFor(domainTypem)) + .hasValueSatisfying(it -> assertThat(it.getRepositoryInterface()).isEqualTo(repositoryInterface)); + } + + private Repositories bootstrapRepositories(Class... repositoryInterfaces) { + + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + + for (Class repositoryInterface : repositoryInterfaces) { + beanFactory.registerBeanDefinition(repositoryInterface.getName(), + getRepositoryBeanDefinition(repositoryInterface)); + } + + context = new GenericApplicationContext(beanFactory); + context.refresh(); + + return new Repositories(context); + } + class Person {} class Address {} @@ -301,4 +341,14 @@ interface FirstRepository extends Repository {} interface PrimaryRepository extends Repository {} interface ThirdRepository extends Repository {} + + // GH-2406 + + static class Parent {} + + static class Child extends Parent {} + + interface ParentRepository extends Repository {} + + interface ChildRepository extends Repository {} }