Skip to content

Commit 464af98

Browse files
committed
Polishing.
Introduce doWithPlainSelect(…) callback for easier filtering of Select subtypes. Add test for known (previously) failing case. See: #3869 Original pull request: #3870
1 parent b522702 commit 464af98

File tree

2 files changed

+112
-68
lines changed

2 files changed

+112
-68
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancer.java

Lines changed: 105 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@
4848
import java.util.List;
4949
import java.util.Set;
5050
import java.util.StringJoiner;
51+
import java.util.function.Predicate;
52+
import java.util.function.Supplier;
5153

5254
import org.springframework.data.domain.Sort;
55+
import org.springframework.data.util.Predicates;
5356
import org.springframework.lang.Nullable;
5457
import org.springframework.util.Assert;
5558
import org.springframework.util.CollectionUtils;
@@ -145,24 +148,8 @@ private static String detectAlias(ParsedType parsedType, Statement statement) {
145148

146149
if (ParsedType.SELECT.equals(parsedType)) {
147150

148-
Select selectStatement = (Select) statement;
149-
150-
/*
151-
* For all the other types ({@link ValuesStatement} and {@link SetOperationList}) it does not make sense to provide
152-
* alias since:
153-
* ValuesStatement has no alias
154-
* SetOperation can have multiple alias for each operation item
155-
*/
156-
if (!(selectStatement instanceof PlainSelect selectBody)) {
157-
return null;
158-
}
159-
160-
if (selectBody.getFromItem() == null) {
161-
return null;
162-
}
163-
164-
Alias alias = selectBody.getFromItem().getAlias();
165-
return alias == null ? null : alias.getName();
151+
return doWithPlainSelect(statement, it -> it.getFromItem() == null || it.getFromItem().getAlias() == null,
152+
it -> it.getFromItem().getAlias().getName(), () -> null);
166153
}
167154

168155
return null;
@@ -175,20 +162,24 @@ private static String detectAlias(ParsedType parsedType, Statement statement) {
175162
*/
176163
private static Set<String> getSelectionAliases(Statement statement) {
177164

178-
if (!(statement instanceof PlainSelect select) || CollectionUtils.isEmpty(select.getSelectItems())) {
179-
return Collections.emptySet();
165+
if (statement instanceof SetOperationList sel) {
166+
statement = sel.getSelect(0);
180167
}
181168

182-
Set<String> set = new HashSet<>(select.getSelectItems().size());
169+
return doWithPlainSelect(statement, it -> CollectionUtils.isEmpty(it.getSelectItems()), it -> {
170+
171+
Set<String> set = new HashSet<>(it.getSelectItems().size(), 1.0f);
183172

184-
for (SelectItem<?> selectItem : select.getSelectItems()) {
185-
Alias alias = selectItem.getAlias();
186-
if (alias != null) {
187-
set.add(alias.getName());
173+
for (SelectItem<?> selectItem : it.getSelectItems()) {
174+
Alias alias = selectItem.getAlias();
175+
if (alias != null) {
176+
set.add(alias.getName());
177+
}
188178
}
189-
}
190179

191-
return set;
180+
return set;
181+
182+
}, Collections::emptySet);
192183
}
193184

194185
/**
@@ -198,21 +189,74 @@ private static Set<String> getSelectionAliases(Statement statement) {
198189
*/
199190
private static Set<String> getJoinAliases(Statement statement) {
200191

201-
if (!(statement instanceof PlainSelect selectBody) || CollectionUtils.isEmpty(selectBody.getJoins())) {
202-
return Collections.emptySet();
192+
if (statement instanceof SetOperationList sel) {
193+
statement = sel.getSelect(0);
203194
}
204195

205-
Set<String> set = new HashSet<>(selectBody.getJoins().size());
196+
return doWithPlainSelect(statement, it -> CollectionUtils.isEmpty(it.getJoins()), it -> {
206197

207-
for (Join join : selectBody.getJoins()) {
208-
Alias alias = join.getRightItem().getAlias();
209-
if (alias != null) {
210-
set.add(alias.getName());
198+
Set<String> set = new HashSet<>(it.getJoins().size(), 1.0f);
199+
200+
for (Join join : it.getJoins()) {
201+
Alias alias = join.getRightItem().getAlias();
202+
if (alias != null) {
203+
set.add(alias.getName());
204+
}
211205
}
206+
return set;
207+
208+
}, Collections::emptySet);
209+
}
210+
211+
/**
212+
* Apply a {@link java.util.function.Function mapping function} to the {@link PlainSelect} of the given
213+
* {@link Statement} is or contains a {@link PlainSelect}.
214+
*
215+
* @param statement
216+
* @param mapper
217+
* @param fallback
218+
* @return
219+
* @param <T>
220+
*/
221+
private static <T> T doWithPlainSelect(Statement statement, java.util.function.Function<PlainSelect, T> mapper,
222+
Supplier<T> fallback) {
223+
224+
Predicate<PlainSelect> neverSkip = Predicates.isFalse();
225+
return doWithPlainSelect(statement, neverSkip, mapper, fallback);
226+
}
227+
228+
/**
229+
* Apply a {@link java.util.function.Function mapping function} to the {@link PlainSelect} of the given
230+
* {@link Statement} is or contains a {@link PlainSelect}.
231+
* <p>
232+
* The operation is only applied if {@link Predicate skipIf} returns {@literal false} for the given statement
233+
* returning the fallback value from {@code fallback}.
234+
*
235+
* @param statement
236+
* @param skipIf
237+
* @param mapper
238+
* @param fallback
239+
* @return
240+
* @param <T>
241+
*/
242+
private static <T> T doWithPlainSelect(Statement statement, Predicate<PlainSelect> skipIf,
243+
java.util.function.Function<PlainSelect, T> mapper, Supplier<T> fallback) {
244+
245+
if (!(statement instanceof Select select)) {
246+
return fallback.get();
212247
}
213248

214-
return set;
249+
try {
250+
if (skipIf.test(select.getPlainSelect())) {
251+
return fallback.get();
252+
}
253+
}
254+
// e.g. SetOperationList is a subclass of Select but it is not a PlainSelect
255+
catch (ClassCastException e) {
256+
return fallback.get();
257+
}
215258

259+
return mapper.apply(select.getPlainSelect());
216260
}
217261

218262
private static String detectProjection(Statement statement) {
@@ -231,18 +275,17 @@ private static String detectProjection(Statement statement) {
231275

232276
// using the first one since for setoperations the projection has to be the same
233277
selectBody = setOperationList.getSelects().get(0);
234-
235-
if (!(selectBody instanceof PlainSelect)) {
236-
return "";
237-
}
238278
}
239279

240-
StringJoiner joiner = new StringJoiner(", ");
241-
for (SelectItem<?> selectItem : selectBody.getPlainSelect().getSelectItems()) {
242-
joiner.add(selectItem.toString());
243-
}
244-
return joiner.toString().trim();
280+
return doWithPlainSelect(selectBody, it -> CollectionUtils.isEmpty(it.getSelectItems()), it -> {
281+
282+
StringJoiner joiner = new StringJoiner(", ");
283+
for (SelectItem<?> selectItem : it.getSelectItems()) {
284+
joiner.add(selectItem.toString());
285+
}
286+
return joiner.toString().trim();
245287

288+
}, () -> "");
246289
}
247290

248291
/**
@@ -320,20 +363,22 @@ private String applySorting(Select selectStatement, Sort sort, @Nullable String
320363
return applySortingToSetOperationList(setOperationList, sort);
321364
}
322365

323-
if (!(selectStatement instanceof PlainSelect selectBody)) {
324-
return selectStatement.toString();
325-
}
366+
doWithPlainSelect(selectStatement, it -> {
326367

327-
List<OrderByElement> orderByElements = new ArrayList<>(16);
328-
for (Sort.Order order : sort) {
329-
orderByElements.add(getOrderClause(joinAliases, selectAliases, alias, order));
330-
}
368+
List<OrderByElement> orderByElements = new ArrayList<>(16);
369+
for (Sort.Order order : sort) {
370+
orderByElements.add(getOrderClause(joinAliases, selectAliases, alias, order));
371+
}
331372

332-
if (CollectionUtils.isEmpty(selectBody.getOrderByElements())) {
333-
selectBody.setOrderByElements(orderByElements);
334-
} else {
335-
selectBody.getOrderByElements().addAll(orderByElements);
336-
}
373+
if (CollectionUtils.isEmpty(it.getOrderByElements())) {
374+
it.setOrderByElements(orderByElements);
375+
} else {
376+
it.getOrderByElements().addAll(orderByElements);
377+
}
378+
379+
return null;
380+
381+
}, () -> "");
337382

338383
return selectStatement.toString();
339384
}
@@ -348,18 +393,13 @@ public String createCountQueryFor(@Nullable String countProjection) {
348393
Assert.hasText(this.query.getQueryString(), "OriginalQuery must not be null or empty");
349394

350395
Statement statement = (Statement) deserialize(this.serialized);
351-
/*
352-
We only support count queries for {@link PlainSelect}.
353-
*/
354-
if (!(statement instanceof PlainSelect selectBody)) {
355-
return this.query.getQueryString();
356-
}
357396

358-
return createCountQueryFor(this.query, selectBody, countProjection, primaryAlias);
397+
return doWithPlainSelect(statement, it -> createCountQueryFor(it, countProjection, primaryAlias),
398+
this.query::getQueryString);
359399
}
360400

361-
private static String createCountQueryFor(DeclaredQuery query, PlainSelect selectBody,
362-
@Nullable String countProjection, @Nullable String primaryAlias) {
401+
private static String createCountQueryFor(PlainSelect selectBody, @Nullable String countProjection,
402+
@Nullable String primaryAlias) {
363403

364404
// remove order by
365405
selectBody.setOrderByElements(null);

spring-data-jpa/src/test/java/org/springframework/data/jpa/repository/query/JSqlParserQueryEnhancerUnitTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
* @author Geoffrey Deremetz
3535
* @author Christoph Strobl
3636
*/
37-
public class JSqlParserQueryEnhancerUnitTests extends QueryEnhancerTckTests {
37+
class JSqlParserQueryEnhancerUnitTests extends QueryEnhancerTckTests {
3838

3939
@Override
4040
QueryEnhancer createQueryEnhancer(DeclaredQuery declaredQuery) {
@@ -258,12 +258,16 @@ static Stream<Arguments> mergeStatementWorksSource() {
258258
}
259259

260260
@Test // GH-3869
261-
void shouldWorkWithoutFromClause() {
262-
String query = "SELECT is_contained_in(:innerId, :outerId)";
261+
void shouldWorkWithParenthesedSelect() {
262+
263+
String query = "(SELECT is_contained_in(:innerId, :outerId))";
263264

264265
StringQuery stringQuery = new StringQuery(query, true);
266+
QueryEnhancer queryEnhancer = QueryEnhancerFactory.forQuery(stringQuery);
265267

266268
assertThat(stringQuery.getQueryString()).isEqualTo(query);
269+
assertThat(stringQuery.getAlias()).isNull();
270+
assertThat(queryEnhancer.getProjection()).isEqualTo("is_contained_in(:innerId, :outerId)");
267271
}
268272

269273
}

0 commit comments

Comments
 (0)