Skip to content

Commit 7cf81ae

Browse files
committed
Polishing.
Replace code duplications with doWithBatch(…) method. Return most concrete type in DefaultDataAccessStrategy and MyBatisDataAccessStrategy. See #1623 Original pull request: #1897
1 parent c4f62e9 commit 7cf81ae

File tree

3 files changed

+40
-48
lines changed

3 files changed

+40
-48
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
package org.springframework.data.jdbc.core;
1717

1818
import java.util.ArrayList;
19+
import java.util.Collection;
1920
import java.util.Collections;
2021
import java.util.HashMap;
2122
import java.util.Iterator;
2223
import java.util.LinkedHashMap;
2324
import java.util.List;
2425
import java.util.Map;
2526
import java.util.Optional;
27+
import java.util.function.Consumer;
2628
import java.util.function.Function;
2729
import java.util.stream.Collectors;
2830
import java.util.stream.StreamSupport;
@@ -56,6 +58,7 @@
5658
import org.springframework.lang.Nullable;
5759
import org.springframework.util.Assert;
5860
import org.springframework.util.ClassUtils;
61+
import org.springframework.util.ObjectUtils;
5962

6063
/**
6164
* {@link JdbcAggregateOperations} implementation, storing aggregates in and obtaining them from a JDBC data store.
@@ -173,19 +176,8 @@ public <T> T save(T instance) {
173176

174177
@Override
175178
public <T> List<T> saveAll(Iterable<T> instances) {
176-
177-
Assert.notNull(instances, "Aggregate instances must not be null");
178-
179-
if (!instances.iterator().hasNext()) {
180-
return Collections.emptyList();
181-
}
182-
183-
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
184-
for (T instance : instances) {
185-
verifyIdProperty(instance);
186-
entityAndChangeCreators.add(new EntityAndChangeCreator<>(instance, changeCreatorSelectorForSave(instance)));
187-
}
188-
return performSaveAll(entityAndChangeCreators);
179+
return doWithBatch(instances, entity -> changeCreatorSelectorForSave(entity).apply(entity), this::verifyIdProperty,
180+
this::performSaveAll);
189181
}
190182

191183
/**
@@ -206,21 +198,7 @@ public <T> T insert(T instance) {
206198

207199
@Override
208200
public <T> List<T> insertAll(Iterable<T> instances) {
209-
210-
Assert.notNull(instances, "Aggregate instances must not be null");
211-
212-
if (!instances.iterator().hasNext()) {
213-
return Collections.emptyList();
214-
}
215-
216-
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
217-
for (T instance : instances) {
218-
219-
Function<T, RootAggregateChange<T>> changeCreator = entity -> createInsertChange(prepareVersionForInsert(entity));
220-
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
221-
entityAndChangeCreators.add(entityChange);
222-
}
223-
return performSaveAll(entityAndChangeCreators);
201+
return doWithBatch(instances, entity -> createInsertChange(prepareVersionForInsert(entity)), this::performSaveAll);
224202
}
225203

226204
/**
@@ -241,21 +219,35 @@ public <T> T update(T instance) {
241219

242220
@Override
243221
public <T> List<T> updateAll(Iterable<T> instances) {
222+
return doWithBatch(instances, entity -> createUpdateChange(prepareVersionForUpdate(entity)), this::performSaveAll);
223+
}
224+
225+
private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
226+
Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {
227+
return doWithBatch(iterable, changeCreator, entity -> {}, performFunction);
228+
}
244229

245-
Assert.notNull(instances, "Aggregate instances must not be null");
230+
private <T> List<T> doWithBatch(Iterable<T> iterable, Function<T, RootAggregateChange<T>> changeCreator,
231+
Consumer<T> beforeEntityChange, Function<List<EntityAndChangeCreator<T>>, List<T>> performFunction) {
246232

247-
if (!instances.iterator().hasNext()) {
233+
Assert.notNull(iterable, "Aggregate instances must not be null");
234+
235+
if (ObjectUtils.isEmpty(iterable)) {
248236
return Collections.emptyList();
249237
}
250238

251-
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>();
252-
for (T instance : instances) {
239+
List<EntityAndChangeCreator<T>> entityAndChangeCreators = new ArrayList<>(
240+
iterable instanceof Collection<?> c ? c.size() : 16);
241+
242+
for (T instance : iterable) {
243+
244+
beforeEntityChange.accept(instance);
253245

254-
Function<T, RootAggregateChange<T>> changeCreator = entity -> createUpdateChange(prepareVersionForUpdate(entity));
255246
EntityAndChangeCreator<T> entityChange = new EntityAndChangeCreator<>(instance, changeCreator);
256247
entityAndChangeCreators.add(entityChange);
257248
}
258-
return performSaveAll(entityAndChangeCreators);
249+
250+
return performFunction.apply(entityAndChangeCreators);
259251
}
260252

261253
@Override

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,12 @@ public <T> T findById(Object id, Class<T> domainType) {
272272
}
273273

274274
@Override
275-
public <T> Iterable<T> findAll(Class<T> domainType) {
275+
public <T> List<T> findAll(Class<T> domainType) {
276276
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
277277
}
278278

279279
@Override
280-
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
280+
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
281281

282282
if (!ids.iterator().hasNext()) {
283283
return Collections.emptyList();
@@ -290,7 +290,7 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
290290

291291
@Override
292292
@SuppressWarnings("unchecked")
293-
public Iterable<Object> findAllByPath(Identifier identifier,
293+
public List<Object> findAllByPath(Identifier identifier,
294294
PersistentPropertyPath<? extends RelationalPersistentProperty> propertyPath) {
295295

296296
Assert.notNull(identifier, "identifier must not be null");
@@ -338,12 +338,12 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
338338
}
339339

340340
@Override
341-
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
341+
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
342342
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
343343
}
344344

345345
@Override
346-
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
346+
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
347347
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
348348
}
349349

@@ -361,7 +361,7 @@ public <T> Optional<T> findOne(Query query, Class<T> domainType) {
361361
}
362362

363363
@Override
364-
public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
364+
public <T> List<T> findAll(Query query, Class<T> domainType) {
365365

366366
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
367367
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);
@@ -370,7 +370,7 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
370370
}
371371

372372
@Override
373-
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
373+
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
374374

375375
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
376376
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource, pageable);

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,21 +256,21 @@ public <T> T findById(Object id, Class<T> domainType) {
256256
}
257257

258258
@Override
259-
public <T> Iterable<T> findAll(Class<T> domainType) {
259+
public <T> List<T> findAll(Class<T> domainType) {
260260

261261
String statement = namespace(domainType) + ".findAll";
262262
MyBatisContext parameter = new MyBatisContext(null, null, domainType, Collections.emptyMap());
263263
return sqlSession().selectList(statement, parameter);
264264
}
265265

266266
@Override
267-
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
267+
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
268268
return sqlSession().selectList(namespace(domainType) + ".findAllById",
269269
new MyBatisContext(ids, null, domainType, Collections.emptyMap()));
270270
}
271271

272272
@Override
273-
public Iterable<Object> findAllByPath(Identifier identifier,
273+
public List<Object> findAllByPath(Identifier identifier,
274274
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
275275

276276
String statementName = namespace(getOwnerTyp(path)) + ".findAllByPath-" + path.toDotPath();
@@ -288,7 +288,7 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
288288
}
289289

290290
@Override
291-
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
291+
public <T> List<T> findAll(Class<T> domainType, Sort sort) {
292292

293293
Map<String, Object> additionalContext = new HashMap<>();
294294
additionalContext.put("sort", sort);
@@ -297,7 +297,7 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
297297
}
298298

299299
@Override
300-
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
300+
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
301301

302302
Map<String, Object> additionalContext = new HashMap<>();
303303
additionalContext.put("pageable", pageable);
@@ -311,12 +311,12 @@ public <T> Optional<T> findOne(Query query, Class<T> probeType) {
311311
}
312312

313313
@Override
314-
public <T> Iterable<T> findAll(Query query, Class<T> probeType) {
314+
public <T> List<T> findAll(Query query, Class<T> probeType) {
315315
throw new UnsupportedOperationException("Not implemented");
316316
}
317317

318318
@Override
319-
public <T> Iterable<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
319+
public <T> List<T> findAll(Query query, Class<T> probeType, Pageable pageable) {
320320
throw new UnsupportedOperationException("Not implemented");
321321
}
322322

0 commit comments

Comments
 (0)