Skip to content

Commit 549f815

Browse files
committed
Terminate stream with error on null values returned by RedisElementReader for top-level elements.
We now emit InvalidDataAccessApiUsageException when a RedisElementReader returns null in the context of a top-level stream to indicate invalid API usage although RedisElementReader.read can generally return null values if these are being collected in a container or value wrapper or parent complex object.
1 parent 7d3e805 commit 549f815

20 files changed

+253
-68
lines changed

src/main/java/org/springframework/data/redis/connection/convert/Converters.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ public static Object parse(Object source, String sourcePath, Map<String, Class<?
465465
* @return
466466
* @since 2.6
467467
*/
468-
public static <K, V> Map.Entry<K, V> entryOf(K key, V value) {
468+
public static <K, V> Map.Entry<K, V> entryOf(@Nullable K key, @Nullable V value) {
469469
return new AbstractMap.SimpleImmutableEntry<>(key, value);
470470
}
471471

src/main/java/org/springframework/data/redis/core/DefaultReactiveGeoOperations.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import java.util.stream.Collectors;
2727

2828
import org.reactivestreams.Publisher;
29-
3029
import org.springframework.data.geo.Circle;
3130
import org.springframework.data.geo.Distance;
3231
import org.springframework.data.geo.GeoResult;
@@ -40,6 +39,7 @@
4039
import org.springframework.data.redis.domain.geo.GeoReference.GeoMemberReference;
4140
import org.springframework.data.redis.domain.geo.GeoShape;
4241
import org.springframework.data.redis.serializer.RedisSerializationContext;
42+
import org.springframework.lang.Nullable;
4343
import org.springframework.util.Assert;
4444

4545
/**
@@ -321,6 +321,7 @@ private ByteBuffer rawValue(V value) {
321321
return serializationContext.getValueSerializationPair().write(value);
322322
}
323323

324+
@Nullable
324325
private V readValue(ByteBuffer buffer) {
325326
return serializationContext.getValueSerializationPair().read(buffer);
326327
}

src/main/java/org/springframework/data/redis/core/DefaultReactiveHashOperations.java

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
import java.util.function.Function;
2727

2828
import org.reactivestreams.Publisher;
29-
29+
import org.springframework.dao.InvalidDataAccessApiUsageException;
3030
import org.springframework.data.redis.connection.ReactiveHashCommands;
3131
import org.springframework.data.redis.connection.convert.Converters;
3232
import org.springframework.data.redis.serializer.RedisSerializationContext;
33+
import org.springframework.lang.Nullable;
3334
import org.springframework.util.Assert;
3435

3536
/**
@@ -127,7 +128,7 @@ public Mono<HK> randomKey(H key) {
127128
Assert.notNull(key, "Key must not be null");
128129

129130
return template.doCreateMono(connection -> connection //
130-
.hashCommands().hRandField(rawKey(key))).map(this::readHashKey);
131+
.hashCommands().hRandField(rawKey(key))).map(this::readRequiredHashKey);
131132
}
132133

133134
@Override
@@ -145,7 +146,7 @@ public Flux<HK> randomKeys(H key, long count) {
145146
Assert.notNull(key, "Key must not be null");
146147

147148
return template.doCreateFlux(connection -> connection //
148-
.hashCommands().hRandField(rawKey(key), count)).map(this::readHashKey);
149+
.hashCommands().hRandField(rawKey(key), count)).map(this::readRequiredHashKey);
149150
}
150151

151152
@Override
@@ -163,7 +164,7 @@ public Flux<HK> keys(H key) {
163164
Assert.notNull(key, "Key must not be null");
164165

165166
return createFlux(connection -> connection.hKeys(rawKey(key)) //
166-
.map(this::readHashKey));
167+
.map(this::readRequiredHashKey));
167168
}
168169

169170
@Override
@@ -211,7 +212,7 @@ public Flux<HV> values(H key) {
211212
Assert.notNull(key, "Key must not be null");
212213

213214
return createFlux(connection -> connection.hVals(rawKey(key)) //
214-
.map(this::readHashValue));
215+
.map(this::readRequiredHashValue));
215216
}
216217

217218
@Override
@@ -268,13 +269,37 @@ private ByteBuffer rawHashValue(HV key) {
268269
}
269270

270271
@SuppressWarnings("unchecked")
272+
@Nullable
271273
private HK readHashKey(ByteBuffer value) {
272274
return (HK) serializationContext.getHashKeySerializationPair().read(value);
273275
}
274276

277+
private HK readRequiredHashKey(ByteBuffer buffer) {
278+
279+
HK hashKey = readHashKey(buffer);
280+
281+
if (hashKey == null) {
282+
throw new InvalidDataAccessApiUsageException("Deserialized hash key is null");
283+
}
284+
285+
return hashKey;
286+
}
287+
275288
@SuppressWarnings("unchecked")
276-
private HV readHashValue(ByteBuffer value) {
277-
return (HV) (value == null ? value : serializationContext.getHashValueSerializationPair().read(value));
289+
@Nullable
290+
private HV readHashValue(@Nullable ByteBuffer value) {
291+
return (HV) (value == null ? null : serializationContext.getHashValueSerializationPair().read(value));
292+
}
293+
294+
private HV readRequiredHashValue(ByteBuffer buffer) {
295+
296+
HV hashValue = readHashValue(buffer);
297+
298+
if (hashValue == null) {
299+
throw new InvalidDataAccessApiUsageException("Deserialized hash value is null");
300+
}
301+
302+
return hashValue;
278303
}
279304

280305
private Map.Entry<HK, HV> deserializeHashEntry(Map.Entry<ByteBuffer, ByteBuffer> source) {

src/main/java/org/springframework/data/redis/core/DefaultReactiveListOperations.java

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
import java.util.function.Function;
2828

2929
import org.reactivestreams.Publisher;
30+
import org.springframework.dao.InvalidDataAccessApiUsageException;
3031
import org.springframework.data.redis.connection.ReactiveListCommands;
3132
import org.springframework.data.redis.connection.ReactiveListCommands.Direction;
3233
import org.springframework.data.redis.connection.ReactiveListCommands.LPosCommand;
3334
import org.springframework.data.redis.connection.RedisListCommands.Position;
3435
import org.springframework.data.redis.serializer.RedisSerializationContext;
36+
import org.springframework.lang.Nullable;
3537
import org.springframework.util.Assert;
3638

3739
/**
@@ -58,7 +60,7 @@ public Flux<V> range(K key, long start, long end) {
5860

5961
Assert.notNull(key, "Key must not be null");
6062

61-
return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readValue));
63+
return createFlux(connection -> connection.lRange(rawKey(key), start, end).map(this::readRequiredValue));
6264
}
6365

6466
@Override
@@ -170,7 +172,8 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to)
170172
Assert.notNull(to, "To direction must not be null");
171173

172174
return createMono(
173-
connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to).map(this::readValue));
175+
connection -> connection.lMove(rawKey(sourceKey), rawKey(destinationKey), from, to)
176+
.map(this::readRequiredValue));
174177
}
175178

176179
@Override
@@ -183,7 +186,7 @@ public Mono<V> move(K sourceKey, Direction from, K destinationKey, Direction to,
183186
Assert.notNull(timeout, "Timeout must not be null");
184187

185188
return createMono(connection -> connection.bLMove(rawKey(sourceKey), rawKey(destinationKey), from, to, timeout)
186-
.map(this::readValue));
189+
.map(this::readRequiredValue));
187190
}
188191

189192
@Override
@@ -208,7 +211,7 @@ public Mono<V> index(K key, long index) {
208211

209212
Assert.notNull(key, "Key must not be null");
210213

211-
return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readValue));
214+
return createMono(connection -> connection.lIndex(rawKey(key), index).map(this::readRequiredValue));
212215
}
213216

214217
@Override
@@ -232,7 +235,7 @@ public Mono<V> leftPop(K key) {
232235

233236
Assert.notNull(key, "Key must not be null");
234237

235-
return createMono(connection -> connection.lPop(rawKey(key)).map(this::readValue));
238+
return createMono(connection -> connection.lPop(rawKey(key)).map(this::readRequiredValue));
236239

237240
}
238241

@@ -244,15 +247,15 @@ public Mono<V> leftPop(K key, Duration timeout) {
244247
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");
245248

246249
return createMono(connection -> connection.blPop(Collections.singletonList(rawKey(key)), timeout)
247-
.map(popResult -> readValue(popResult.getValue())));
250+
.mapNotNull(popResult -> readValue(popResult.getValue())));
248251
}
249252

250253
@Override
251254
public Mono<V> rightPop(K key) {
252255

253256
Assert.notNull(key, "Key must not be null");
254257

255-
return createMono(connection -> connection.rPop(rawKey(key)).map(this::readValue));
258+
return createMono(connection -> connection.rPop(rawKey(key)).map(this::readRequiredValue));
256259
}
257260

258261
@Override
@@ -263,7 +266,7 @@ public Mono<V> rightPop(K key, Duration timeout) {
263266
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");
264267

265268
return createMono(connection -> connection.brPop(Collections.singletonList(rawKey(key)), timeout)
266-
.map(popResult -> readValue(popResult.getValue())));
269+
.mapNotNull(popResult -> readValue(popResult.getValue())));
267270
}
268271

269272
@Override
@@ -273,7 +276,7 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey) {
273276
Assert.notNull(destinationKey, "Destination key must not be null");
274277

275278
return createMono(
276-
connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readValue));
279+
connection -> connection.rPopLPush(rawKey(sourceKey), rawKey(destinationKey)).map(this::readRequiredValue));
277280
}
278281

279282
@Override
@@ -285,7 +288,8 @@ public Mono<V> rightPopAndLeftPush(K sourceKey, K destinationKey, Duration timeo
285288
Assert.isTrue(isZeroOrGreater1Second(timeout), "Duration must be either zero or greater or equal to 1 second");
286289

287290
return createMono(
288-
connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout).map(this::readValue));
291+
connection -> connection.bRPopLPush(rawKey(sourceKey), rawKey(destinationKey), timeout)
292+
.map(this::readRequiredValue));
289293
}
290294

291295
@Override
@@ -322,7 +326,19 @@ private ByteBuffer rawValue(V value) {
322326
return serializationContext.getValueSerializationPair().write(value);
323327
}
324328

329+
@Nullable
325330
private V readValue(ByteBuffer buffer) {
326331
return serializationContext.getValueSerializationPair().read(buffer);
327332
}
333+
334+
private V readRequiredValue(ByteBuffer buffer) {
335+
336+
V v = readValue(buffer);
337+
338+
if (v == null) {
339+
throw new InvalidDataAccessApiUsageException("Deserialized list value is null");
340+
}
341+
342+
return v;
343+
}
328344
}

src/main/java/org/springframework/data/redis/core/DefaultReactiveSetOperations.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
import java.util.function.Function;
2929

3030
import org.reactivestreams.Publisher;
31+
import org.springframework.dao.InvalidDataAccessApiUsageException;
3132
import org.springframework.data.redis.connection.ReactiveSetCommands;
3233
import org.springframework.data.redis.serializer.RedisSerializationContext;
34+
import org.springframework.lang.Nullable;
3335
import org.springframework.util.Assert;
3436

3537
/**
@@ -88,15 +90,15 @@ public Mono<V> pop(K key) {
8890

8991
Assert.notNull(key, "Key must not be null");
9092

91-
return createMono(connection -> connection.sPop(rawKey(key)).map(this::readValue));
93+
return createMono(connection -> connection.sPop(rawKey(key)).map(this::readRequiredValue));
9294
}
9395

9496
@Override
9597
public Flux<V> pop(K key, long count) {
9698

9799
Assert.notNull(key, "Key must not be null");
98100

99-
return createFlux(connection -> connection.sPop(rawKey(key), count).map(this::readValue));
101+
return createFlux(connection -> connection.sPop(rawKey(key), count).map(this::readRequiredValue));
100102
}
101103

102104
@Override
@@ -176,7 +178,7 @@ public Flux<V> intersect(Collection<K> keys) {
176178
.map(this::rawKey) //
177179
.collectList() //
178180
.flatMapMany(connection::sInter) //
179-
.map(this::readValue));
181+
.map(this::readRequiredValue));
180182
}
181183

182184
@Override
@@ -238,7 +240,7 @@ public Flux<V> union(Collection<K> keys) {
238240
.map(this::rawKey) //
239241
.collectList() //
240242
.flatMapMany(connection::sUnion) //
241-
.map(this::readValue));
243+
.map(this::readRequiredValue));
242244
}
243245

244246
@Override
@@ -300,7 +302,7 @@ public Flux<V> difference(Collection<K> keys) {
300302
.map(this::rawKey) //
301303
.collectList() //
302304
.flatMapMany(connection::sDiff) //
303-
.map(this::readValue));
305+
.map(this::readRequiredValue));
304306
}
305307

306308
@Override
@@ -340,7 +342,7 @@ public Flux<V> members(K key) {
340342

341343
Assert.notNull(key, "Key must not be null");
342344

343-
return createFlux(connection -> connection.sMembers(rawKey(key)).map(this::readValue));
345+
return createFlux(connection -> connection.sMembers(rawKey(key)).map(this::readRequiredValue));
344346
}
345347

346348
@Override
@@ -349,31 +351,31 @@ public Flux<V> scan(K key, ScanOptions options) {
349351
Assert.notNull(key, "Key must not be null");
350352
Assert.notNull(options, "ScanOptions must not be null");
351353

352-
return createFlux(connection -> connection.sScan(rawKey(key), options).map(this::readValue));
354+
return createFlux(connection -> connection.sScan(rawKey(key), options).map(this::readRequiredValue));
353355
}
354356

355357
@Override
356358
public Mono<V> randomMember(K key) {
357359

358360
Assert.notNull(key, "Key must not be null");
359361

360-
return createMono(connection -> connection.sRandMember(rawKey(key)).map(this::readValue));
362+
return createMono(connection -> connection.sRandMember(rawKey(key)).map(this::readRequiredValue));
361363
}
362364

363365
@Override
364366
public Flux<V> distinctRandomMembers(K key, long count) {
365367

366368
Assert.isTrue(count > 0, "Negative count not supported; Use randomMembers to allow duplicate elements");
367369

368-
return createFlux(connection -> connection.sRandMember(rawKey(key), count).map(this::readValue));
370+
return createFlux(connection -> connection.sRandMember(rawKey(key), count).map(this::readRequiredValue));
369371
}
370372

371373
@Override
372374
public Flux<V> randomMembers(K key, long count) {
373375

374376
Assert.isTrue(count > 0, "Use a positive number for count; This method is already allowing duplicate elements");
375377

376-
return createFlux(connection -> connection.sRandMember(rawKey(key), -count).map(this::readValue));
378+
return createFlux(connection -> connection.sRandMember(rawKey(key), -count).map(this::readRequiredValue));
377379
}
378380

379381
@Override
@@ -416,7 +418,19 @@ private ByteBuffer rawValue(V value) {
416418
return serializationContext.getValueSerializationPair().write(value);
417419
}
418420

421+
@Nullable
419422
private V readValue(ByteBuffer buffer) {
420423
return serializationContext.getValueSerializationPair().read(buffer);
421424
}
425+
426+
private V readRequiredValue(ByteBuffer buffer) {
427+
428+
V v = readValue(buffer);
429+
430+
if (v == null) {
431+
throw new InvalidDataAccessApiUsageException("Deserialized set value is null");
432+
}
433+
434+
return v;
435+
}
422436
}

0 commit comments

Comments
 (0)