Skip to content

Commit caedb1f

Browse files
schauderchristophstrobl
authored andcommitted
DATAJDBC-378 - Proper handling of null and empty collections in JdbcAggregateTemplate.
1 parent 9903101 commit caedb1f

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateTemplate.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ public <T> T update(T instance) {
152152
*/
153153
@Override
154154
public long count(Class<?> domainType) {
155+
156+
Assert.notNull(domainType, "Domain type must not be null");
157+
155158
return accessStrategy.count(domainType);
156159
}
157160

@@ -162,6 +165,9 @@ public long count(Class<?> domainType) {
162165
@Override
163166
public <T> T findById(Object id, Class<T> domainType) {
164167

168+
Assert.notNull(id, "Id must not be null");
169+
Assert.notNull(domainType, "Domain type must not be null");
170+
165171
T entity = accessStrategy.findById(id, domainType);
166172
if (entity != null) {
167173
publishAfterLoad(id, entity);
@@ -175,6 +181,10 @@ public <T> T findById(Object id, Class<T> domainType) {
175181
*/
176182
@Override
177183
public <T> boolean existsById(Object id, Class<T> domainType) {
184+
185+
Assert.notNull(id, "Id must not be null");
186+
Assert.notNull(domainType, "Domain type must not be null");
187+
178188
return accessStrategy.existsById(id, domainType);
179189
}
180190

@@ -185,6 +195,8 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
185195
@Override
186196
public <T> Iterable<T> findAll(Class<T> domainType) {
187197

198+
Assert.notNull(domainType, "Domain type must not be null");
199+
188200
Iterable<T> all = accessStrategy.findAll(domainType);
189201
publishAfterLoad(all);
190202
return all;
@@ -197,6 +209,9 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
197209
@Override
198210
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
199211

212+
Assert.notNull(ids, "Ids must not be null");
213+
Assert.notNull(domainType, "Domain type must not be null");
214+
200215
Iterable<T> allById = accessStrategy.findAllById(ids, domainType);
201216
publishAfterLoad(allById);
202217
return allById;
@@ -209,6 +224,9 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
209224
@Override
210225
public <S> void delete(S aggregateRoot, Class<S> domainType) {
211226

227+
Assert.notNull(aggregateRoot, "Aggregate root must not be null");
228+
Assert.notNull(domainType, "Domain type must not be null");
229+
212230
IdentifierAccessor identifierAccessor = context.getRequiredPersistentEntity(domainType)
213231
.getIdentifierAccessor(aggregateRoot);
214232

@@ -221,6 +239,10 @@ public <S> void delete(S aggregateRoot, Class<S> domainType) {
221239
*/
222240
@Override
223241
public <S> void deleteById(Object id, Class<S> domainType) {
242+
243+
Assert.notNull(id, "Id must not be null");
244+
Assert.notNull(domainType, "Domain type must not be null");
245+
224246
deleteTree(id, null, domainType);
225247
}
226248

@@ -231,6 +253,8 @@ public <S> void deleteById(Object id, Class<S> domainType) {
231253
@Override
232254
public void deleteAll(Class<?> domainType) {
233255

256+
Assert.notNull(domainType, "Domain type must not be null");
257+
234258
AggregateChange<?> change = createDeletingChange(domainType);
235259
change.executeWith(interpreter, context, converter);
236260
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.sql.JDBCType;
1919
import java.util.ArrayList;
2020
import java.util.Arrays;
21+
import java.util.Collections;
2122
import java.util.HashMap;
2223
import java.util.HashSet;
2324
import java.util.List;
@@ -241,6 +242,10 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
241242
@SuppressWarnings("unchecked")
242243
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
243244

245+
if (!ids.iterator().hasNext()) {
246+
return Collections.emptyList();
247+
}
248+
244249
RelationalPersistentProperty idProperty = getRequiredPersistentEntity(domainType).getRequiredIdProperty();
245250
MapSqlParameterSource parameterSource = new MapSqlParameterSource();
246251

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.junit.ClassRule;
3636
import org.junit.Rule;
3737
import org.junit.Test;
38-
3938
import org.springframework.beans.factory.annotation.Autowired;
4039
import org.springframework.context.ApplicationEventPublisher;
4140
import org.springframework.context.annotation.Bean;
@@ -597,6 +596,24 @@ public void shouldDeleteChainOfMapsWithoutIds() {
597596
});
598597
}
599598

599+
@Test // DATAJDBC-378
600+
public void findAllByIdMustNotAcceptNullArgumentForType() {
601+
602+
assertThatThrownBy(() -> template.findAllById(singleton(23L), null)).isInstanceOf(IllegalArgumentException.class);
603+
}
604+
605+
@Test // DATAJDBC-378
606+
public void findAllByIdMustNotAcceptNullArgumentForIds() {
607+
608+
assertThatThrownBy(() -> template.findAllById(null, LegoSet.class)).isInstanceOf(IllegalArgumentException.class);
609+
}
610+
611+
@Test // DATAJDBC-378
612+
public void findAllByIdWithEmpthListMustReturnEmptyResult() {
613+
614+
assertThat(template.findAllById(emptyList(), LegoSet.class)).isEmpty();
615+
}
616+
600617
private static NoIdMapChain4 createNoIdMapTree() {
601618

602619
NoIdMapChain4 chain4 = new NoIdMapChain4();

0 commit comments

Comments
 (0)