Skip to content

Commit 8f1d619

Browse files
committed
Fix query execution mode detection for aggregate types that implement Streamable.
We now short-circuit the QueryMethod.isCollectionQuery() algorithm in case we find the concrete domain type or any subclass of it. Fixes #2869.
1 parent 6bfc7c8 commit 8f1d619

File tree

5 files changed

+118
-6
lines changed

5 files changed

+118
-6
lines changed

src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,13 @@ public TypeInformation<?> getReturnType(Method method) {
9696
return returnType;
9797
}
9898

99+
@Override
99100
public Class<?> getReturnedDomainClass(Method method) {
100101

101102
TypeInformation<?> returnType = getReturnType(method);
103+
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);
102104

103-
return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType))
104-
.getType();
105+
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
105106
}
106107

107108
public Class<?> getRepositoryInterface() {

src/main/java/org/springframework/data/repository/query/QueryMethod.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.data.repository.util.QueryExecutionConverters;
3232
import org.springframework.data.repository.util.ReactiveWrapperConverters;
3333
import org.springframework.data.util.Lazy;
34+
import org.springframework.data.util.NullableWrapperConverters;
3435
import org.springframework.data.util.TypeInformation;
3536
import org.springframework.util.Assert;
3637

@@ -270,7 +271,15 @@ private boolean calculateIsCollectionQuery() {
270271
return false;
271272
}
272273

273-
Class<?> returnType = metadata.getReturnType(method).getType();
274+
TypeInformation<?> returnTypeInformation = metadata.getReturnType(method);
275+
276+
// Check against simple wrapper types first
277+
if (metadata.getDomainTypeInformation()
278+
.isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) {
279+
return false;
280+
}
281+
282+
Class<?> returnType = returnTypeInformation.getType();
274283

275284
if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) {
276285
return true;

src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public abstract class QueryExecutionConverters {
8484
private static final Set<Class<?>> ALLOWED_PAGEABLE_TYPES = new HashSet<>();
8585
private static final Map<Class<?>, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>();
8686
private static final Map<Class<?>, Boolean> supportsCache = new ConcurrentReferenceHashMap<>();
87+
private static final TypeInformation<Void> VOID_INFORMATION = TypeInformation.of(Void.class);
8788

8889
static {
8990

@@ -233,15 +234,21 @@ public static Object unwrap(@Nullable Object source) {
233234
}
234235

235236
/**
236-
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
237+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given
238+
* reference type.
237239
*
238240
* @param type must not be {@literal null}.
241+
* @param reference must not be {@literal null}.
239242
* @return will never be {@literal null}.
240243
*/
241-
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
244+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, TypeInformation<?> reference) {
242245

243246
Assert.notNull(type, "type must not be null");
244247

248+
if (reference.isAssignableFrom(type)) {
249+
return type;
250+
}
251+
245252
Class<?> rawType = type.getType();
246253

247254
boolean needToUnwrap = type.isCollectionLike() //
@@ -251,7 +258,17 @@ public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
251258
|| supports(rawType) //
252259
|| Stream.class.isAssignableFrom(rawType);
253260

254-
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type;
261+
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type;
262+
}
263+
264+
/**
265+
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
266+
*
267+
* @param type must not be {@literal null}.
268+
* @return will never be {@literal null}.
269+
*/
270+
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
271+
return unwrapWrapperTypes(type, VOID_INFORMATION);
255272
}
256273

257274
/**

src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,21 @@
2222
import java.util.Collections;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.Optional;
2526
import java.util.Set;
27+
import java.util.stream.Stream;
2628

29+
30+
import org.junit.jupiter.api.DynamicTest;
2731
import org.junit.jupiter.api.Test;
32+
import org.junit.jupiter.api.TestFactory;
2833
import org.springframework.data.domain.Page;
2934
import org.springframework.data.domain.Pageable;
3035
import org.springframework.data.querydsl.User;
3136
import org.springframework.data.repository.PagingAndSortingRepository;
3237
import org.springframework.data.repository.Repository;
3338
import org.springframework.data.repository.core.RepositoryMetadata;
39+
import org.springframework.data.util.Streamable;
3440

3541
/**
3642
* Unit tests for {@link AbstractRepositoryMetadata}.
@@ -113,6 +119,25 @@ void doesNotUnwrapCustomTypeImplementingIterable() throws Exception {
113119
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class);
114120
}
115121

122+
@TestFactory // GH-2869
123+
Stream<DynamicTest> detectsReturnTypesForStreamableAggregates() throws Exception {
124+
125+
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
126+
var methods = Stream.of(
127+
Map.entry("findBy", StreamableAggregate.class),
128+
Map.entry("findSubTypeBy", StreamableAggregateSubType.class),
129+
Map.entry("findAllBy", StreamableAggregate.class),
130+
Map.entry("findOptional", StreamableAggregate.class));
131+
132+
return DynamicTest.stream(methods, //
133+
it -> it.getKey() + "'s returned domain class is " + it.getValue(), //
134+
it -> {
135+
136+
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
137+
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue());
138+
});
139+
}
140+
116141
interface UserRepository extends Repository<User, Long> {
117142

118143
User findSingle();
@@ -157,4 +182,20 @@ interface ContainerRepository extends Repository<Container, Long> {
157182

158183
interface CompletePageableAndSortingRepository extends PagingAndSortingRepository<Container, Long> {}
159184

185+
// GH-2869
186+
187+
static abstract class StreamableAggregate implements Streamable<Object> {}
188+
189+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
190+
191+
StreamableAggregate findBy();
192+
193+
StreamableAggregateSubType findSubTypeBy();
194+
195+
Streamable<StreamableAggregate> findAllBy();
196+
197+
Optional<StreamableAggregate> findOptional();
198+
}
199+
200+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
160201
}

src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@
2323

2424
import java.io.Serializable;
2525
import java.util.List;
26+
import java.util.Map;
27+
import java.util.Optional;
2628
import java.util.concurrent.CompletableFuture;
2729
import java.util.concurrent.Future;
2830
import java.util.stream.Stream;
2931

3032
import org.eclipse.collections.api.list.ImmutableList;
33+
import org.junit.jupiter.api.DynamicTest;
3134
import org.junit.jupiter.api.Test;
35+
import org.junit.jupiter.api.TestFactory;
3236
import org.springframework.data.domain.Page;
3337
import org.springframework.data.domain.Pageable;
3438
import org.springframework.data.domain.Slice;
@@ -38,6 +42,7 @@
3842
import org.springframework.data.repository.core.RepositoryMetadata;
3943
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
4044
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
45+
import org.springframework.data.util.Streamable;
4146

4247
/**
4348
* Unit tests for {@link QueryMethod}.
@@ -257,6 +262,28 @@ void considersEclipseCollectionCollectionQuery() throws Exception {
257262
assertThat(queryMethod.isCollectionQuery()).isTrue();
258263
}
259264

265+
@TestFactory // GH-2869
266+
Stream<DynamicTest> doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery()
267+
throws Exception {
268+
269+
var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
270+
var stream = Stream.of(
271+
Map.entry("findBy", false),
272+
Map.entry("findSubTypeBy", false),
273+
Map.entry("findAllBy", true),
274+
Map.entry("findOptionalBy", false));
275+
276+
return DynamicTest.stream(stream, //
277+
it -> it.getKey() + " considered collection query -> " + it.getValue(), //
278+
it -> {
279+
280+
var method = StreamableAggregateRepository.class.getMethod(it.getKey());
281+
var queryMethod = new QueryMethod(method, metadata, factory);
282+
283+
assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue());
284+
});
285+
}
286+
260287
interface SampleRepository extends Repository<User, Serializable> {
261288

262289
String pagingMethodWithInvalidReturnType(Pageable pageable);
@@ -324,4 +351,21 @@ abstract class Container implements Iterable<Element> {}
324351
interface ContainerRepository extends Repository<Container, Long> {
325352
Container someMethod();
326353
}
354+
355+
// GH-2869
356+
357+
static abstract class StreamableAggregate implements Streamable<Object> {}
358+
359+
interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {
360+
361+
StreamableAggregate findBy();
362+
363+
StreamableAggregateSubType findSubTypeBy();
364+
365+
Optional<StreamableAggregate> findOptionalBy();
366+
367+
Streamable<StreamableAggregate> findAllBy();
368+
}
369+
370+
static abstract class StreamableAggregateSubType extends StreamableAggregate {}
327371
}

0 commit comments

Comments
 (0)