Skip to content

Commit 20d2716

Browse files
committed
Add includeId parameter to batch insert method and populate it appropriately when the InsertBatch is created.
+ Copied over a lot of the insert logic around keyColumnNames and added method to batchJdbcOperations taking keyColumnNames
1 parent 643aee4 commit 20d2716

File tree

12 files changed

+162
-54
lines changed

12 files changed

+162
-54
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ <T> void executeInsertBatch(DbAction.InsertBatch<T> insertBatch) {
113113
.map(insert -> RecordDescriptor.of(insert.getEntity(), getParentKeys(insert, converter)))
114114
.collect(Collectors.toList());
115115

116-
Object[] ids = accessStrategy.insert(recordDescriptors, insertBatch.getEntityType());
116+
Object[] ids = accessStrategy.insert(recordDescriptors, insertBatch.getEntityType(), insertBatch.getIncludeId());
117117

118118
for (int i = 0; i < inserts.size(); i++) {
119119
add(new DbActionExecutionResult(inserts.get(i), ids.length > 0 ? ids[i] : null));

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/BatchJdbcOperations.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,55 @@ public int getBatchSize() {
7878
};
7979
return jdbcOperations.execute(psc, preparedStatementCallback);
8080
}
81+
82+
@Nullable
83+
public int[] insert(String insertSql, SqlIdentifierParameterSource[] sqlParameterSources,
84+
GeneratedKeyHolder generatedKeyHolder, String[] keyColumnNames) {
85+
// TODO: This is largely duplicated from spring-jdbc and should be replaced with a call into
86+
// NamedParameterJdbcTemplate#batchUpdate once a method taking KeyHolder is added there.
87+
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(insertSql);
88+
SqlParameterSource paramSource = sqlParameterSources[0];
89+
String sqlToUse = NamedParameterUtils.substituteNamedParameters(parsedSql, paramSource);
90+
List<SqlParameter> declaredParameters = NamedParameterUtils.buildSqlParameterList(parsedSql, paramSource);
91+
PreparedStatementCreatorFactory pscf = new PreparedStatementCreatorFactory(sqlToUse, declaredParameters);
92+
pscf.setReturnGeneratedKeys(true);
93+
pscf.setGeneratedKeysColumnNames(keyColumnNames);
94+
Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null);
95+
PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params);
96+
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
97+
@Override
98+
public void setValues(PreparedStatement ps, int i) throws SQLException {
99+
Object[] values = NamedParameterUtils.buildValueArray(parsedSql, sqlParameterSources[i], null);
100+
pscf.newPreparedStatementSetter(values).setValues(ps);
101+
}
102+
103+
@Override
104+
public int getBatchSize() {
105+
return sqlParameterSources.length;
106+
}
107+
};
108+
PreparedStatementCallback<int[]> preparedStatementCallback = ps -> {
109+
int batchSize = bpss.getBatchSize();
110+
for (int i = 0; i < batchSize; i++) {
111+
bpss.setValues(ps, i);
112+
ps.addBatch();
113+
}
114+
int[] results = ps.executeBatch();
115+
List<Map<String, Object>> generatedKeys = ((KeyHolder) generatedKeyHolder).getKeyList();
116+
generatedKeys.clear();
117+
ResultSet keys = ps.getGeneratedKeys();
118+
if (keys != null) {
119+
try {
120+
RowMapperResultSetExtractor<Map<String, Object>> rse =
121+
new RowMapperResultSetExtractor<>(new ColumnMapRowMapper(), 1);
122+
//noinspection ConstantConditions
123+
generatedKeys.addAll(rse.extractData(keys));
124+
} finally {
125+
JdbcUtils.closeResultSet(keys);
126+
}
127+
}
128+
return results;
129+
};
130+
return jdbcOperations.execute(psc, preparedStatementCallback);
131+
}
81132
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/CascadingDataAccessStrategy.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
5555
}
5656

5757
@Override
58-
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
59-
return collect(das -> das.insert(recordDescriptors, domainType));
58+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType, boolean includeId) {
59+
return collect(das -> das.insert(recordDescriptors, domainType, includeId));
6060
}
6161

6262
/*

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public interface DataAccessStrategy extends RelationResolver {
5454
@Nullable
5555
<T> Object insert(T instance, Class<T> domainType, Identifier identifier);
5656

57-
<T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType);
57+
<T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType, boolean includeId);
5858

5959
/**
6060
* Updates the data of a single entity in the database. Referenced entities don't get handled.

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -141,35 +141,54 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
141141
}
142142

143143
@Override
144-
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
144+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType, boolean includeId) {
145145

146+
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
146147
SqlIdentifierParameterSource[] sqlParameterSources = recordDescriptors.stream()
147-
.map(recordDescriptor -> getParameterSource(recordDescriptor, getRequiredPersistentEntity(domainType)))
148+
.map(recordDescriptor -> getParameterSource(recordDescriptor, persistentEntity))
148149
.toArray(SqlIdentifierParameterSource[]::new);
149150

150151
String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());
151152

152-
GeneratedKeyHolder generatedKeyHolder = new GeneratedKeyHolder();
153-
batchOperations.insert(insertSql, sqlParameterSources, generatedKeyHolder);
154-
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
153+
if (includeId) {
154+
operations.batchUpdate(insertSql, sqlParameterSources);
155+
return new Object[sqlParameterSources.length];
156+
}
157+
GeneratedKeyHolder holder = new GeneratedKeyHolder();
158+
159+
IdGeneration idGeneration = sqlGeneratorSource.getDialect().getIdGeneration();
160+
161+
if (idGeneration.driverRequiresKeyColumnNames()) {
162+
163+
String[] keyColumnNames = getKeyColumnNames(persistentEntity.getType());
164+
if (keyColumnNames.length == 0) {
165+
batchOperations.insert(insertSql, sqlParameterSources, holder);
166+
} else {
167+
batchOperations.insert(insertSql, sqlParameterSources, holder, keyColumnNames);
168+
}
169+
} else {
170+
batchOperations.insert(insertSql, sqlParameterSources, holder);
171+
}
155172
if (!persistentEntity.hasIdProperty()) {
156-
return new Object[recordDescriptors.size()];
173+
return new Object[sqlParameterSources.length];
157174
}
158175
// TODO: Duplicated in #getIdFromHolder - consider refactoring
159-
return generatedKeyHolder.getKeyList().stream() //
160-
.map(keys -> {
161-
if (!persistentEntity.hasIdProperty()) {
162-
return null;
163-
}
164-
if (keys.size() > 1) {
165-
return keys.get(persistentEntity.getIdColumn().getReference(getIdentifierProcessing()));
166-
} else {
167-
return keys.entrySet().stream().findFirst() //
168-
.map(Map.Entry::getValue) //
169-
.orElse(null);
170-
}
171-
}) //
172-
.toArray();
176+
Object[] ids = new Object[sqlParameterSources.length];
177+
List<Map<String, Object>> keyList = holder.getKeyList();
178+
for (int i = 0; i < keyList.size(); i++) {
179+
Map<String, Object> keys = keyList.get(i);
180+
Object id;
181+
if (keys.size() > 1) {
182+
id = keys.get(persistentEntity.getIdColumn().getReference(getIdentifierProcessing()));
183+
} else {
184+
id = keys.entrySet().stream().findFirst() //
185+
.map(Map.Entry::getValue) //
186+
// TODO: Missing a test for this
187+
.orElse(null);
188+
}
189+
ids[i] = id;
190+
}
191+
return ids;
173192
}
174193

175194
private <T> SqlIdentifierParameterSource getParameterSource(RecordDescriptor<T> recordDescriptor, RelationalPersistentEntity<T> persistentEntity) {

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DelegatingDataAccessStrategy.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
4848
}
4949

5050
@Override
51-
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
52-
return delegate.insert(recordDescriptors, domainType);
51+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType, boolean includeId) {
52+
return delegate.insert(recordDescriptors, domainType, includeId);
5353
}
5454

5555
/*

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
162162
}
163163

164164
@Override
165-
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType) {
165+
public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T> domainType, boolean includeId) {
166166
// TODO: Figure out what this should look like
167167
throw new UnsupportedOperationException();
168168
}

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutorContextUnitTests.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,11 @@ void batchInsertOperation_withGeneratedIds() {
147147
Identifier identifier = Identifier.empty()
148148
.withPart(SqlIdentifier.quoted("DUMMY_ENTITY"), 123L, Long.class)
149149
.withPart(SqlIdentifier.quoted("DUMMY_ENTITY_KEY"), 0, Integer.class);
150-
when(accessStrategy.insert(singletonList(RecordDescriptor.of(content, identifier)), Content.class))
150+
when(accessStrategy.insert(singletonList(RecordDescriptor.of(content, identifier)), Content.class, false))
151151
.thenReturn(new Object[] { 456L });
152152
DbAction.InsertBatch<?> insertBatch = new DbAction.InsertBatch<>(
153-
singletonList(createInsert(rootInsert, "list", content, 0))
153+
singletonList(createInsert(rootInsert, "list", content, 0)),
154+
false
154155
);
155156
executionContext.executeInsertBatch(insertBatch);
156157

@@ -172,10 +173,11 @@ void batchInsertOperation_withoutGeneratedIds() {
172173
Identifier identifier = Identifier.empty()
173174
.withPart(SqlIdentifier.quoted("DUMMY_ENTITY"), 123L, Long.class)
174175
.withPart(SqlIdentifier.quoted("DUMMY_ENTITY_KEY"), 0, Integer.class);
175-
when(accessStrategy.insert(singletonList(RecordDescriptor.of(content, identifier)), Content.class))
176-
.thenReturn(new Object[] {});
176+
when(accessStrategy.insert(singletonList(RecordDescriptor.of(content, identifier)), Content.class, true))
177+
.thenReturn(new Object[] { null });
177178
DbAction.InsertBatch<?> insertBatch = new DbAction.InsertBatch<>(
178-
singletonList(createInsert(rootInsert, "list", content, 0))
179+
singletonList(createInsert(rootInsert, "list", content, 0)),
180+
true
179181
);
180182
executionContext.executeInsertBatch(insertBatch);
181183

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategyUnitTests.java

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,13 @@ public void insertWithUndefinedIdRetrievesGeneratedKeys() {
230230
class BatchInsertTests {
231231

232232
@Test
233-
void batchInsertMapsParametersForEachRecord() {
233+
void batchInsertParametersWhenNotIncludeId() {
234234

235235
accessStrategy.insert(asList(
236236
RecordDescriptor.of(new DummyEntity(null), Identifier.of(SqlIdentifier.quoted("id"), 1L, Long.class)),
237237
RecordDescriptor.of(new DummyEntity(null), Identifier.of(SqlIdentifier.quoted("id"), 2L, Long.class))),
238-
DummyEntity.class
239-
);
238+
DummyEntity.class,
239+
false);
240240

241241
ArgumentCaptor<SqlParameterSource[]> sqlParametersCaptor = ArgumentCaptor.forClass(SqlParameterSource[].class);
242242
verify(batchJdbcOperations).insert(any(), sqlParametersCaptor.capture(), any());
@@ -246,20 +246,41 @@ void batchInsertMapsParametersForEachRecord() {
246246
assertThat(sqlParameterSources[1].getValue("id")).isEqualTo(2L);
247247
}
248248

249+
@Test
250+
void batchInsertParametersWhenIncludeId() {
251+
252+
Object[] ids = accessStrategy.insert(asList(
253+
RecordDescriptor.of(new DummyEntity(null), Identifier.of(quoted("id"), 1L, Long.class)),
254+
RecordDescriptor.of(new DummyEntity(null), Identifier.of(quoted("id"), 2L, Long.class))),
255+
DummyEntity.class,
256+
true);
257+
258+
assertThat(ids).hasSize(2);
259+
assertThat(ids).containsOnlyNulls();
260+
261+
verifyNoInteractions(batchJdbcOperations);
262+
ArgumentCaptor<SqlParameterSource[]> sqlParametersCaptor = ArgumentCaptor.forClass(SqlParameterSource[].class);
263+
verify(namedJdbcOperations).batchUpdate(any(), sqlParametersCaptor.capture());
264+
SqlParameterSource[] sqlParameterSources = sqlParametersCaptor.getValue();
265+
assertThat(sqlParameterSources).hasSize(2);
266+
assertThat(sqlParameterSources[0].getValue("id")).isEqualTo(1L);
267+
assertThat(sqlParameterSources[1].getValue("id")).isEqualTo(2L);
268+
}
269+
249270
@Test
250271
void batchInsertWithSingleKeyPerRecord() {
251272

252273
when(batchJdbcOperations.insert(any(), any(), any())).thenAnswer(invocationOnMock -> {
253274
KeyHolder keyHolder = invocationOnMock.getArgument(2);
254-
keyHolder.getKeyList().addAll(asList(singletonMap("id", "id1"), singletonMap("id", "id2")));
275+
keyHolder.getKeyList().addAll(singletonList(singletonMap("id", "id1")));
255276
return new int[]{};
256277
});
257278

258279
Object[] ids = accessStrategy.insert(
259280
singletonList(
260281
RecordDescriptor.of(new DummyEntity(null), Identifier.of(SqlIdentifier.quoted("id"), 1L, Long.class))),
261-
DummyEntity.class);
262-
assertThat(ids).containsExactly("id1", "id2");
282+
DummyEntity.class, false);
283+
assertThat(ids).containsExactly("id1");
263284
}
264285

265286
@Test
@@ -270,8 +291,9 @@ void batchInsertWithNoGeneratedKeys() {
270291
Object[] ids = accessStrategy.insert(
271292
singletonList(
272293
RecordDescriptor.of(new DummyEntity(null), Identifier.of(SqlIdentifier.quoted("id"), 1L, Long.class))),
273-
DummyEntity.class);
274-
assertThat(ids).isEmpty();
294+
DummyEntity.class, false);
295+
assertThat(ids).hasSize(1);
296+
assertThat(ids).containsOnlyNulls();
275297
}
276298

277299
@Test
@@ -289,7 +311,7 @@ void batchInsertWithMultipleKeysPerRecord_getsTheKeyForTheIdAnnotatedProperty()
289311
Object[] ids = accessStrategy.insert(
290312
singletonList(
291313
RecordDescriptor.of(new DummyEntity(null), Identifier.of(SqlIdentifier.quoted("id"), 1L, Long.class))),
292-
DummyEntity.class);
314+
DummyEntity.class, false);
293315
assertThat(ids).containsExactly("someId");
294316
}
295317

@@ -299,7 +321,7 @@ void batchInsertWhenEntityHasNoIdAnnotatedProperty() {
299321
when(batchJdbcOperations.insert(any(), any(), any())).thenReturn(new int[]{});
300322

301323
Object[] ids = accessStrategy.insert(singletonList(
302-
RecordDescriptor.of(new DummyEntityWithoutIdAnnotation(null), Identifier.empty())), DummyEntityWithoutIdAnnotation.class);
324+
RecordDescriptor.of(new DummyEntityWithoutIdAnnotation(null), Identifier.empty())), DummyEntityWithoutIdAnnotation.class, false);
303325
assertThat(ids).hasSize(1);
304326
assertThat(ids).containsOnlyNulls();
305327
}
@@ -318,7 +340,7 @@ void batchInsertWithSingleKeyPerRecord_whenKeyDoesNotMatchEntityIdAnnotatedPrope
318340
Object[] ids = accessStrategy.insert(
319341
singletonList(
320342
RecordDescriptor.of(new DummyEntity(null), Identifier.of(SqlIdentifier.quoted("id"), 1L, Long.class))),
321-
DummyEntity.class);
343+
DummyEntity.class, false);
322344
assertThat(ids).containsExactly("someId");
323345
}
324346

spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/DbAction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,12 @@ public String toString() {
375375

376376
final class InsertBatch<T> implements DbAction<T> {
377377
private final List<Insert<T>> inserts;
378+
private final Boolean includeId;
378379

379-
public InsertBatch(List<Insert<T>> inserts) {
380+
public InsertBatch(List<Insert<T>> inserts, Boolean includeId) {
380381
Assert.notEmpty(inserts, "Inserts must contains at least one insert");
381382
this.inserts = inserts;
383+
this.includeId = includeId;
382384
}
383385

384386
@Override
@@ -389,6 +391,10 @@ public Class<T> getEntityType() {
389391
public List<Insert<T>> getInserts() {
390392
return inserts;
391393
}
394+
395+
public Boolean getIncludeId() {
396+
return includeId;
397+
}
392398
}
393399

394400
/**

spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/WritingContext.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,12 @@ private List<? extends DbAction<?>> insertAll(PersistentPropertyPath<RelationalP
150150
.add(insert);
151151
previousActions.put(node, insert);
152152
});
153-
return insertsPartitionedByHasId.values().stream()
154-
.filter(batch -> (!batch.isEmpty()))
155-
.map(batch -> {
153+
return insertsPartitionedByHasId.entrySet().stream()
154+
.filter(entry -> (!entry.getValue().isEmpty()))
155+
.map(entry -> {
156+
List<DbAction.Insert<Object>> batch = entry.getValue();
156157
if (batch.size() > 1) {
157-
return new DbAction.InsertBatch<>(batch);
158+
return new DbAction.InsertBatch<>(batch, entry.getKey());
158159
}
159160
return batch.get(0);
160161
})

spring-data-relational/src/test/java/org/springframework/data/relational/core/conversion/RelationalEntityWriterUnitTests.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -675,19 +675,17 @@ void newEntityWithCollectionWhereSomeElementsHaveIdSet_producesABatchInsertEachF
675675
tuple(InsertBatch.class, Element.class, "", null, false), //
676676
tuple(InsertBatch.class, Element.class, "", null, false) //
677677
);
678-
List<InsertBatch<Element>> insertBatchActions = getInsertBatchActions(actions, Element.class);
679-
assertThat(insertBatchActions).hasSize(2);
680-
List<Insert<Element>> noIdBatchInserts = insertBatchActions.get(0).getInserts();
681-
assertThat(noIdBatchInserts).extracting(DbAction::getClass, //
678+
InsertBatch<Element> insertBatchWithoutId = getInsertBatchAction(actions, Element.class, false);
679+
assertThat(insertBatchWithoutId.getInserts()).extracting(DbAction::getClass, //
682680
DbAction::getEntityType, //
683681
this::getListKey, //
684682
DbActionTestSupport::extractPath) //
685683
.containsExactly( //
686684
tuple(Insert.class, Element.class, 0, "elements"), //
687685
tuple(Insert.class, Element.class, 2, "elements") //
688686
);
689-
List<Insert<Element>> idBatchInserts = insertBatchActions.get(1).getInserts();
690-
assertThat(idBatchInserts).extracting(DbAction::getClass, //
687+
InsertBatch<Element> insertBatchWithId = getInsertBatchAction(actions, Element.class, true);
688+
assertThat(insertBatchWithId.getInserts()).extracting(DbAction::getClass, //
691689
DbAction::getEntityType, //
692690
this::getListKey, //
693691
DbActionTestSupport::extractPath) //
@@ -711,6 +709,15 @@ private <T> InsertBatch<T> getInsertBatchAction(List<DbAction<?>> actions, Class
711709
.orElseThrow(() -> new RuntimeException("No InsertBatch action found!"));
712710
}
713711

712+
@NotNull
713+
private <T> InsertBatch<T> getInsertBatchAction(List<DbAction<?>> actions, Class<T> entityType,
714+
boolean includeId) {
715+
return getInsertBatchActions(actions, entityType).stream()
716+
.filter(insertBatch -> insertBatch.getIncludeId().equals(includeId))
717+
.findFirst()
718+
.orElseThrow(() -> new RuntimeException(String.format("No InsertBatch with includeId '%s' found!", includeId)));
719+
}
720+
714721
@NotNull
715722
private <T> List<InsertBatch<T>> getInsertBatchActions(List<DbAction<?>> actions, Class<T> entityType) {
716723
//noinspection unchecked

0 commit comments

Comments
 (0)