Skip to content

Commit 643aee4

Browse files
committed
Create DbAction.InsertBatch instead of individual inserts when mapping collections to db actions.
+ InsertBatch is a simple container of Insert so that existing insert logic can largely be reused. + Create a new class to encapsulate the information about an entity to be inserted. + Integration tests passing for all stores other than DB2 and MSSQL
1 parent 9f36815 commit 643aee4

25 files changed

+803
-60
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/AggregateChangeExecutor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ private void execute(DbAction<?> action, JdbcAggregateChangeExecutionContext exe
6666
executionContext.executeInsertRoot((DbAction.InsertRoot<?>) action);
6767
} else if (action instanceof DbAction.Insert) {
6868
executionContext.executeInsert((DbAction.Insert<?>) action);
69+
} else if (action instanceof DbAction.InsertBatch) {
70+
executionContext.executeInsertBatch((DbAction.InsertBatch<?>) action);
6971
} else if (action instanceof DbAction.UpdateRoot) {
7072
executionContext.executeUpdateRoot((DbAction.UpdateRoot<?>) action);
7173
} else if (action instanceof DbAction.Update) {

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
import java.util.Map;
2626
import java.util.Set;
2727
import java.util.function.BiConsumer;
28+
import java.util.stream.Collectors;
2829

2930
import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
3031
import org.springframework.dao.OptimisticLockingFailureException;
3132
import org.springframework.data.jdbc.core.convert.DataAccessStrategy;
3233
import org.springframework.data.jdbc.core.convert.Identifier;
3334
import org.springframework.data.jdbc.core.convert.JdbcConverter;
3435
import org.springframework.data.jdbc.core.convert.JdbcIdentifierBuilder;
36+
import org.springframework.data.jdbc.core.convert.RecordDescriptor;
3537
import org.springframework.data.mapping.PersistentProperty;
3638
import org.springframework.data.mapping.PersistentPropertyAccessor;
3739
import org.springframework.data.mapping.PersistentPropertyPath;
@@ -104,6 +106,20 @@ <T> void executeInsert(DbAction.Insert<T> insert) {
104106
add(new DbActionExecutionResult(insert, id));
105107
}
106108

109+
<T> void executeInsertBatch(DbAction.InsertBatch<T> insertBatch) {
110+
111+
List<DbAction.Insert<T>> inserts = insertBatch.getInserts();
112+
List<RecordDescriptor<T>> recordDescriptors = inserts.stream()
113+
.map(insert -> RecordDescriptor.of(insert.getEntity(), getParentKeys(insert, converter)))
114+
.collect(Collectors.toList());
115+
116+
Object[] ids = accessStrategy.insert(recordDescriptors, insertBatch.getEntityType());
117+
118+
for (int i = 0; i < inserts.size(); i++) {
119+
add(new DbActionExecutionResult(inserts.get(i), ids.length > 0 ? ids[i] : null));
120+
}
121+
}
122+
107123
<T> void executeUpdateRoot(DbAction.UpdateRoot<T> update) {
108124

109125
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(update.getEntityType());
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package org.springframework.data.jdbc.core.convert;
2+
3+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
4+
import org.springframework.jdbc.core.ColumnMapRowMapper;
5+
import org.springframework.jdbc.core.JdbcOperations;
6+
import org.springframework.jdbc.core.PreparedStatementCallback;
7+
import org.springframework.jdbc.core.PreparedStatementCreator;
8+
import org.springframework.jdbc.core.PreparedStatementCreatorFactory;
9+
import org.springframework.jdbc.core.RowMapperResultSetExtractor;
10+
import org.springframework.jdbc.core.SqlParameter;
11+
import org.springframework.jdbc.core.namedparam.NamedParameterUtils;
12+
import org.springframework.jdbc.core.namedparam.ParsedSql;
13+
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
14+
import org.springframework.jdbc.support.GeneratedKeyHolder;
15+
import org.springframework.jdbc.support.JdbcUtils;
16+
import org.springframework.jdbc.support.KeyHolder;
17+
import org.springframework.lang.Nullable;
18+
19+
import java.sql.PreparedStatement;
20+
import java.sql.ResultSet;
21+
import java.sql.SQLException;
22+
import java.util.List;
23+
import java.util.Map;
24+
25+
class BatchJdbcOperations {
26+
private final JdbcOperations jdbcOperations;
27+
28+
BatchJdbcOperations(JdbcOperations jdbcOperations) {
29+
this.jdbcOperations = jdbcOperations;
30+
}
31+
32+
@Nullable
33+
int[] insert(String insertSql, SqlParameterSource[] sqlParameterSources,
34+
GeneratedKeyHolder generatedKeyHolder) {
35+
// TODO: This is largely duplicated from spring-jdbc and should be replaced with a call into
36+
// NamedParameterJdbcTemplate#batchUpdate once a method taking KeyHolder is added there.
37+
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(insertSql);
38+
SqlParameterSource paramSource = sqlParameterSources[0];
39+
String sqlToUse = NamedParameterUtils.substituteNamedParameters(parsedSql, paramSource);
40+
List<SqlParameter> declaredParameters = NamedParameterUtils.buildSqlParameterList(parsedSql, paramSource);
41+
PreparedStatementCreatorFactory pscf = new PreparedStatementCreatorFactory(sqlToUse, declaredParameters);
42+
pscf.setReturnGeneratedKeys(true);
43+
Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null);
44+
PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params);
45+
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
46+
@Override
47+
public void setValues(PreparedStatement ps, int i) throws SQLException {
48+
Object[] values = NamedParameterUtils.buildValueArray(parsedSql, sqlParameterSources[i], null);
49+
pscf.newPreparedStatementSetter(values).setValues(ps);
50+
}
51+
52+
@Override
53+
public int getBatchSize() {
54+
return sqlParameterSources.length;
55+
}
56+
};
57+
PreparedStatementCallback<int[]> preparedStatementCallback = ps -> {
58+
int batchSize = bpss.getBatchSize();
59+
for (int i = 0; i < batchSize; i++) {
60+
bpss.setValues(ps, i);
61+
ps.addBatch();
62+
}
63+
int[] results = ps.executeBatch();
64+
List<Map<String, Object>> generatedKeys = ((KeyHolder) generatedKeyHolder).getKeyList();
65+
generatedKeys.clear();
66+
ResultSet keys = ps.getGeneratedKeys();
67+
if (keys != null) {
68+
try {
69+
RowMapperResultSetExtractor<Map<String, Object>> rse =
70+
new RowMapperResultSetExtractor<>(new ColumnMapRowMapper(), 1);
71+
//noinspection ConstantConditions
72+
generatedKeys.addAll(rse.extractData(keys));
73+
} finally {
74+
JdbcUtils.closeResultSet(keys);
75+
}
76+
}
77+
return results;
78+
};
79+
return jdbcOperations.execute(psc, preparedStatementCallback);
80+
}
81+
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
5454
return collect(das -> das.insert(instance, domainType, identifier));
5555
}
5656

57+
@Override
58+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
59+
return collect(das -> das.insert(recordDescriptors, domainType));
60+
}
61+
5762
/*
5863
* (non-Javadoc)
5964
* @see org.springframework.data.jdbc.core.DataAccessStrategy#update(java.lang.Object, java.lang.Class)

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.data.jdbc.core.convert;
1717

18+
import java.util.List;
1819
import java.util.Map;
1920

2021
import org.springframework.dao.OptimisticLockingFailureException;
@@ -53,6 +54,8 @@ public interface DataAccessStrategy extends RelationResolver {
5354
@Nullable
5455
<T> Object insert(T instance, Class<T> domainType, Identifier identifier);
5556

57+
<T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType);
58+
5659
/**
5760
* Updates the data of a single entity in the database. Referenced entities don't get handled.
5861
*

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.sql.SQLType;
2222
import java.util.ArrayList;
2323
import java.util.Collections;
24-
import java.util.HashSet;
2524
import java.util.List;
2625
import java.util.Map;
2726
import java.util.function.Predicate;
@@ -76,6 +75,7 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
7675
private final RelationalMappingContext context;
7776
private final JdbcConverter converter;
7877
private final NamedParameterJdbcOperations operations;
78+
private final BatchJdbcOperations batchOperations;
7979

8080
/**
8181
* Creates a {@link DefaultDataAccessStrategy}
@@ -88,6 +88,12 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
8888
*/
8989
public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, RelationalMappingContext context,
9090
JdbcConverter converter, NamedParameterJdbcOperations operations) {
91+
this(sqlGeneratorSource, context, converter, operations, new BatchJdbcOperations(operations.getJdbcOperations()));
92+
}
93+
94+
DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, RelationalMappingContext context,
95+
JdbcConverter converter, NamedParameterJdbcOperations operations,
96+
BatchJdbcOperations batchOperations) {
9197

9298
Assert.notNull(sqlGeneratorSource, "SqlGeneratorSource must not be null");
9399
Assert.notNull(context, "RelationalMappingContext must not be null");
@@ -98,6 +104,7 @@ public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, Relation
98104
this.context = context;
99105
this.converter = converter;
100106
this.operations = operations;
107+
this.batchOperations = batchOperations;
101108
}
102109

103110
/*
@@ -122,27 +129,74 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
122129
addConvertedPropertyValue(parameterSource, idProperty, idValue, idProperty.getColumnName());
123130
}
124131

125-
String insertSql = sqlGenerator.getInsert(new HashSet<>(parameterSource.getIdentifiers()));
132+
String insertSql = sqlGenerator.getInsert(parameterSource.getIdentifiers());
126133

127134
if (idValue == null) {
128-
return executeInsertAndReturnGeneratedId(domainType, persistentEntity, parameterSource, insertSql);
135+
return executeInsertAndReturnGeneratedId(persistentEntity, parameterSource, insertSql);
129136
} else {
130137

131138
operations.update(insertSql, parameterSource);
132139
return null;
133140
}
134141
}
135142

143+
@Override
144+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
145+
146+
SqlIdentifierParameterSource[] sqlParameterSources = recordDescriptors.stream()
147+
.map(recordDescriptor -> getParameterSource(recordDescriptor, getRequiredPersistentEntity(domainType)))
148+
.toArray(SqlIdentifierParameterSource[]::new);
149+
150+
String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());
151+
152+
GeneratedKeyHolder generatedKeyHolder = new GeneratedKeyHolder();
153+
batchOperations.insert(insertSql, sqlParameterSources, generatedKeyHolder);
154+
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
155+
if (!persistentEntity.hasIdProperty()) {
156+
return new Object[recordDescriptors.size()];
157+
}
158+
// TODO: Duplicated in #getIdFromHolder - consider refactoring
159+
return generatedKeyHolder.getKeyList().stream() //
160+
.map(keys -> {
161+
if (!persistentEntity.hasIdProperty()) {
162+
return null;
163+
}
164+
if (keys.size() > 1) {
165+
return keys.get(persistentEntity.getIdColumn().getReference(getIdentifierProcessing()));
166+
} else {
167+
return keys.entrySet().stream().findFirst() //
168+
.map(Map.Entry::getValue) //
169+
.orElse(null);
170+
}
171+
}) //
172+
.toArray();
173+
}
174+
175+
private <T> SqlIdentifierParameterSource getParameterSource(RecordDescriptor<T> recordDescriptor, RelationalPersistentEntity<T> persistentEntity) {
176+
SqlIdentifierParameterSource parameterSource = getParameterSource(recordDescriptor.getInstance(), persistentEntity, "",
177+
PersistentProperty::isIdProperty, getIdentifierProcessing());
178+
179+
recordDescriptor.getIdentifier().forEach((name, value, type) -> addConvertedPropertyValue(parameterSource, name, value, type));
180+
181+
Object idValue = getIdValueOrNull(recordDescriptor.getInstance(), persistentEntity);
182+
if (idValue != null) {
183+
184+
RelationalPersistentProperty idProperty = persistentEntity.getRequiredIdProperty();
185+
addConvertedPropertyValue(parameterSource, idProperty, idValue, idProperty.getColumnName());
186+
}
187+
return parameterSource;
188+
}
189+
136190
@Nullable
137-
private <T> Object executeInsertAndReturnGeneratedId(Class<T> domainType, RelationalPersistentEntity<T> persistentEntity, SqlIdentifierParameterSource parameterSource, String insertSql) {
191+
private <T> Object executeInsertAndReturnGeneratedId(RelationalPersistentEntity<T> persistentEntity, SqlIdentifierParameterSource parameterSource, String insertSql) {
138192

139193
KeyHolder holder = new GeneratedKeyHolder();
140194

141195
IdGeneration idGeneration = sqlGeneratorSource.getDialect().getIdGeneration();
142196

143197
if (idGeneration.driverRequiresKeyColumnNames()) {
144198

145-
String[] keyColumnNames = getKeyColumnNames(domainType);
199+
String[] keyColumnNames = getKeyColumnNames(persistentEntity.getType());
146200
if (keyColumnNames.length == 0) {
147201
operations.update(insertSql, parameterSource, holder);
148202
} else {

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import org.springframework.data.relational.core.sql.LockMode;
2323
import org.springframework.util.Assert;
2424

25+
import java.util.List;
26+
2527
/**
2628
* Delegates all method calls to an instance set after construction. This is useful for {@link DataAccessStrategy}s with
2729
* cyclic dependencies.
@@ -45,6 +47,11 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
4547
return delegate.insert(instance, domainType, identifier);
4648
}
4749

50+
@Override
51+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
52+
return delegate.insert(recordDescriptors, domainType);
53+
}
54+
4855
/*
4956
* (non-Javadoc)
5057
* @see org.springframework.data.jdbc.core.DataAccessStrategy#update(java.lang.Object, java.lang.Class)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package org.springframework.data.jdbc.core.convert;
2+
3+
import java.util.Objects;
4+
5+
public final class RecordDescriptor<T> {
6+
private final T instance;
7+
private final Identifier identifier;
8+
9+
public static <T> RecordDescriptor<T> of(T instance, Identifier identifier) {
10+
return new RecordDescriptor<>(instance, identifier);
11+
}
12+
13+
private RecordDescriptor(T instance, Identifier identifier) {
14+
this.instance = instance;
15+
this.identifier = identifier;
16+
}
17+
18+
public T getInstance() {
19+
return instance;
20+
}
21+
22+
public Identifier getIdentifier() {
23+
return identifier;
24+
}
25+
26+
@Override
27+
public boolean equals(Object o) {
28+
if (this == o)
29+
return true;
30+
if (o == null || getClass() != o.getClass())
31+
return false;
32+
RecordDescriptor<?> that = (RecordDescriptor<?>) o;
33+
return Objects.equals(instance, that.instance) && Objects.equals(identifier, that.identifier);
34+
}
35+
36+
@Override
37+
public int hashCode() {
38+
return Objects.hash(instance, identifier);
39+
}
40+
41+
@Override
42+
public String toString() {
43+
return "RecordDescriptor{" + "instance=" + instance + ", identifier=" + identifier + '}';
44+
}
45+
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,12 +579,16 @@ private String createInsertSql(Set<SqlIdentifier> additionalColumns) {
579579
insert = insert.column(table.column(cn));
580580
}
581581

582+
if (columnNamesForInsert.isEmpty()) {
583+
return render(insert.build());
584+
}
585+
582586
InsertBuilder.InsertValuesWithBuild insertWithValues = null;
583587
for (SqlIdentifier cn : columnNamesForInsert) {
584588
insertWithValues = (insertWithValues == null ? insert : insertWithValues).values(getBindMarker(cn));
585589
}
586590

587-
return render(insertWithValues == null ? insert.build() : insertWithValues.build());
591+
return render(insertWithValues.build());
588592
}
589593

590594
private String createUpdateSql() {

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Collections;
2121
import java.util.HashMap;
22+
import java.util.List;
2223
import java.util.Map;
2324
import java.util.stream.Collectors;
2425

@@ -35,6 +36,7 @@
3536
import org.springframework.data.jdbc.core.convert.DelegatingDataAccessStrategy;
3637
import org.springframework.data.jdbc.core.convert.Identifier;
3738
import org.springframework.data.jdbc.core.convert.JdbcConverter;
39+
import org.springframework.data.jdbc.core.convert.RecordDescriptor;
3840
import org.springframework.data.jdbc.core.convert.SqlGeneratorSource;
3941
import org.springframework.data.mapping.PersistentPropertyPath;
4042
import org.springframework.data.mapping.PropertyPath;
@@ -159,6 +161,12 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
159161
return myBatisContext.getId();
160162
}
161163

164+
@Override
165+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
166+
// TODO: Figure out what this should look like
167+
throw new UnsupportedOperationException();
168+
}
169+
162170
/*
163171
* (non-Javadoc)
164172
* @see org.springframework.data.jdbc.core.DataAccessStrategy#update(java.lang.Object, java.lang.Class)

0 commit comments

Comments
 (0)