From f7c66a687f317ac9522320979c47b988ec83e614 Mon Sep 17 00:00:00 2001 From: John Blum Date: Wed, 8 Mar 2023 15:23:09 -0800 Subject: [PATCH 1/3] Prepare issue branch. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 778dd177dc..b804e93073 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-redis - 3.2.0-SNAPSHOT + 3.2.0-GH-2518-SNAPSHOT Spring Data Redis Spring Data module for Redis From 27830824b5e9bc0d5ffb542f0e35edd9239de525 Mon Sep 17 00:00:00 2001 From: John Blum Date: Fri, 29 Sep 2023 14:43:14 -0700 Subject: [PATCH 2/3] Remove Thread.sleep(..) from collectResults(..). Replace Thread.sleep(..) with Future.get(timeout, :TimeUnit) for 10 microseconds. As a result, Future.isDone() and Future.isCancelled() are no longer necessary. Simply try to get the results within 10 us, and if a TimeoutException is thrown, then set done to false. 10 microseconds is 1/1000 of 10 milliseconds. This means a Redis cluster with 1000 nodes will run in a similar time to Thread.sleep(10L) if all Futures are blocked waiting for the computation to complete and take an equal amount of time to compute the result, which is rarely the case in practice, given different hardware configurations, data access patterns, load balancing/request routing, and so on. However, using Future.get(timeout, :TimeUnit) is more fair than Future.get(), which blocks until a result is returned or an ExecutionException is thrown, thereby starving computationally faster nodes vs. other nodes in the cluster that might be overloaded. In the meantime, some nodes may even complete in the short amount of time when waiting on a single node to complete. 10 microseconds was partially arbitrary, but no more so than Thread.sleep(10L) (10 milliseconds). The main objective was to give each node a chance to complete the computation in a moments notice balanced with the need to quickly check if the computation is done, hence Future.get(timeout, TimeUnit.MICROSECONDS) for sub-millisecond response times. This may need to be further tuned over time, but should serve as a reasonable baseline for the time being. Additionally, this was based on https://redis.io/docs/reference/cluster-spec/#overview-of-redis-cluster-main-components in the Redis documentation, recommending a cluster size of no more than 1000 nodes. One optimization might be to reorder the Map of Futures at the end of each iteration by organizing Futures that are done first. Furthermore, Futures that have already completed could even be removed from the Map. Of course, there is little harm in keeping the completed Futures in the Map with the safeguard in place. This optimization was not included in theses changes simply because the optimization is most likely negligible and should be measured. Reconstructing a TreeMap should run mostly within log(n) time, but memory consumption should also be taken into consideration. Add test coverage for ClusterCommandExecutor collectResults(..) method. Cleanup compiler warnings in ClusterCommandExecutorUnitTests. Closes #2518 --- .../connection/ClusterCommandExecutor.java | 293 ++++++---- .../ClusterCommandExecutorUnitTests.java | 513 ++++++++++++++---- .../data/redis/test/util/MockitoUtils.java | 142 ++++- 3 files changed, 729 insertions(+), 219 deletions(-) diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index f016dfc33f..22ad65cdd0 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -15,10 +15,28 @@ */ package org.springframework.data.redis.connection; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,6 +61,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum * @since 1.7 */ public class ClusterCommandExecutor implements DisposableBean { @@ -58,7 +77,7 @@ public class ClusterCommandExecutor implements DisposableBean { private final ExceptionTranslationStrategy exceptionTranslationStrategy; /** - * Create a new instance of {@link ClusterCommandExecutor}. + * Create a new {@link ClusterCommandExecutor}. * * @param topologyProvider must not be {@literal null}. * @param resourceProvider must not be {@literal null}. @@ -92,40 +111,47 @@ public ClusterCommandExecutor(ClusterTopologyProvider topologyProvider, ClusterN /** * Run {@link ClusterCommandCallback} on a random node. * - * @param commandCallback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @return never {@literal null}. */ - public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback commandCallback) { + public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback clusterCommand) { - Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); List nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); - return executeCommandOnSingleNode(commandCallback, nodes.get(new Random().nextInt(nodes.size()))); + RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); + + return executeCommandOnSingleNode(clusterCommand, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param cmd must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node) { - return executeCommandOnSingleNode(cmd, node, 0); + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node) { + + return executeCommandOnSingleNode(clusterCommand, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node, - int redirectCount) { + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + RedisClusterNode node, int redirectCount) { - Assert.notNull(cmd, "ClusterCommandCallback must not be null"); + Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); - if (redirectCount > maxRedirects) { - throw new TooManyClusterRedirectionsException(String.format( - "Cannot follow Cluster Redirects over more than %s legs; Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, maxRedirects)); + if (redirectCount > this.maxRedirects) { + + String message = String.format("Cannot follow Cluster Redirects over more than %s legs;" + + " Please consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects); + + throw new TooManyClusterRedirectionsException(message); } RedisClusterNode nodeToUse = lookupNode(node); @@ -135,14 +161,15 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, cmd.doInCluster(client)); + return new NodeResult<>(node, clusterCommand.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(cmd, topologyProvider.getTopology().lookup( - clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); + return executeCommandOnSingleNode(clusterCommand, topologyProvider.getTopology() + .lookup(clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), + redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; } @@ -152,10 +179,10 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(final ClusterCommandCallback cmd) { - return executeCommandAsyncOnNodes(cmd, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback clusterCommand) { + return executeCommandAsyncOnNodes(clusterCommand, getClusterTopology().getActiveMasterNodes()); } /** - * @param callback must not be {@literal null}. + * @param clusterCommand must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback callback, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback clusterCommand, Iterable nodes) { - Assert.notNull(callback, "Callback must not be null"); + Assert.notNull(clusterCommand, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); + ClusterTopology topology = this.topologyProvider.getTopology(); List resolvedRedisClusterNodes = new ArrayList<>(); - ClusterTopology topology = topologyProvider.getTopology(); for (RedisClusterNode node : nodes) { try { resolvedRedisClusterNodes.add(topology.lookup(node)); - } catch (ClusterStateFailureException e) { - throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), e); + } catch (ClusterStateFailureException cause) { + throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), cause); } } Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(callback, node))); + Callable> nodeCommandExecution = () -> executeCommandOnSingleNode(clusterCommand, node); + futures.put(new NodeExecution(node), executor.submit(nodeCommandExecution)); } return collectResults(futures); } - private MultiNodeResult collectResults(Map>> futures) { - - boolean done = false; + MultiNodeResult collectResults(Map>> futures) { Map exceptions = new HashMap<>(); MultiNodeResult result = new MultiNodeResult<>(); - Set saveGuard = new HashSet<>(); + Set safeguard = new HashSet<>(); + + BiConsumer exceptionHandler = getExceptionHandlerFunction(exceptions); + + boolean done = false; while (!done) { @@ -227,50 +257,34 @@ private MultiNodeResult collectResults(Map>> entry : futures.entrySet()) { - if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) { - done = false; - } else { - - NodeExecution execution = entry.getKey(); - - try { - - String futureId = ObjectUtils.getIdentityHexString(entry.getValue()); + NodeExecution nodeExecution = entry.getKey(); + Future> futureNodeResult = entry.getValue(); + String futureId = ObjectUtils.getIdentityHexString(futureNodeResult); - if (!saveGuard.contains(futureId)) { + try { + if (!safeguard.contains(futureId)) { - if (execution.isPositional()) { - result.add(execution.getPositionalKey(), entry.getValue().get()); - } else { - result.add(entry.getValue().get()); - } + NodeResult nodeResult = futureNodeResult.get(10L, TimeUnit.MICROSECONDS); - saveGuard.add(futureId); + if (nodeExecution.isPositional()) { + result.add(nodeExecution.getPositionalKey(), nodeResult); + } else { + result.add(nodeResult); } - } catch (ExecutionException cause) { - - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - } catch (InterruptedException cause) { - - Thread.currentThread().interrupt(); - - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - break; + safeguard.add(futureId); } + } catch (ExecutionException exception) { + safeguard.add(futureId); + exceptionHandler.accept(nodeExecution, exception.getCause()); + } catch (TimeoutException ignore) { + done = false; + } catch (InterruptedException cause) { + Thread.currentThread().interrupt(); + exceptionHandler.accept(nodeExecution, cause); + break; } } - - try { - Thread.sleep(10); - } catch (InterruptedException e) { - done = true; - Thread.currentThread().interrupt(); - } } if (!exceptions.isEmpty()) { @@ -280,6 +294,17 @@ private MultiNodeResult collectResults(Map getExceptionHandlerFunction(Map exceptions) { + + return (nodeExecution, throwable) -> { + + DataAccessException dataAccessException = convertToDataAccessException((Exception) throwable); + Throwable resolvedException = dataAccessException != null ? dataAccessException : throwable; + + exceptions.putIfAbsent(nodeExecution.getNode(), resolvedException); + }; + } + /** * Run {@link MultiKeyClusterCommandCallback} with on a curated set of nodes serving one or more keys. * @@ -306,8 +331,8 @@ public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCa if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), this.executor - .submit(() -> executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> + executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -328,10 +353,11 @@ private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterC try { return new NodeResult<>(node, commandCallback.doInCluster(client, key), key); - } catch (RuntimeException ex) { + } catch (RuntimeException cause) { - RuntimeException translatedException = convertToDataAccessException(ex); - throw translatedException != null ? translatedException : ex; + RuntimeException translatedException = convertToDataAccessException(cause); + + throw translatedException != null ? translatedException : cause; } finally { this.resourceProvider.returnResourceForSpecificNode(node, client); } @@ -343,7 +369,7 @@ private ClusterTopology getClusterTopology() { @Nullable private DataAccessException convertToDataAccessException(Exception cause) { - return exceptionTranslationStrategy.translate(cause); + return this.exceptionTranslationStrategy.translate(cause); } /** @@ -395,7 +421,7 @@ public interface MultiKeyClusterCommandCallback { * @author Mark Paluch * @since 1.7 */ - private static class NodeExecution { + static class NodeExecution { private final RedisClusterNode node; private final @Nullable PositionalKey positionalKey; @@ -414,7 +440,7 @@ private static class NodeExecution { * Get the {@link RedisClusterNode} the execution happens on. */ RedisClusterNode getNode() { - return node; + return this.node; } /** @@ -423,30 +449,31 @@ RedisClusterNode getNode() { * @since 2.0.3 */ PositionalKey getPositionalKey() { - return positionalKey; + return this.positionalKey; } boolean isPositional() { - return positionalKey != null; + return this.positionalKey != null; } } /** - * {@link NodeResult} encapsulates the actual value returned by a {@link ClusterCommandCallback} on a given - * {@link RedisClusterNode}. + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} + * on a given {@link RedisClusterNode}. * + * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl - * @param + * @author John Blum * @since 1.7 */ public static class NodeResult { private RedisClusterNode node; - private @Nullable T value; private ByteArrayWrapper key; + private @Nullable T value; /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -456,7 +483,7 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { } /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -465,37 +492,36 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { public NodeResult(RedisClusterNode node, @Nullable T value, byte[] key) { this.node = node; - this.value = value; - this.key = new ByteArrayWrapper(key); + this.value = value; } /** - * Get the actual value of the command execution. + * Get the {@link RedisClusterNode} the command was executed on. * - * @return can be {@literal null}. + * @return never {@literal null}. */ - @Nullable - public T getValue() { - return value; + public RedisClusterNode getNode() { + return this.node; } /** - * Get the {@link RedisClusterNode} the command was executed on. + * Return the {@link byte[] key} mapped to the value stored in Redis. * - * @return never {@literal null}. + * @return a {@link byte[] byte array} of the key mapped to the value stored in Redis. */ - public RedisClusterNode getNode() { - return node; + public byte[] getKey() { + return this.key.getArray(); } /** - * Returns the key as an array of bytes. + * Get the actual value of the command execution. * - * @return the key as an array of bytes. + * @return can be {@literal null}. */ - public byte[] getKey() { - return key.getArray(); + @Nullable + public T getValue() { + return this.value; } /** @@ -513,6 +539,34 @@ public U mapValue(Function mapper) { return mapper.apply(getValue()); } + + @Override + public boolean equals(@Nullable Object obj) { + + if (obj == this) { + return true; + } + + if (!(obj instanceof NodeResult that)) { + return false; + } + + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) + && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); + } + + @Override + public int hashCode() { + + int hashValue = 17; + + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getNode()); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(this.key); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getValue()); + + return hashValue; + } } /** @@ -566,12 +620,14 @@ public List resultsAsListSortBy(byte[]... keys) { if (positionalResults.isEmpty()) { List> clone = new ArrayList<>(nodeResults); + clone.sort(new ResultByReferenceKeyPositionComparator(keys)); return toList(clone); } Map> result = new TreeMap<>(new ResultByKeyPositionComparator(keys)); + result.putAll(positionalResults); return result.values().stream().map(tNodeResult -> tNodeResult.value).collect(Collectors.toList()); @@ -604,10 +660,12 @@ public T getFirstNonNullNotEmptyOrDefault(@Nullable T returnValue) { private List toList(Collection> source) { - ArrayList result = new ArrayList<>(); + List result = new ArrayList<>(); + for (NodeResult nodeResult : source) { result.add(nodeResult.getValue()); } + return result; } @@ -678,7 +736,7 @@ static PositionalKey of(byte[] key, int index) { * @return binary key. */ byte[] getBytes() { - return key.getArray(); + return getKey().getArray(); } public ByteArrayWrapper getKey() { @@ -690,23 +748,23 @@ public int getPosition() { } @Override - public boolean equals(@Nullable Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; + public boolean equals(@Nullable Object obj) { - PositionalKey that = (PositionalKey) o; + if (this == obj) { + return true; + } - if (position != that.position) + if (!(obj instanceof PositionalKey that)) return false; - return ObjectUtils.nullSafeEquals(key, that.key); + + return this.getPosition() == that.getPosition() + && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); } @Override public int hashCode() { - int result = ObjectUtils.nullSafeHashCode(key); - result = 31 * result + position; + int result = ObjectUtils.nullSafeHashCode(getKey()); + result = 31 * result + ObjectUtils.nullSafeHashCode(getPosition()); return result; } } @@ -753,6 +811,7 @@ static PositionalKeys of(byte[]... keys) { static PositionalKeys of(PositionalKey... keys) { PositionalKeys result = PositionalKeys.empty(); + result.append(keys); return result; @@ -769,12 +828,12 @@ void append(PositionalKey... keys) { * @return index of the {@link PositionalKey}. */ int indexOf(PositionalKey key) { - return keys.indexOf(key); + return this.keys.indexOf(key); } @Override public Iterator iterator() { - return keys.iterator(); + return this.keys.iterator(); } } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index f99e786ce4..25e3d47562 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -15,16 +15,38 @@ */ package org.springframework.data.redis.connection; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; -import static org.springframework.data.redis.test.util.MockitoUtils.*; - +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.data.redis.test.util.MockitoUtils.verifyInvocationsAcross; + +import java.time.Instant; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.function.Supplier; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,6 +55,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.stubbing.Answer; import org.springframework.core.convert.converter.Converter; import org.springframework.core.task.AsyncTaskExecutor; @@ -45,14 +68,24 @@ import org.springframework.data.redis.connection.ClusterCommandExecutor.ClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiKeyClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiNodeResult; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeExecution; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeResult; import org.springframework.data.redis.connection.RedisClusterNode.LinkState; import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; +import org.springframework.data.redis.test.util.MockitoUtils; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + /** + * Unit Tests for {@link ClusterCommandExecutor}. + * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @since 1.7 */ @ExtendWith(MockitoExtension.class) class ClusterCommandExecutorUnitTests { @@ -66,17 +99,32 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT).serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) + .serving(new SlotRange(0, 5460)) + .withId("ef570f86c7b1a953846668debc177a3a16733420") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeX") .build(); + private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT).serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) + .serving(new SlotRange(5461, 10922)) + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeY") .build(); + private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT).serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) + .serving(new SlotRange(10923, 16383)) + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") + .promotedAs(NodeType.MASTER) + .linkState(LinkState.CONNECTED) + .withName("ClusterNodeZ") .build(); + private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); @@ -88,8 +136,8 @@ class ClusterCommandExecutorUnitTests { private static final Converter exceptionConverter = source -> { - if (source instanceof MovedException) { - return new ClusterRedirectException(1000, ((MovedException) source).host, ((MovedException) source).port, source); + if (source instanceof MovedException movedException) { + return new ClusterRedirectException(1000, movedException.host, movedException.port, source); } return new InvalidDataAccessApiUsageException(source.getMessage(), source); @@ -97,14 +145,14 @@ class ClusterCommandExecutorUnitTests { private static final MultiKeyConnectionCommandCallback MULTIKEY_CALLBACK = Connection::bloodAndAshes; - @Mock Connection con1; - @Mock Connection con2; - @Mock Connection con3; + @Mock Connection connection1; + @Mock Connection connection2; + @Mock Connection connection3; @BeforeEach void setUp() { - this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterResourceProvider(), + this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ImmediateExecutor()); } @@ -118,7 +166,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -127,7 +175,7 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -135,15 +183,17 @@ void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(COMMAND_CALLBACK, null)); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenCommandCallbackIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(null, CLUSTER_NODE_1)); } @@ -158,52 +208,52 @@ void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsUnknown() { void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, @@ -214,42 +264,42 @@ void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - } catch (ClusterCommandExecutionFailureException e) { + } catch (ClusterCommandExecutionFailureException cause) { - assertThat(e.getSuppressed()).hasSize(1); - assertThat(e.getSuppressed()[0]).isInstanceOf(DataAccessException.class); + assertThat(cause.getSuppressed()).hasSize(1); + assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCollectResultsCorrectly() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); MultiNodeResult result = executor.executeCommandOnAllNodes(COMMAND_CALLBACK); @@ -261,10 +311,10 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { // key-1 and key-9 map both to node1 ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); - when(con1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); - when(con2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); - when(con3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); + when(connection1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); + when(connection2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); + when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, new HashSet<>( @@ -279,21 +329,21 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(con3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); try { executor.setMaxRedirects(4); @@ -302,9 +352,9 @@ void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); } - verify(con1, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -312,53 +362,249 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), con1, con2, con3); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), connection1, connection2, connection3); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfully() { + + Instant done = Instant.now().plusMillis(5); + + Predicate>> isDone = future -> Instant.now().isAfter(done); + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureAndIsDone(nodeTwoB, isDone)); + futures.put(newNodeExecution(CLUSTER_NODE_3), mockFutureAndIsDone(nodeThreeC, isDone)); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + + futures.values().forEach(future -> + runsSafely(() -> verify(future, times(1)).get(anyLong(), any(TimeUnit.class)))); + } + + @Test // GH-2518 + void collectResultsCompletesSuccessfullyEvenWithTimeouts() throws Exception { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + + Future> nodeOneFutureResult = mockFutureThrowingTimeoutException(nodeOneA, 4); + Future> nodeTwoFutureResult = mockFutureThrowingTimeoutException(nodeTwoB, 1); + Future> nodeThreeFutureResult = mockFutureThrowingTimeoutException(nodeThreeC, 2); + + futures.put(newNodeExecution(CLUSTER_NODE_1), nodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), nodeTwoFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_3), nodeThreeFutureResult); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + + verify(nodeOneFutureResult, times(4)).get(anyLong(), any(TimeUnit.class)); + verify(nodeTwoFutureResult, times(1)).get(anyLong(), any(TimeUnit.class)); + verify(nodeThreeFutureResult, times(2)).get(anyLong(), any(TimeUnit.class)); + verifyNoMoreInteractions(nodeOneFutureResult, nodeTwoFutureResult, nodeThreeFutureResult); + } + + @Test // GH-2518 + void collectResultsFailsWithExecutionException() { + + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); + + futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, future -> true)); + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError")))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalArgumentException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError"); } - class MockClusterNodeProvider implements ClusterTopologyProvider { + @Test // GH-2518 + void collectResultsFailsWithInterruptedException() throws Throwable { + TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); + } + + // Future.get() for X will get called twice if at least one other Future is not done and Future.get() for X + // threw an ExecutionException in the previous iteration, thereby marking it as done! + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCallsFutureGetOnlyOnce() throws Exception { + + AtomicInteger count = new AtomicInteger(0); + Map>> futures = new HashMap<>(); + + Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, future -> + count.incrementAndGet() % 2 == 0); + + Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError"))); + + futures.put(newNodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); + futures.put(newNodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)); + + verify(clusterNodeOneFutureResult, times(1)).get(anyLong(), any()); + verify(clusterNodeTwoFutureResult, times(1)).get(anyLong(), any()); + } + + // Covers the case where Future.get() is mistakenly called multiple times, or if the Future.isDone() implementation + // does not properly take into account Future.get() throwing an ExecutionException during computation subsequently + // returning false instead of true. + // This should be properly handled by the "safeguard" (see collectResultsCallsFutureGetOnlyOnce()), but... + // just in case! The ExecutionException handler now stores the [DataAccess]Exception with Map.putIfAbsent(..). + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCapturesFirstExecutionExceptionOnly() { + + AtomicInteger count = new AtomicInteger(0); + AtomicInteger exceptionCount = new AtomicInteger(0); + + Map>> futures = new HashMap<>(); + + futures.put(newNodeExecution(CLUSTER_NODE_1), + mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); + + futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException(() -> + new ExecutionException("TestError", new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) + .withMessage("MockError0") + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .extracting(Throwable::getCause) + .extracting(Throwable::getCause) + .isInstanceOf(IllegalStateException.class) + .extracting(Throwable::getMessage) + .isEqualTo("MockError0"); + } + + private Future mockFutureAndIsDone(T result, Predicate> isDone) { + + return MockitoUtils.mockFuture(result, future -> { + doAnswer(invocation -> isDone.test(future)).when(future).isDone(); + return future; + }); + } + + private Future mockFutureThrowingExecutionException(ExecutionException exception) { + return mockFutureThrowingExecutionException(() -> exception); + } + + private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { + + Answer getAnswer = invocationOnMock -> { throw exceptionSupplier.get(); }; + + return MockitoUtils.mockFuture(null, future -> { + doReturn(true).when(future).isDone(); + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + return future; + }); + } + + @SuppressWarnings("unchecked") + private Future mockFutureThrowingTimeoutException(T result, int timeoutCount) { + + AtomicInteger counter = new AtomicInteger(timeoutCount); + + Answer getAnswer = invocationOnMock -> { + + if (counter.decrementAndGet() > 0) { + throw new TimeoutException("TIMES UP"); + } + + doReturn(true).when((Future>) invocationOnMock.getMock()).isDone(); + + return result; + }; + + return MockitoUtils.mockFuture(result, future -> { + + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + + return future; + }); + } + + private NodeExecution newNodeExecution(RedisClusterNode clusterNode) { + return new NodeExecution(clusterNode); + } + + private NodeResult newNodeResult(RedisClusterNode clusterNode, T value) { + return new NodeResult<>(clusterNode, value); + } + + private void runsSafely(ThrowableOperation operation) { + + try { + operation.run(); + } catch (Throwable ignore) { } + } + + static class MockClusterNodeProvider implements ClusterTopologyProvider { @Override public ClusterTopology getTopology() { - return new ClusterTopology( - new LinkedHashSet<>(Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3))); + return new ClusterTopology(Set.of(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3)); } - } - class MockClusterResourceProvider implements ClusterNodeResourceProvider { + class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { @Override - public Connection getResourceForSpecificNode(RedisClusterNode node) { - - if (CLUSTER_NODE_1.equals(node)) { - return con1; - } - if (CLUSTER_NODE_2.equals(node)) { - return con2; - } - if (CLUSTER_NODE_3.equals(node)) { - return con3; - } + @SuppressWarnings("all") + public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { - return null; + return CLUSTER_NODE_1.equals(clusterNode) ? connection1 + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 + : CLUSTER_NODE_3.equals(clusterNode) ? connection3 + : null; } @Override public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - // TODO Auto-generated method stub } - } - static interface ConnectionCommandCallback extends ClusterCommandCallback { + interface ConnectionCommandCallback extends ClusterCommandCallback { } - static interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { } - static interface Connection { + @FunctionalInterface + interface ThrowableOperation { + void run() throws Throwable; + } + + interface Connection { String theWheelWeavesAsTheWheelWills(); @@ -374,19 +620,20 @@ static class MovedException extends RuntimeException { this.host = host; this.port = port; } - } static class ImmediateExecutor implements AsyncTaskExecutor { @Override - public void execute(Runnable runnable, long l) { + public void execute(Runnable runnable) { runnable.run(); } @Override public Future submit(Runnable runnable) { + return submit(() -> { + runnable.run(); return null; @@ -395,19 +642,103 @@ public Future submit(Runnable runnable) { @Override public Future submit(Callable callable) { + try { return CompletableFuture.completedFuture(callable.call()); - } catch (Exception e) { + } catch (Exception cause) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(e); + + future.completeExceptionally(cause); + return future; } } + } + + @SuppressWarnings("all") + static class CollectResultsInterruptedMultithreadedTestCase extends MultithreadedTestCase { + + private static final CountDownLatch latch = new CountDownLatch(1); + + private static final Comparator NODE_COMPARATOR = + Comparator.comparing(nodeExecution -> nodeExecution.getNode().getName()); + + private final ClusterCommandExecutor clusterCommandExecutor; + + private final Map>> futureNodeResults; + + private Future> mockNodeOneFutureResult; + private Future> mockNodeTwoFutureResult; + + private volatile Thread collectResultsThread; + + private CollectResultsInterruptedMultithreadedTestCase(ClusterCommandExecutor clusterCommandExecutor) { + this.clusterCommandExecutor = clusterCommandExecutor; + this.futureNodeResults = new ConcurrentSkipListMap<>(NODE_COMPARATOR); + } @Override - public void execute(Runnable runnable) { - runnable.run(); + public void initialize() { + + super.initialize(); + + this.mockNodeOneFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_1), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + doReturn(false).when(mockFuture).isDone(); + return mockFuture; + })); + + this.mockNodeTwoFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_2), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + + doReturn(true).when(mockFuture).isDone(); + + doAnswer(invocation -> { + latch.await(); + return null; + }).when(mockFuture).get(anyLong(), any()); + + return mockFuture; + })); + } + + public void thread1() { + + assertTick(0); + + this.collectResultsThread = Thread.currentThread(); + this.collectResultsThread.setName("CollectResults Thread"); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + + assertThat(this.collectResultsThread.isInterrupted()).isTrue(); + } + + public void thread2() { + + assertTick(0); + + Thread.currentThread().setName("Interrupting Thread"); + + waitForTick(1); + + assertThat(this.collectResultsThread).isNotNull(); + assertThat(this.collectResultsThread.getName()).isEqualTo("CollectResults Thread"); + + this.collectResultsThread.interrupt(); + } + + @Override + public void finish() { + + try { + verify(this.mockNodeOneFutureResult, never()).get(); + verify(this.mockNodeTwoFutureResult, times(1)).get(anyLong(), any()); + } catch (ExecutionException | InterruptedException | TimeoutException cause) { + throw new RuntimeException(cause); + } } } } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index e42f0749d8..318d47ec85 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -15,34 +15,131 @@ */ package org.springframework.data.redis.test.util; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import static org.mockito.Mockito.withSettings; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import org.mockito.internal.invocation.InvocationMatcher; import org.mockito.internal.verification.api.VerificationData; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; +import org.mockito.quality.Strictness; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** + * Utilities for using {@literal Mockito} and creating {@link Object mock objects} in {@literal unit tests}. + * * @author Christoph Strobl + * @author John Blum + * @see org.mockito.Mockito * @since 1.7 */ -public class MockitoUtils { +@SuppressWarnings("unused") +public abstract class MockitoUtils { + + /** + * Creates a mock {@link Future} returning the given {@link Object result}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + */ + @SuppressWarnings("unchecked") + public static @NonNull Future mockFuture(@Nullable T result) { + + try { + + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicBoolean done = new AtomicBoolean(false); + + Future mockFuture = mock(Future.class, withSettings().strictness(Strictness.LENIENT)); + + // A Future can only be cancelled if not done, it was not already cancelled, and no error occurred. + // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. + // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution + // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. + Answer cancelAnswer = invocation -> !done.get() + && cancelled.compareAndSet(done.get(), true) + && done.compareAndSet(done.get(), true); + + Answer getAnswer = invocation -> { + + // The Future is done no matter if it returns the result or was cancelled/interrupted. + done.set(true); + + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException("Thread was interrupted"); + } + + if (cancelled.get()) { + throw new CancellationException("Task was cancelled"); + } + + return result; + }; + + doAnswer(invocation -> cancelled.get()).when(mockFuture).isCancelled(); + doAnswer(invocation -> done.get()).when(mockFuture).isDone(); + doAnswer(cancelAnswer).when(mockFuture).cancel(anyBoolean()); + doAnswer(getAnswer).when(mockFuture).get(); + doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); + + return mockFuture; + } + catch (Exception cause) { + String message = String.format("Failed to create a mock of Future having result [%s]", result); + throw new IllegalStateException(message, cause); + } + } + + /** + * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, + * required {@link Function}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; + * must not be {@literal null}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future + * @see java.util.function.Function + * @see #mockFuture(Object) + */ + public static @NonNull Future mockFuture(@Nullable T result, + @NonNull ThrowableFunction, Future> futureFunction) { + + Future mockFuture = mockFuture(result); + + return futureFunction.apply(mockFuture); + } /** * Verifies a given method is called a total number of times across all given mocks. * - * @param method - * @param mode - * @param mocks + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mode mode of verification used by {@literal Mockito} to verify invocations on {@link Object mock objects}. + * @param mocks array of {@link Object mock objects} to verify. */ - @SuppressWarnings({ "rawtypes", "serial" }) - public static void verifyInvocationsAcross(final String method, final VerificationMode mode, Object... mocks) { + public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections .singletonList(org.mockito.internal.matchers.Any.ANY)) { @@ -56,17 +153,15 @@ public boolean matches(Invocation actual) { public String toString() { return String.format("%s for method: %s", mode, method); } - })); } private static List getInvocations(String method, Object... mocks) { List invocations = new ArrayList<>(); - for (Object mock : mocks) { + for (Object mock : mocks) { if (StringUtils.hasText(method)) { - for (Invocation invocation : mockingDetails(mock).getInvocations()) { if (invocation.getMethod().getName().equals(method)) { invocations.add(invocation); @@ -76,6 +171,7 @@ private static List getInvocations(String method, Object... mocks) { invocations.addAll(mockingDetails(mock).getInvocations()); } } + return invocations; } @@ -98,7 +194,31 @@ public List getAllInvocations() { public MatchableInvocation getTarget() { return wanted; } - } + @FunctionalInterface + public interface ThrowableFunction extends Function { + + @Override + default R apply(T target) { + + try { + return applyThrowingException(target); + } + catch (Throwable cause) { + String message = String.format("Failed to apply Function [%s] to target [%s]", this, target); + throw new IllegalStateException(message, cause); + } + } + + R applyThrowingException(T target) throws Throwable; + + @SuppressWarnings("unchecked") + default @NonNull ThrowableFunction andThen( + @Nullable ThrowableFunction after) { + + return after == null ? (ThrowableFunction) this + : target -> after.apply(this.apply(target)); + } + } } From 0cfe8bd649fc9bd3bcfedf0ab7f8b93d1f2fd6bb Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 11 Oct 2023 11:05:42 +0200 Subject: [PATCH 3/3] Polishing. Simplify tests. Reuse existing interfaces from Spring. Remove inappropriate nullability annotations and introduce annotations where required. Consistently name callbacks. Make exception collector concept explicit. Reformat code. --- .../connection/ClusterCommandExecutor.java | 169 ++++----- .../ClusterCommandExecutorUnitTests.java | 332 +++++++----------- .../data/redis/test/util/MockitoUtils.java | 91 ++--- 3 files changed, 241 insertions(+), 351 deletions(-) diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index 22ad65cdd0..bfcf3a7853 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -15,28 +15,13 @@ */ package org.springframework.data.redis.connection; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; -import java.util.Objects; -import java.util.Random; -import java.util.Set; -import java.util.TreeMap; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.BiConsumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -111,47 +96,46 @@ public ClusterCommandExecutor(ClusterTopologyProvider topologyProvider, ClusterN /** * Run {@link ClusterCommandCallback} on a random node. * - * @param clusterCommand must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @return never {@literal null}. */ - public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback clusterCommand) { + public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback commandCallback) { - Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); + Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); List nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); - return executeCommandOnSingleNode(clusterCommand, arbitraryNode); + return executeCommandOnSingleNode(commandCallback, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param clusterCommand must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback commandCallback, RedisClusterNode node) { - return executeCommandOnSingleNode(clusterCommand, node, 0); + return executeCommandOnSingleNode(commandCallback, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback clusterCommand, + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback commandCallback, RedisClusterNode node, int redirectCount) { - Assert.notNull(clusterCommand, "ClusterCommandCallback must not be null"); + Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); if (redirectCount > this.maxRedirects) { - String message = String.format("Cannot follow Cluster Redirects over more than %s legs;" - + " Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, this.maxRedirects); - - throw new TooManyClusterRedirectionsException(message); + throw new TooManyClusterRedirectionsException(String.format( + "Cannot follow Cluster Redirects over more than %s legs; " + + "Consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects)); } RedisClusterNode nodeToUse = lookupNode(node); @@ -161,15 +145,14 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, clusterCommand.doInCluster(client)); + return new NodeResult<>(node, commandCallback.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(clusterCommand, topologyProvider.getTopology() - .lookup(clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), - redirectCount + 1); + return executeCommandOnSingleNode(commandCallback, topologyProvider.getTopology().lookup( + clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; } @@ -182,7 +165,8 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback clusterCommand) { - return executeCommandAsyncOnNodes(clusterCommand, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback commandCallback) { + return executeCommandAsyncOnNodes(commandCallback, getClusterTopology().getActiveMasterNodes()); } /** - * @param clusterCommand must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback clusterCommand, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback commandCallback, Iterable nodes) { - Assert.notNull(clusterCommand, "Callback must not be null"); + Assert.notNull(commandCallback, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); ClusterTopology topology = this.topologyProvider.getTopology(); @@ -234,7 +218,7 @@ public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallba Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - Callable> nodeCommandExecution = () -> executeCommandOnSingleNode(clusterCommand, node); + Callable> nodeCommandExecution = () -> executeCommandOnSingleNode(commandCallback, node); futures.put(new NodeExecution(node), executor.submit(nodeCommandExecution)); } @@ -243,26 +227,22 @@ public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallba MultiNodeResult collectResults(Map>> futures) { - Map exceptions = new HashMap<>(); + NodeExceptionCollector exceptionCollector = new NodeExceptionCollector(); MultiNodeResult result = new MultiNodeResult<>(); - Set safeguard = new HashSet<>(); + Object placeholder = new Object(); + Map>, Object> safeguard = new IdentityHashMap<>(); - BiConsumer exceptionHandler = getExceptionHandlerFunction(exceptions); - - boolean done = false; - - while (!done) { - - done = true; + for (;;) { + boolean timeout = false; for (Map.Entry>> entry : futures.entrySet()) { NodeExecution nodeExecution = entry.getKey(); Future> futureNodeResult = entry.getValue(); - String futureId = ObjectUtils.getIdentityHexString(futureNodeResult); try { - if (!safeguard.contains(futureId)) { + + if (!safeguard.containsKey(futureNodeResult)) { NodeResult nodeResult = futureNodeResult.get(10L, TimeUnit.MICROSECONDS); @@ -272,39 +252,32 @@ MultiNodeResult collectResults(Map>> result.add(nodeResult); } - safeguard.add(futureId); + safeguard.put(futureNodeResult, placeholder); } } catch (ExecutionException exception) { - safeguard.add(futureId); - exceptionHandler.accept(nodeExecution, exception.getCause()); + safeguard.put(futureNodeResult, placeholder); + exceptionCollector.addException(nodeExecution, exception.getCause()); } catch (TimeoutException ignore) { - done = false; - } catch (InterruptedException cause) { + timeout = true; + } catch (InterruptedException exception) { Thread.currentThread().interrupt(); - exceptionHandler.accept(nodeExecution, cause); + exceptionCollector.addException(nodeExecution, exception); break; } } + + if (!timeout) { + break; + } } - if (!exceptions.isEmpty()) { - throw new ClusterCommandExecutionFailureException(new ArrayList<>(exceptions.values())); + if (exceptionCollector.hasExceptions()) { + throw new ClusterCommandExecutionFailureException(exceptionCollector.getExceptions()); } return result; } - private BiConsumer getExceptionHandlerFunction(Map exceptions) { - - return (nodeExecution, throwable) -> { - - DataAccessException dataAccessException = convertToDataAccessException((Exception) throwable); - Throwable resolvedException = dataAccessException != null ? dataAccessException : throwable; - - exceptions.putIfAbsent(nodeExecution.getNode(), resolvedException); - }; - } - /** * Run {@link MultiKeyClusterCommandCallback} with on a curated set of nodes serving one or more keys. * @@ -331,8 +304,8 @@ public MultiNodeResult executeMultiKeyCommand(MultiKeyClusterCommandCa if (entry.getKey().isMaster()) { for (PositionalKey key : entry.getValue()) { - futures.put(new NodeExecution(entry.getKey(), key), this.executor.submit(() -> - executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); + futures.put(new NodeExecution(entry.getKey(), key), this.executor + .submit(() -> executeMultiKeyCommandOnSingleNode(commandCallback, entry.getKey(), key.getBytes()))); } } } @@ -458,8 +431,8 @@ boolean isPositional() { } /** - * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} - * on a given {@link RedisClusterNode}. + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} on a given + * {@link RedisClusterNode}. * * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl @@ -468,9 +441,9 @@ boolean isPositional() { */ public static class NodeResult { - private RedisClusterNode node; - private ByteArrayWrapper key; - private @Nullable T value; + private final RedisClusterNode node; + private final ByteArrayWrapper key; + private final @Nullable T value; /** * Create a new {@link NodeResult}. @@ -551,9 +524,8 @@ public boolean equals(@Nullable Object obj) { return false; } - return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) - && Objects.equals(this.key, that.key) - && Objects.equals(this.getValue(), that.getValue()); + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); } @Override @@ -757,8 +729,7 @@ public boolean equals(@Nullable Object obj) { if (!(obj instanceof PositionalKey that)) return false; - return this.getPosition() == that.getPosition() - && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); + return this.getPosition() == that.getPosition() && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); } @Override @@ -836,4 +807,34 @@ public Iterator iterator() { return this.keys.iterator(); } } + + /** + * Collector for exceptions. Applies translation of exceptions if possible. + */ + private class NodeExceptionCollector { + + private final Map exceptions = new HashMap<>(); + + /** + * @return {@code true} if the collector contains at least one exception. + */ + public boolean hasExceptions() { + return !exceptions.isEmpty(); + } + + public void addException(NodeExecution execution, Throwable throwable) { + + Throwable translated = throwable instanceof Exception e ? convertToDataAccessException(e) : throwable; + Throwable resolvedException = translated != null ? translated : throwable; + + exceptions.putIfAbsent(execution.getNode(), resolvedException); + } + + /** + * @return the collected exceptions. + */ + public List getExceptions() { + return new ArrayList<>(exceptions.values()); + } + } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index 25e3d47562..1da31a7f73 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -15,21 +15,15 @@ */ package org.springframework.data.redis.connection; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.anyLong; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.data.redis.test.util.MockitoUtils.verifyInvocationsAcross; - -import java.time.Instant; +import static org.springframework.data.redis.test.util.MockitoUtils.*; + +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; @@ -42,9 +36,12 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; import java.util.function.Predicate; import java.util.function.Supplier; @@ -56,7 +53,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.stubbing.Answer; - import org.springframework.core.convert.converter.Converter; import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.SyncTaskExecutor; @@ -74,11 +70,9 @@ import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; import org.springframework.data.redis.test.util.MockitoUtils; +import org.springframework.lang.Nullable; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; -import edu.umd.cs.mtc.MultithreadedTestCase; -import edu.umd.cs.mtc.TestFramework; - /** * Unit Tests for {@link ClusterCommandExecutor}. * @@ -99,34 +93,35 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) - .serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420") - .promotedAs(NodeType.MASTER) - .linkState(LinkState.CONNECTED) - .withName("ClusterNodeX") + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) // + .serving(new SlotRange(0, 5460)) // + .withId("ef570f86c7b1a953846668debc177a3a16733420") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeX") // .build(); private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) - .serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") - .promotedAs(NodeType.MASTER) - .linkState(LinkState.CONNECTED) - .withName("ClusterNodeY") + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) // + .serving(new SlotRange(5461, 10922)) // + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeY") // .build(); private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) - .serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") - .promotedAs(NodeType.MASTER) - .linkState(LinkState.CONNECTED) - .withName("ClusterNodeZ") + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) // + .serving(new SlotRange(10923, 16383)) // + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeZ") // .build(); private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") // + .build(); private static final RedisClusterNode UNKNOWN_CLUSTER_NODE = new RedisClusterNode("8.8.8.8", 7379, SlotRange.empty()); @@ -166,7 +161,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -175,7 +170,7 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -183,7 +178,7 @@ void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -213,8 +208,8 @@ void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @@ -229,8 +224,8 @@ void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @@ -244,8 +239,8 @@ void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @@ -269,16 +264,17 @@ void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection2.theWheelWeavesAsTheWheelWills()) + .thenThrow(new IllegalStateException("(error) mat lost the dagger...")); when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { @@ -289,9 +285,9 @@ void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -317,8 +313,7 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, - new HashSet<>( - Arrays.asList("key-1".getBytes(), "key-2".getBytes(), "key-3".getBytes(), "key-9".getBytes()))); + new HashSet<>(Arrays.asList("key-1".getBytes(), "key-2".getBytes(), "key-3".getBytes(), "key-9".getBytes()))); assertThat(result.resultsAsList()).contains("rand", "mat", "perrin", "egwene"); @@ -329,32 +324,35 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(connection1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(connection3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(connection1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(connection3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(connection2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); - try { - executor.setMaxRedirects(4); + executor.setMaxRedirects(4); + + assertThatExceptionOfType(TooManyClusterRedirectionsException.class).isThrownBy(() -> { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - } catch (Exception e) { - assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); - } + }); verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(connection2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -362,83 +360,45 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), connection1, connection2, connection3); - } - - @Test // GH-2518 - void collectResultsCompletesSuccessfully() { - - Instant done = Instant.now().plusMillis(5); - - Predicate>> isDone = future -> Instant.now().isAfter(done); - - Map>> futures = new HashMap<>(); - - NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); - NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); - NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); - - futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, isDone)); - futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureAndIsDone(nodeTwoB, isDone)); - futures.put(newNodeExecution(CLUSTER_NODE_3), mockFutureAndIsDone(nodeThreeC, isDone)); - - MultiNodeResult results = this.executor.collectResults(futures); - - assertThat(results).isNotNull(); - assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); - - futures.values().forEach(future -> - runsSafely(() -> verify(future, times(1)).get(anyLong(), any(TimeUnit.class)))); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", connection1, connection2, connection3); } @Test // GH-2518 - void collectResultsCompletesSuccessfullyEvenWithTimeouts() throws Exception { + void collectResultsCompletesSuccessfullyAfterTimeouts() { Map>> futures = new HashMap<>(); - NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); - NodeResult nodeTwoB = newNodeResult(CLUSTER_NODE_2, "B"); - NodeResult nodeThreeC = newNodeResult(CLUSTER_NODE_3, "C"); + NodeResult nodeOneA = new NodeResult<>(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = new NodeResult<>(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = new NodeResult<>(CLUSTER_NODE_3, "C"); - Future> nodeOneFutureResult = mockFutureThrowingTimeoutException(nodeOneA, 4); - Future> nodeTwoFutureResult = mockFutureThrowingTimeoutException(nodeTwoB, 1); - Future> nodeThreeFutureResult = mockFutureThrowingTimeoutException(nodeThreeC, 2); + doWithScheduler(scheduler -> { - futures.put(newNodeExecution(CLUSTER_NODE_1), nodeOneFutureResult); - futures.put(newNodeExecution(CLUSTER_NODE_2), nodeTwoFutureResult); - futures.put(newNodeExecution(CLUSTER_NODE_3), nodeThreeFutureResult); + futures.put(new NodeExecution(CLUSTER_NODE_1), scheduler.schedule(() -> nodeOneA, 15, TimeUnit.MILLISECONDS)); + futures.put(new NodeExecution(CLUSTER_NODE_2), scheduler.schedule(() -> nodeTwoB, 15, TimeUnit.MILLISECONDS)); + futures.put(new NodeExecution(CLUSTER_NODE_3), scheduler.schedule(() -> nodeThreeC, 15, TimeUnit.MILLISECONDS)); - MultiNodeResult results = this.executor.collectResults(futures); + MultiNodeResult results = this.executor.collectResults(futures); - assertThat(results).isNotNull(); - assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); - - verify(nodeOneFutureResult, times(4)).get(anyLong(), any(TimeUnit.class)); - verify(nodeTwoFutureResult, times(1)).get(anyLong(), any(TimeUnit.class)); - verify(nodeThreeFutureResult, times(2)).get(anyLong(), any(TimeUnit.class)); - verifyNoMoreInteractions(nodeOneFutureResult, nodeTwoFutureResult, nodeThreeFutureResult); + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + }); } @Test // GH-2518 void collectResultsFailsWithExecutionException() { Map>> futures = new HashMap<>(); + NodeResult nodeOneA = new NodeResult<>(CLUSTER_NODE_1, "A"); - NodeResult nodeOneA = newNodeResult(CLUSTER_NODE_1, "A"); - - futures.put(newNodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(nodeOneA, future -> true)); - futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException( - new ExecutionException("TestError", new IllegalArgumentException("MockError")))); + futures.put(new NodeExecution(CLUSTER_NODE_1), CompletableFuture.completedFuture(nodeOneA)); + futures.put(new NodeExecution(CLUSTER_NODE_2), + CompletableFuture.failedFuture(new IllegalArgumentException("MockError"))); assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.executor.collectResults(futures)) - .withMessage("MockError") - .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) - .extracting(Throwable::getCause) - .extracting(Throwable::getCause) - .isInstanceOf(IllegalArgumentException.class) - .extracting(Throwable::getMessage) - .isEqualTo("MockError"); + .isThrownBy(() -> this.executor.collectResults(futures)) // + .withMessage("MockError") // + .withRootCauseInstanceOf(IllegalArgumentException.class); } @Test // GH-2518 @@ -446,8 +406,6 @@ void collectResultsFailsWithInterruptedException() throws Throwable { TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); } - // Future.get() for X will get called twice if at least one other Future is not done and Future.get() for X - // threw an ExecutionException in the previous iteration, thereby marking it as done! @Test // GH-2518 @SuppressWarnings("all") void collectResultsCallsFutureGetOnlyOnce() throws Exception { @@ -455,27 +413,22 @@ void collectResultsCallsFutureGetOnlyOnce() throws Exception { AtomicInteger count = new AtomicInteger(0); Map>> futures = new HashMap<>(); - Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, future -> - count.incrementAndGet() % 2 == 0); + Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, + future -> count.incrementAndGet() % 2 == 0); Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( new ExecutionException("TestError", new IllegalArgumentException("MockError"))); - futures.put(newNodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); - futures.put(newNodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); + futures.put(new NodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); + futures.put(new NodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.executor.collectResults(futures)); + .isThrownBy(() -> this.executor.collectResults(futures)); - verify(clusterNodeOneFutureResult, times(1)).get(anyLong(), any()); - verify(clusterNodeTwoFutureResult, times(1)).get(anyLong(), any()); + verify(clusterNodeOneFutureResult).get(anyLong(), any()); + verify(clusterNodeTwoFutureResult).get(anyLong(), any()); } - // Covers the case where Future.get() is mistakenly called multiple times, or if the Future.isDone() implementation - // does not properly take into account Future.get() throwing an ExecutionException during computation subsequently - // returning false instead of true. - // This should be properly handled by the "safeguard" (see collectResultsCallsFutureGetOnlyOnce()), but... - // just in case! The ExecutionException handler now stores the [DataAccess]Exception with Map.putIfAbsent(..). @Test // GH-2518 @SuppressWarnings("all") void collectResultsCapturesFirstExecutionExceptionOnly() { @@ -485,24 +438,21 @@ void collectResultsCapturesFirstExecutionExceptionOnly() { Map>> futures = new HashMap<>(); - futures.put(newNodeExecution(CLUSTER_NODE_1), + futures.put(new NodeExecution(CLUSTER_NODE_1), mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); - futures.put(newNodeExecution(CLUSTER_NODE_2), mockFutureThrowingExecutionException(() -> - new ExecutionException("TestError", new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + futures.put(new NodeExecution(CLUSTER_NODE_2), + mockFutureThrowingExecutionException(() -> new ExecutionException("TestError", + new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.executor.collectResults(futures)) - .withMessage("MockError0") - .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) - .extracting(Throwable::getCause) - .extracting(Throwable::getCause) - .isInstanceOf(IllegalStateException.class) - .extracting(Throwable::getMessage) - .isEqualTo("MockError0"); + .isThrownBy(() -> this.executor.collectResults(futures)) // + .withMessage("MockError0") // + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .withRootCauseInstanceOf(IllegalStateException.class); } - private Future mockFutureAndIsDone(T result, Predicate> isDone) { + private Future mockFutureAndIsDone(@Nullable T result, Predicate> isDone) { return MockitoUtils.mockFuture(result, future -> { doAnswer(invocation -> isDone.test(future)).when(future).isDone(); @@ -516,54 +466,32 @@ private Future mockFutureThrowingExecutionException(ExecutionException ex private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { - Answer getAnswer = invocationOnMock -> { throw exceptionSupplier.get(); }; - - return MockitoUtils.mockFuture(null, future -> { - doReturn(true).when(future).isDone(); - doAnswer(getAnswer).when(future).get(); - doAnswer(getAnswer).when(future).get(anyLong(), any()); - return future; - }); - } - - @SuppressWarnings("unchecked") - private Future mockFutureThrowingTimeoutException(T result, int timeoutCount) { - - AtomicInteger counter = new AtomicInteger(timeoutCount); - Answer getAnswer = invocationOnMock -> { - - if (counter.decrementAndGet() > 0) { - throw new TimeoutException("TIMES UP"); - } - - doReturn(true).when((Future>) invocationOnMock.getMock()).isDone(); - - return result; + throw exceptionSupplier.get(); }; - return MockitoUtils.mockFuture(result, future -> { - + return MockitoUtils.mockFuture(null, future -> { + doReturn(true).when(future).isDone(); doAnswer(getAnswer).when(future).get(); doAnswer(getAnswer).when(future).get(anyLong(), any()); - return future; }); } - private NodeExecution newNodeExecution(RedisClusterNode clusterNode) { - return new NodeExecution(clusterNode); - } - - private NodeResult newNodeResult(RedisClusterNode clusterNode, T value) { - return new NodeResult<>(clusterNode, value); - } - - private void runsSafely(ThrowableOperation operation) { + /** + * Performs the given action within the scope of a running {@link ScheduledExecutorService}. The scheduler is only + * valid during the callback and shut down after this method returns. + * + * @param callback the action to invoke. + */ + private void doWithScheduler(Consumer callback) { + ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(3); try { - operation.run(); - } catch (Throwable ignore) { } + callback.accept(scheduler); + } finally { + scheduler.shutdown(); + } } static class MockClusterNodeProvider implements ClusterTopologyProvider { @@ -581,28 +509,16 @@ class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { return CLUSTER_NODE_1.equals(clusterNode) ? connection1 - : CLUSTER_NODE_2.equals(clusterNode) ? connection2 - : CLUSTER_NODE_3.equals(clusterNode) ? connection3 - : null; + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 : CLUSTER_NODE_3.equals(clusterNode) ? connection3 : null; } @Override - public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - } - } - - interface ConnectionCommandCallback extends ClusterCommandCallback { - + public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) {} } - interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { + interface ConnectionCommandCallback extends ClusterCommandCallback {} - } - - @FunctionalInterface - interface ThrowableOperation { - void run() throws Throwable; - } + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback {} interface Connection { @@ -661,8 +577,8 @@ static class CollectResultsInterruptedMultithreadedTestCase extends Multithreade private static final CountDownLatch latch = new CountDownLatch(1); - private static final Comparator NODE_COMPARATOR = - Comparator.comparing(nodeExecution -> nodeExecution.getNode().getName()); + private static final Comparator NODE_COMPARATOR = Comparator + .comparing(nodeExecution -> nodeExecution.getNode().getName()); private final ClusterCommandExecutor clusterCommandExecutor; @@ -711,7 +627,7 @@ public void thread1() { this.collectResultsThread.setName("CollectResults Thread"); assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) - .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); assertThat(this.collectResultsThread.isInterrupted()).isTrue(); } @@ -735,7 +651,7 @@ public void finish() { try { verify(this.mockNodeOneFutureResult, never()).get(); - verify(this.mockNodeTwoFutureResult, times(1)).get(anyLong(), any()); + verify(this.mockNodeTwoFutureResult).get(anyLong(), any()); } catch (ExecutionException | InterruptedException | TimeoutException cause) { throw new RuntimeException(cause); } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index 318d47ec85..8b3c6770ac 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -15,13 +15,7 @@ */ package org.springframework.data.redis.test.util; -import static org.mockito.Mockito.anyBoolean; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockingDetails; -import static org.mockito.Mockito.withSettings; +import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.Collections; @@ -39,10 +33,9 @@ import org.mockito.quality.Strictness; import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; - -import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; +import org.springframework.util.function.ThrowingFunction; /** * Utilities for using {@literal Mockito} and creating {@link Object mock objects} in {@literal unit tests}. @@ -64,7 +57,7 @@ public abstract class MockitoUtils { * @see java.util.concurrent.Future */ @SuppressWarnings("unchecked") - public static @NonNull Future mockFuture(@Nullable T result) { + public static Future mockFuture(@Nullable T result) { try { @@ -77,9 +70,8 @@ public abstract class MockitoUtils { // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. - Answer cancelAnswer = invocation -> !done.get() - && cancelled.compareAndSet(done.get(), true) - && done.compareAndSet(done.get(), true); + Answer cancelAnswer = invocation -> !done.get() && cancelled.compareAndSet(done.get(), true) + && done.compareAndSet(done.get(), true); Answer getAnswer = invocation -> { @@ -104,34 +96,40 @@ public abstract class MockitoUtils { doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); return mockFuture; - } - catch (Exception cause) { + } catch (Exception cause) { String message = String.format("Failed to create a mock of Future having result [%s]", result); throw new IllegalStateException(message, cause); } } /** - * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, - * required {@link Function}. + * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, required + * {@link Function}. * * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. - * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; - * must not be {@literal null}. + * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; must not be + * {@literal null}. * @return a new mock {@link Future}. - * @see java.util.concurrent.Future - * @see java.util.function.Function * @see #mockFuture(Object) */ - public static @NonNull Future mockFuture(@Nullable T result, - @NonNull ThrowableFunction, Future> futureFunction) { + public static Future mockFuture(@Nullable T result, ThrowingFunction, Future> futureFunction) { Future mockFuture = mockFuture(result); return futureFunction.apply(mockFuture); } + /** + * Verifies a given method is called once across all given mocks. + * + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mocks array of {@link Object mock objects} to verify. + */ + public static void verifyInvocationsAcross(String method, Object... mocks) { + verifyInvocationsAcross(method, times(1), mocks); + } + /** * Verifies a given method is called a total number of times across all given mocks. * @@ -141,19 +139,19 @@ public abstract class MockitoUtils { */ public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { - mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections - .singletonList(org.mockito.internal.matchers.Any.ANY)) { + mode.verify(new VerificationDataImpl(getInvocations(method, mocks), + new InvocationMatcher(null, Collections.singletonList(org.mockito.internal.matchers.Any.ANY)) { - @Override - public boolean matches(Invocation actual) { - return true; - } + @Override + public boolean matches(Invocation actual) { + return true; + } - @Override - public String toString() { - return String.format("%s for method: %s", mode, method); - } - })); + @Override + public String toString() { + return String.format("%s for method: %s", mode, method); + } + })); } private static List getInvocations(String method, Object... mocks) { @@ -196,29 +194,4 @@ public MatchableInvocation getTarget() { } } - @FunctionalInterface - public interface ThrowableFunction extends Function { - - @Override - default R apply(T target) { - - try { - return applyThrowingException(target); - } - catch (Throwable cause) { - String message = String.format("Failed to apply Function [%s] to target [%s]", this, target); - throw new IllegalStateException(message, cause); - } - } - - R applyThrowingException(T target) throws Throwable; - - @SuppressWarnings("unchecked") - default @NonNull ThrowableFunction andThen( - @Nullable ThrowableFunction after) { - - return after == null ? (ThrowableFunction) this - : target -> after.apply(this.apply(target)); - } - } }