Skip to content

Commit 0f81a51

Browse files
committed
Polishing.
1 parent 147f2eb commit 0f81a51

File tree

19 files changed

+456
-328
lines changed

19 files changed

+456
-328
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.data.jdbc.core.convert;
1717

1818
import java.util.*;
19+
import java.util.function.BiFunction;
1920
import java.util.function.Function;
2021
import java.util.function.Predicate;
2122
import java.util.stream.Collectors;
@@ -115,7 +116,7 @@ public class SqlGenerator {
115116

116117
/**
117118
* Create a basic select structure with all the necessary joins
118-
*
119+
*
119120
* @param table the table to base the select on
120121
* @param pathFilter a filter for excluding paths from the select. All paths for which the filter returns
121122
* {@literal true} will be skipped when determining columns to select.
@@ -185,6 +186,8 @@ private Condition getSubselectCondition(AggregatePath path,
185186
Table subSelectTable = Table.create(parentPathTableInfo.qualifiedTableName());
186187

187188
Map<AggregatePath, Column> selectFilterColumns = new TreeMap<>();
189+
190+
// TODO: cannot we simply pass on the columnInfos?
188191
parentPathTableInfo.effectiveIdColumnInfos().forEach( //
189192
(ap, ci) -> //
190193
selectFilterColumns.put(ap, subSelectTable.column(ci.name())) //
@@ -468,6 +471,8 @@ String createDeleteAllSql(@Nullable PersistentPropertyPath<RelationalPersistentP
468471
* @return the statement as a {@link String}. Guaranteed to be not {@literal null}.
469472
*/
470473
String createDeleteByPath(PersistentPropertyPath<RelationalPersistentProperty> path) {
474+
// TODO: When deleting by path, why do we expect the where-value to be id and not named after the path?
475+
// See SqlGeneratorEmbeddedUnitTests.deleteByPath
471476
return createDeleteByPathAndCriteria(mappingContext.getAggregatePath(path), this::equalityCondition);
472477
}
473478

@@ -487,12 +492,10 @@ String createDeleteInByPath(PersistentPropertyPath<RelationalPersistentProperty>
487492
*/
488493
private Condition inCondition(Map<AggregatePath, Column> columnMap) {
489494

490-
List<Column> columns = List.copyOf(columnMap.values());
495+
Collection<Column> columns = columnMap.values();
491496

492-
if (columns.size() == 1) {
493-
return Conditions.in(columns.get(0), getBindMarker(IDS_SQL_PARAMETER));
494-
}
495-
return Conditions.in(TupleExpression.create(columns), getBindMarker(IDS_SQL_PARAMETER));
497+
return Conditions.in(columns.size() == 1 ? columns.iterator().next() : TupleExpression.create(columns),
498+
getBindMarker(IDS_SQL_PARAMETER));
496499
}
497500

498501
/**
@@ -501,44 +504,54 @@ private Condition inCondition(Map<AggregatePath, Column> columnMap) {
501504
*/
502505
private Condition equalityCondition(Map<AggregatePath, Column> columnMap) {
503506

504-
AggregatePath.ColumnInfos idColumnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
507+
Assert.isTrue(!columnMap.isEmpty(), "Column map must not be empty");
505508

506-
Condition result = null;
507-
for (Map.Entry<AggregatePath, Column> entry : columnMap.entrySet()) {
508-
BindMarker bindMarker = getBindMarker(idColumnInfos.get(entry.getKey()).name());
509-
Comparison singleCondition = entry.getValue().isEqualTo(bindMarker);
509+
AggregatePath.ColumnInfos idColumnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
510510

511-
result = result == null ? singleCondition : result.and(singleCondition);
512-
}
513-
Assert.state(result != null, "We need at least one condition");
514-
return result;
511+
return createPredicate(columnMap, (aggregatePath, column) -> {
512+
return column.isEqualTo(getBindMarker(idColumnInfos.get(aggregatePath).name()));
513+
});
515514
}
516515

517516
/**
518517
* Constructs a function for constructing where a condition. The where condition will be of the form
519518
* {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
520519
*/
521520
private Condition isNotNullCondition(Map<AggregatePath, Column> columnMap) {
521+
return createPredicate(columnMap, (aggregatePath, column) -> column.isNotNull());
522+
}
523+
524+
/**
525+
* Constructs a function for constructing where a condition. The where condition will be of the form
526+
* {@literal <column-a> IS NOT NULL AND <column-b> IS NOT NULL ... }
527+
*/
528+
private static Condition createPredicate(Map<AggregatePath, Column> columnMap,
529+
BiFunction<AggregatePath, Column, Condition> conditionFunction) {
522530

523531
Condition result = null;
524-
for (Column column : columnMap.values()) {
525-
Condition singleCondition = column.isNotNull();
532+
for (Map.Entry<AggregatePath, Column> entry : columnMap.entrySet()) {
526533

534+
Condition singleCondition = conditionFunction.apply(entry.getKey(), entry.getValue());
527535
result = result == null ? singleCondition : result.and(singleCondition);
528536
}
529537
Assert.state(result != null, "We need at least one condition");
530538
return result;
531539
}
532540

533541
private String createFindOneSql() {
534-
535542
return render(selectBuilder().where(equalityIdWhereCondition()).build());
536543
}
537544

538545
private Condition equalityIdWhereCondition() {
546+
return equalityIdWhereCondition(getIdColumns());
547+
}
548+
549+
private Condition equalityIdWhereCondition(Iterable<Column> columns) {
550+
551+
Assert.isTrue(columns.iterator().hasNext(), "Identifier columns must not be empty");
539552

540553
Condition aggregate = null;
541-
for (Column column : getIdColumns()) {
554+
for (Column column : columns) {
542555

543556
Comparison condition = column.isEqualTo(getBindMarker(column.getName()));
544557
aggregate = aggregate == null ? condition : aggregate.and(condition);
@@ -711,19 +724,13 @@ Join getJoin(AggregatePath path) {
711724
Table parentTable = sqlContext.getTable(idDefiningParentPath);
712725
AggregatePath.ColumnInfos idColumnInfos = idDefiningParentPath.getTableInfo().idColumnInfos();
713726

714-
final Condition[] joinCondition = { null };
715-
backRefColumnInfos.forEach((ap, ci) -> {
716-
717-
Condition elementalCondition = currentTable.column(ci.name())
718-
.isEqualTo(parentTable.column(idColumnInfos.get(ap).name()));
719-
joinCondition[0] = joinCondition[0] == null ? elementalCondition : joinCondition[0].and(elementalCondition);
720-
});
727+
Condition joinCondition = backRefColumnInfos.reduce(Conditions.unrestricted(), (aggregatePath, columnInfo) -> {
721728

722-
return new Join( //
723-
currentTable, //
724-
joinCondition[0] //
725-
);
729+
return currentTable.column(columnInfo.name())
730+
.isEqualTo(parentTable.column(idColumnInfos.get(aggregatePath).name()));
731+
}, Condition::and);
726732

733+
return new Join(currentTable, joinCondition);
727734
}
728735

729736
private String createFindAllInListSql() {
@@ -862,6 +869,8 @@ private String createDeleteByPathAndCriteria(AggregatePath path,
862869

863870
Map<AggregatePath, Column> columns = new TreeMap<>();
864871
AggregatePath.ColumnInfos columnInfos = path.getTableInfo().backReferenceColumnInfos();
872+
873+
// TODO: cannot we simply pass on the columnInfos?
865874
columnInfos.forEach((ag, ci) -> columns.put(ag, table.column(ci.name())));
866875

867876
if (isFirstNonRoot(path)) {
@@ -915,17 +924,20 @@ private Table getTable() {
915924
*/
916925
private Column getSingleNonNullColumn() {
917926

927+
// getColumn() is slightly different from the code in any(…). Why?
928+
// AggregatePath.ColumnInfo columnInfo = path.getColumnInfo();
929+
// return getTable(path).column(columnInfo.name()).as(columnInfo.alias());
930+
918931
AggregatePath.ColumnInfos columnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
919932
return columnInfos.any((ap, ci) -> sqlContext.getTable(columnInfos.fullPath(ap)).column(ci.name()).as(ci.alias()));
920933
}
921934

922935
private List<Column> getIdColumns() {
923936

924937
AggregatePath.ColumnInfos columnInfos = mappingContext.getAggregatePath(entity).getTableInfo().idColumnInfos();
925-
List<Column> result = new ArrayList<>(columnInfos.size());
926-
columnInfos.forEach((ap, ci) -> result.add(sqlContext.getColumn(columnInfos.fullPath(ap))));
927938

928-
return result;
939+
return columnInfos
940+
.toColumnList((aggregatePath, columnInfo) -> sqlContext.getColumn(columnInfos.fullPath(aggregatePath)));
929941
}
930942

931943
private Column getVersionColumn() {

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import java.sql.SQLType;
1919
import java.util.ArrayList;
20+
import java.util.Collection;
2021
import java.util.List;
2122
import java.util.Map;
2223
import java.util.function.BiFunction;
23-
import java.util.function.Function;
2424
import java.util.function.Predicate;
2525

2626
import org.springframework.data.jdbc.core.mapping.JdbcValue;
@@ -34,9 +34,7 @@
3434
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
3535
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
3636
import org.springframework.data.relational.core.sql.SqlIdentifier;
37-
import org.springframework.jdbc.support.JdbcUtils;
3837
import org.springframework.lang.Nullable;
39-
import org.springframework.util.Assert;
4038

4139
/**
4240
* Creates the {@link SqlIdentifierParameterSource} for various SQL operations, dialect identifier processing rules and
@@ -45,9 +43,11 @@
4543
* @author Jens Schauder
4644
* @author Chirag Tailor
4745
* @author Mikhail Polivakha
46+
* @author Mark Paluch
4847
* @since 2.4
4948
*/
5049
public class SqlParametersFactory {
50+
5151
private final RelationalMappingContext context;
5252
private final JdbcConverter converter;
5353

@@ -119,24 +119,20 @@ <T> SqlIdentifierParameterSource forUpdate(T instance, Class<T> domainType) {
119119
*/
120120
<T> SqlIdentifierParameterSource forQueryById(Object id, Class<T> domainType) {
121121

122-
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
123-
124-
RelationalPersistentEntity<T> entity = getRequiredPersistentEntity(domainType);
125-
RelationalPersistentProperty singleIdProperty = entity.getRequiredIdProperty();
122+
return doWithIdentifiers(domainType, (columns, idProperty, complexId) -> {
126123

127-
RelationalPersistentEntity<?> complexId = context.getPersistentEntity(singleIdProperty);
124+
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
125+
BiFunction<Object, AggregatePath, Object> valueExtractor = getIdMapper(complexId);
128126

129-
Function<AggregatePath, Object> valueExtractor = complexId == null ? ap -> id
130-
: ap -> complexId.getPropertyPathAccessor(id).getProperty(ap.getRequiredPersistentPropertyPath());
127+
columns.forEach((ap, ci) -> addConvertedPropertyValue( //
128+
parameterSource, //
129+
ap.getRequiredLeafProperty(), //
130+
valueExtractor.apply(id, ap), //
131+
ci.name() //
132+
));
131133

132-
context.getAggregatePath(entity).getTableInfo().idColumnInfos() //
133-
.forEach((ap, ci) -> addConvertedPropertyValue( //
134-
parameterSource, //
135-
ap.getRequiredLeafProperty(), //
136-
valueExtractor.apply(ap), //
137-
ci.name() //
138-
));
139-
return parameterSource;
134+
return parameterSource;
135+
});
140136
}
141137

142138
/**
@@ -149,29 +145,44 @@ <T> SqlIdentifierParameterSource forQueryById(Object id, Class<T> domainType) {
149145
*/
150146
<T> SqlIdentifierParameterSource forQueryByIds(Iterable<?> ids, Class<T> domainType) {
151147

152-
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
148+
return doWithIdentifiers(domainType, (columns, idProperty, complexId) -> {
153149

154-
RelationalPersistentEntity<?> entity = context.getRequiredPersistentEntity(domainType);
155-
RelationalPersistentProperty singleIdProperty = entity.getRequiredIdProperty();
156-
RelationalPersistentEntity<?> complexId = context.getPersistentEntity(singleIdProperty);
157-
AggregatePath.ColumnInfos idColumnInfos = context.getAggregatePath(entity).getTableInfo().idColumnInfos();
150+
SqlIdentifierParameterSource parameterSource = new SqlIdentifierParameterSource();
158151

159-
BiFunction<Object, AggregatePath, Object> valueExtractor = complexId == null ? (id, ap) -> id
160-
: (id, ap) -> complexId.getPropertyPathAccessor(id).getProperty(ap.getRequiredPersistentPropertyPath());
152+
BiFunction<Object, AggregatePath, Object> valueExtractor = getIdMapper(complexId);
161153

162-
List<Object[]> parameterValues = new ArrayList<>();
163-
for (Object id : ids) {
154+
List<Object[]> parameterValues = new ArrayList<>(ids instanceof Collection<?> c ? c.size() : 16);
155+
for (Object id : ids) {
164156

165-
List<Object> tupleList = new ArrayList<>();
166-
idColumnInfos.forEach((ap, ci) -> {
167-
tupleList.add(valueExtractor.apply(id, ap));
168-
});
169-
parameterValues.add(tupleList.toArray(new Object[0]));
170-
}
157+
Object[] tupleList = new Object[columns.size()];
171158

172-
parameterSource.addValue(SqlGenerator.IDS_SQL_PARAMETER, parameterValues);
159+
int i = 0;
160+
for (AggregatePath path : columns.paths()) {
161+
tupleList[i++] = valueExtractor.apply(id, path);
162+
}
173163

174-
return parameterSource;
164+
parameterValues.add(tupleList);
165+
}
166+
167+
parameterSource.addValue(SqlGenerator.IDS_SQL_PARAMETER, parameterValues);
168+
return parameterSource;
169+
});
170+
}
171+
172+
private <T> T doWithIdentifiers(Class<?> domainType, IdentifierCallback<T> callback) {
173+
174+
RelationalPersistentEntity<?> entity = context.getRequiredPersistentEntity(domainType);
175+
RelationalPersistentProperty idProperty = entity.getRequiredIdProperty();
176+
RelationalPersistentEntity<?> complexId = context.getPersistentEntity(idProperty);
177+
AggregatePath.ColumnInfos columns = context.getAggregatePath(entity).getTableInfo().idColumnInfos();
178+
179+
return callback.doWithIdentifiers(columns, idProperty, complexId);
180+
}
181+
182+
interface IdentifierCallback<T> {
183+
184+
T doWithIdentifiers(AggregatePath.ColumnInfos columns, RelationalPersistentProperty idProperty,
185+
RelationalPersistentEntity<?> complexId);
175186
}
176187

177188
/**
@@ -191,6 +202,16 @@ SqlIdentifierParameterSource forQueryByIdentifier(Identifier identifier) {
191202
return parameterSource;
192203
}
193204

205+
private BiFunction<Object, AggregatePath, Object> getIdMapper(@Nullable RelationalPersistentEntity<?> complexId) {
206+
207+
if (complexId == null) {
208+
return (id, aggregatePath) -> id;
209+
}
210+
211+
return (id, aggregatePath) -> complexId.getPropertyPathAccessor(id)
212+
.getProperty(aggregatePath.getRequiredPersistentPropertyPath());
213+
}
214+
194215
private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSource,
195216
RelationalPersistentProperty property, @Nullable Object value, SqlIdentifier name) {
196217

@@ -219,28 +240,6 @@ private void addConvertedValue(SqlIdentifierParameterSource parameterSource, @Nu
219240
jdbcValue.getJdbcType().getVendorTypeNumber());
220241
}
221242

222-
private void addConvertedPropertyValuesAsList(SqlIdentifierParameterSource parameterSource,
223-
RelationalPersistentProperty property, Iterable<?> values) {
224-
225-
List<Object> convertedIds = new ArrayList<>();
226-
JdbcValue jdbcValue = null;
227-
for (Object id : values) {
228-
229-
Class<?> columnType = converter.getColumnType(property);
230-
SQLType sqlType = converter.getTargetSqlType(property);
231-
232-
jdbcValue = converter.writeJdbcValue(id, columnType, sqlType);
233-
convertedIds.add(jdbcValue.getValue());
234-
}
235-
236-
Assert.state(jdbcValue != null, "JdbcValue must be not null at this point; Please report this as a bug");
237-
238-
SQLType jdbcType = jdbcValue.getJdbcType();
239-
int typeNumber = jdbcType == null ? JdbcUtils.TYPE_UNKNOWN : jdbcType.getVendorTypeNumber();
240-
241-
parameterSource.addValue(SqlGenerator.IDS_SQL_PARAMETER, convertedIds, typeNumber);
242-
}
243-
244243
@SuppressWarnings("unchecked")
245244
private <S> RelationalPersistentEntity<S> getRequiredPersistentEntity(Class<S> domainType) {
246245
return (RelationalPersistentEntity<S>) context.getRequiredPersistentEntity(domainType);

0 commit comments

Comments
 (0)