Skip to content

Repositories now allows lookup of parent repositories for sub-types. #2406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.springframework.data.util.ProxyUtils;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ConcurrentLruCache;

/**
* Wrapper class to access repository instances obtained from a {@link ListableBeanFactory}.
Expand All @@ -58,6 +59,8 @@ public class Repositories implements Iterable<Class<?>> {
private final Optional<BeanFactory> beanFactory;
private final Map<Class<?>, String> repositoryBeanNames;
private final Map<Class<?>, RepositoryFactoryInformation<Object, Object>> repositoryFactoryInfos;
private final ConcurrentLruCache<Class<?>, Class<?>> domainTypeMapping = new ConcurrentLruCache<>(64,
this::getRepositoryDomainTypeFor);

/**
* Constructor to create the {@link #NONE} instance.
Expand Down Expand Up @@ -124,7 +127,7 @@ public boolean hasRepositoryFor(Class<?> domainClass) {

Assert.notNull(domainClass, DOMAIN_TYPE_MUST_NOT_BE_NULL);

Class<?> userClass = ProxyUtils.getUserClass(domainClass);
Class<?> userClass = domainTypeMapping.get(ProxyUtils.getUserClass(domainClass));

return repositoryFactoryInfos.containsKey(userClass);
}
Expand All @@ -139,7 +142,7 @@ public Optional<Object> getRepositoryFor(Class<?> domainClass) {

Assert.notNull(domainClass, DOMAIN_TYPE_MUST_NOT_BE_NULL);

Class<?> userClass = ProxyUtils.getUserClass(domainClass);
Class<?> userClass = domainTypeMapping.get(ProxyUtils.getUserClass(domainClass));
Optional<String> repositoryBeanName = Optional.ofNullable(repositoryBeanNames.get(userClass));

return beanFactory.flatMap(it -> repositoryBeanName.map(it::getBean));
Expand All @@ -157,7 +160,7 @@ private RepositoryFactoryInformation<Object, Object> getRepositoryFactoryInfoFor

Assert.notNull(domainClass, DOMAIN_TYPE_MUST_NOT_BE_NULL);

Class<?> userType = ProxyUtils.getUserClass(domainClass);
Class<?> userType = domainTypeMapping.get(ProxyUtils.getUserClass(domainClass));
RepositoryFactoryInformation<Object, Object> repositoryInfo = repositoryFactoryInfos.get(userType);

if (repositoryInfo != null) {
Expand Down Expand Up @@ -303,6 +306,33 @@ private void cacheFirstOrPrimary(Class<?> type, RepositoryFactoryInformation inf
this.repositoryBeanNames.put(type, name);
}

/**
* Returns the repository domain type for which to look up the repository. The input can either be a repository
* managed type directly. Or it can be a sub-type of a repository managed one, in which case we check the domain types
* we have repositories registered for for assignability.
*
* @param domainType must not be {@literal null}.
* @return
*/
private Class<?> getRepositoryDomainTypeFor(Class<?> domainType) {

Assert.notNull(domainType, "Domain type must not be null!");

Set<Class<?>> declaredTypes = repositoryBeanNames.keySet();

if (declaredTypes.contains(domainType)) {
return domainType;
}

for (Class<?> declaredType : declaredTypes) {
if (declaredType.isAssignableFrom(domainType)) {
return declaredType;
}
}

return domainType;
}

/**
* Null-object to avoid nasty {@literal null} checks in cache lookups.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;

import org.springframework.aop.framework.ProxyFactory;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
Expand Down Expand Up @@ -206,6 +205,47 @@ void keepsPrimaryRepositoryInCaseOfMultipleOnesIfContextIsNotAConfigurableListab
});
}

@Test // GH-2406
void exposesParentRepositoryForChildIfOnlyParentRepositoryIsRegistered() {

Repositories repositories = bootstrapRepositories(ParentRepository.class);

assertRepositoryAvailableFor(repositories, Child.class, ParentRepository.class);
}

@Test // GH-2406
void usesChildRepositoryIfRegistered() {

Repositories repositories = bootstrapRepositories(ParentRepository.class, ChildRepository.class);

assertRepositoryAvailableFor(repositories, Child.class, ChildRepository.class);
}

private void assertRepositoryAvailableFor(Repositories repositories, Class<?> domainTypem,
Class<?> repositoryInterface) {

assertThat(repositories.hasRepositoryFor(domainTypem)).isTrue();
assertThat(repositories.getRepositoryFor(domainTypem))
.hasValueSatisfying(it -> assertThat(it).isInstanceOf(repositoryInterface));
assertThat(repositories.getRepositoryInformationFor(domainTypem))
.hasValueSatisfying(it -> assertThat(it.getRepositoryInterface()).isEqualTo(repositoryInterface));
}

private Repositories bootstrapRepositories(Class<?>... repositoryInterfaces) {

DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();

for (Class<?> repositoryInterface : repositoryInterfaces) {
beanFactory.registerBeanDefinition(repositoryInterface.getName(),
getRepositoryBeanDefinition(repositoryInterface));
}

context = new GenericApplicationContext(beanFactory);
context.refresh();

return new Repositories(context);
}

class Person {}

class Address {}
Expand Down Expand Up @@ -301,4 +341,14 @@ interface FirstRepository extends Repository<SomeEntity, Long> {}
interface PrimaryRepository extends Repository<SomeEntity, Long> {}

interface ThirdRepository extends Repository<SomeEntity, Long> {}

// GH-2406

static class Parent {}

static class Child extends Parent {}

interface ParentRepository extends Repository<Parent, Long> {}

interface ChildRepository extends Repository<Child, Long> {}
}