Skip to content

Commit 30f2078

Browse files
committed
Extract insert jdbc operation and key generation logic from DefaultDataAccessStrategy into InsertStrategyFactory.
1 parent ae1983d commit 30f2078

File tree

11 files changed

+488
-288
lines changed

11 files changed

+488
-288
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/BatchJdbcOperations.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.springframework.jdbc.core.namedparam.NamedParameterUtils;
1212
import org.springframework.jdbc.core.namedparam.ParsedSql;
1313
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
14-
import org.springframework.jdbc.support.GeneratedKeyHolder;
1514
import org.springframework.jdbc.support.JdbcUtils;
1615
import org.springframework.jdbc.support.KeyHolder;
1716
import org.springframework.lang.Nullable;
@@ -22,16 +21,16 @@
2221
import java.util.List;
2322
import java.util.Map;
2423

25-
class BatchJdbcOperations {
24+
public class BatchJdbcOperations {
2625
private final JdbcOperations jdbcOperations;
2726

28-
BatchJdbcOperations(JdbcOperations jdbcOperations) {
27+
public BatchJdbcOperations(JdbcOperations jdbcOperations) {
2928
this.jdbcOperations = jdbcOperations;
3029
}
3130

3231
@Nullable
33-
int[] insert(String insertSql, SqlParameterSource[] sqlParameterSources,
34-
GeneratedKeyHolder generatedKeyHolder) {
32+
int[] batchUpdate(String insertSql, SqlParameterSource[] sqlParameterSources,
33+
KeyHolder generatedKeyHolder) {
3534
// TODO: This is largely duplicated from spring-jdbc and should be replaced with a call into
3635
// NamedParameterJdbcTemplate#batchUpdate once a method taking KeyHolder is added there.
3736
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(insertSql);
@@ -80,8 +79,8 @@ public int getBatchSize() {
8079
}
8180

8281
@Nullable
83-
public int[] insert(String insertSql, SqlIdentifierParameterSource[] sqlParameterSources,
84-
GeneratedKeyHolder generatedKeyHolder, String[] keyColumnNames) {
82+
public int[] batchUpdate(String insertSql, SqlParameterSource[] sqlParameterSources,
83+
KeyHolder generatedKeyHolder, String[] keyColumnNames) {
8584
// TODO: This is largely duplicated from spring-jdbc and should be replaced with a call into
8685
// NamedParameterJdbcTemplate#batchUpdate once a method taking KeyHolder is added there.
8786
ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(insertSql);

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

Lines changed: 13 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,15 @@
2222
import java.util.ArrayList;
2323
import java.util.Collections;
2424
import java.util.List;
25-
import java.util.Map;
26-
import java.util.function.Predicate;
25+
import java.util.Optional;
2726

28-
import org.springframework.dao.DataRetrievalFailureException;
2927
import org.springframework.dao.EmptyResultDataAccessException;
30-
import org.springframework.dao.InvalidDataAccessApiUsageException;
3128
import org.springframework.dao.OptimisticLockingFailureException;
3229
import org.springframework.data.domain.Pageable;
3330
import org.springframework.data.domain.Sort;
3431
import org.springframework.data.jdbc.core.mapping.JdbcValue;
3532
import org.springframework.data.jdbc.support.JdbcUtil;
36-
import org.springframework.data.mapping.PersistentProperty;
37-
import org.springframework.data.mapping.PersistentPropertyAccessor;
3833
import org.springframework.data.mapping.PersistentPropertyPath;
39-
import org.springframework.data.relational.core.dialect.IdGeneration;
4034
import org.springframework.data.relational.core.mapping.PersistentPropertyPathExtension;
4135
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
4236
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
@@ -47,9 +41,7 @@
4741
import org.springframework.jdbc.core.RowMapper;
4842
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
4943
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
50-
import org.springframework.jdbc.support.GeneratedKeyHolder;
5144
import org.springframework.jdbc.support.JdbcUtils;
52-
import org.springframework.jdbc.support.KeyHolder;
5345
import org.springframework.lang.Nullable;
5446
import org.springframework.util.Assert;
5547

@@ -75,8 +67,8 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
7567
private final RelationalMappingContext context;
7668
private final JdbcConverter converter;
7769
private final NamedParameterJdbcOperations operations;
78-
private final BatchJdbcOperations batchOperations;
7970
private final SqlParametersFactory sqlParametersFactory;
71+
private final InsertStrategyFactory insertStrategyFactory;
8072

8173
/**
8274
* Creates a {@link DefaultDataAccessStrategy}
@@ -88,26 +80,21 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
8880
* @since 1.1
8981
*/
9082
public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, RelationalMappingContext context,
91-
JdbcConverter converter, NamedParameterJdbcOperations operations, SqlParametersFactory sqlParametersFactory) {
92-
this(sqlGeneratorSource, context, converter, operations, new BatchJdbcOperations(operations.getJdbcOperations()), sqlParametersFactory);
93-
}
94-
95-
DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, RelationalMappingContext context,
96-
JdbcConverter converter, NamedParameterJdbcOperations operations,
97-
BatchJdbcOperations batchOperations, SqlParametersFactory sqlParametersFactory) {
98-
83+
JdbcConverter converter, NamedParameterJdbcOperations operations, SqlParametersFactory sqlParametersFactory,
84+
InsertStrategyFactory insertStrategyFactory) {
9985
Assert.notNull(sqlGeneratorSource, "SqlGeneratorSource must not be null");
10086
Assert.notNull(context, "RelationalMappingContext must not be null");
10187
Assert.notNull(converter, "JdbcConverter must not be null");
10288
Assert.notNull(operations, "NamedParameterJdbcOperations must not be null");
10389
Assert.notNull(sqlParametersFactory, "SqlParametersFactory must not be null");
90+
Assert.notNull(insertStrategyFactory, "InsertStrategyFactory must not be null");
10491

10592
this.sqlGeneratorSource = sqlGeneratorSource;
10693
this.context = context;
10794
this.converter = converter;
10895
this.operations = operations;
109-
this.batchOperations = batchOperations;
11096
this.sqlParametersFactory = sqlParametersFactory;
97+
this.insertStrategyFactory = insertStrategyFactory;
11198
}
11299

113100
/*
@@ -121,14 +108,7 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier,
121108

122109
String insertSql = sql(domainType).getInsert(parameterSource.getIdentifiers());
123110

124-
if (!includeId) {
125-
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
126-
return executeInsertAndReturnGeneratedId(persistentEntity, parameterSource, insertSql);
127-
} else {
128-
129-
operations.update(insertSql, parameterSource);
130-
return null;
131-
}
111+
return insertStrategyFactory.insertStrategy(!includeId, getIdColumn(domainType)).execute(insertSql, parameterSource);
132112
}
133113

134114
@Override
@@ -141,68 +121,7 @@ public <T> Object[] insert(List<RecordDescriptor<T>> recordDescriptors, Class<T>
141121

142122
String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());
143123

144-
if (includeId) {
145-
operations.batchUpdate(insertSql, sqlParameterSources);
146-
return new Object[sqlParameterSources.length];
147-
}
148-
GeneratedKeyHolder holder = new GeneratedKeyHolder();
149-
150-
IdGeneration idGeneration = sqlGeneratorSource.getDialect().getIdGeneration();
151-
152-
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
153-
if (idGeneration.driverRequiresKeyColumnNames()) {
154-
155-
String[] keyColumnNames = getKeyColumnNames(persistentEntity.getType());
156-
if (keyColumnNames.length == 0) {
157-
batchOperations.insert(insertSql, sqlParameterSources, holder);
158-
} else {
159-
batchOperations.insert(insertSql, sqlParameterSources, holder, keyColumnNames);
160-
}
161-
} else {
162-
batchOperations.insert(insertSql, sqlParameterSources, holder);
163-
}
164-
// TODO: Is this needed?
165-
if (!persistentEntity.hasIdProperty()) {
166-
return new Object[sqlParameterSources.length];
167-
}
168-
// TODO: Duplicated in #getIdFromHolder - consider refactoring
169-
Object[] ids = new Object[sqlParameterSources.length];
170-
List<Map<String, Object>> keyList = holder.getKeyList();
171-
for (int i = 0; i < keyList.size(); i++) {
172-
Map<String, Object> keys = keyList.get(i);
173-
Object id;
174-
if (keys.size() > 1) {
175-
id = keys.get(persistentEntity.getIdColumn().getReference(getIdentifierProcessing()));
176-
} else {
177-
id = keys.entrySet().stream().findFirst() //
178-
.map(Map.Entry::getValue) //
179-
.orElseThrow(() -> new IllegalStateException("KeyHolder contains an empty key list."));
180-
}
181-
ids[i] = id;
182-
}
183-
return ids;
184-
}
185-
186-
@Nullable
187-
private <T> Object executeInsertAndReturnGeneratedId(RelationalPersistentEntity<T> persistentEntity, SqlIdentifierParameterSource parameterSource, String insertSql) {
188-
189-
KeyHolder holder = new GeneratedKeyHolder();
190-
191-
IdGeneration idGeneration = sqlGeneratorSource.getDialect().getIdGeneration();
192-
193-
if (idGeneration.driverRequiresKeyColumnNames()) {
194-
195-
String[] keyColumnNames = getKeyColumnNames(persistentEntity.getType());
196-
if (keyColumnNames.length == 0) {
197-
operations.update(insertSql, parameterSource, holder);
198-
} else {
199-
operations.update(insertSql, parameterSource, holder, keyColumnNames);
200-
}
201-
} else {
202-
operations.update(insertSql, parameterSource, holder);
203-
}
204-
205-
return getIdFromHolder(holder, persistentEntity);
124+
return insertStrategyFactory.insertStrategy(!includeId, getIdColumn(domainType)).execute(insertSql, sqlParameterSources);
206125
}
207126

208127
/*
@@ -475,26 +394,6 @@ public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
475394
return operations.query(sql(domainType).getFindAll(pageable), (RowMapper<T>) getEntityRowMapper(domainType));
476395
}
477396

478-
@Nullable
479-
private <S> Object getIdFromHolder(KeyHolder holder, RelationalPersistentEntity<S> persistentEntity) {
480-
481-
try {
482-
// MySQL just returns one value with a special name
483-
return holder.getKey();
484-
} catch (DataRetrievalFailureException | InvalidDataAccessApiUsageException e) {
485-
// Postgres returns a value for each column
486-
// MS SQL Server returns a value that might be null.
487-
488-
Map<String, Object> keys = holder.getKeys();
489-
490-
if (keys == null || persistentEntity.getIdProperty() == null) {
491-
return null;
492-
}
493-
494-
return keys.get(persistentEntity.getIdColumn().getReference(getIdentifierProcessing()));
495-
}
496-
}
497-
498397
private EntityRowMapper<?> getEntityRowMapper(Class<?> domainType) {
499398
return new EntityRowMapper<>(getRequiredPersistentEntity(domainType), converter);
500399
}
@@ -587,17 +486,10 @@ private SqlGenerator sql(Class<?> domainType) {
587486
return sqlGeneratorSource.getSqlGenerator(domainType);
588487
}
589488

590-
private <T> String[] getKeyColumnNames(Class<T> domainType) {
591-
592-
RelationalPersistentEntity<?> requiredPersistentEntity = context.getRequiredPersistentEntity(domainType);
593-
594-
if (!requiredPersistentEntity.hasIdProperty()) {
595-
return new String[0];
596-
}
597-
598-
SqlIdentifier idColumn = requiredPersistentEntity.getIdColumn();
599-
600-
return new String[] { idColumn.getReference(getIdentifierProcessing()) };
489+
@Nullable
490+
private <T> SqlIdentifier getIdColumn(Class<T> domainType) {
491+
return Optional.ofNullable(context.getRequiredPersistentEntity(domainType).getIdProperty())
492+
.map(RelationalPersistentProperty::getColumnName)
493+
.orElse(null);
601494
}
602-
603495
}

0 commit comments

Comments
 (0)