diff --git a/pom.xml b/pom.xml index 778dd177dc..b804e93073 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-redis - 3.2.0-SNAPSHOT + 3.2.0-GH-2518-SNAPSHOT Spring Data Redis Spring Data module for Redis diff --git a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java index f016dfc33f..bfcf3a7853 100644 --- a/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java +++ b/src/main/java/org/springframework/data/redis/connection/ClusterCommandExecutor.java @@ -17,8 +17,11 @@ import java.util.*; import java.util.Map.Entry; +import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Function; import java.util.stream.Collectors; @@ -43,6 +46,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum * @since 1.7 */ public class ClusterCommandExecutor implements DisposableBean { @@ -58,7 +62,7 @@ public class ClusterCommandExecutor implements DisposableBean { private final ExceptionTranslationStrategy exceptionTranslationStrategy; /** - * Create a new instance of {@link ClusterCommandExecutor}. + * Create a new {@link ClusterCommandExecutor}. * * @param topologyProvider must not be {@literal null}. * @param resourceProvider must not be {@literal null}. @@ -101,31 +105,37 @@ public NodeResult executeCommandOnArbitraryNode(ClusterCommandCallback nodes = new ArrayList<>(getClusterTopology().getActiveNodes()); - return executeCommandOnSingleNode(commandCallback, nodes.get(new Random().nextInt(nodes.size()))); + RedisClusterNode arbitraryNode = nodes.get(new Random().nextInt(nodes.size())); + + return executeCommandOnSingleNode(commandCallback, arbitraryNode); } /** * Run {@link ClusterCommandCallback} on given {@link RedisClusterNode}. * - * @param cmd must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @param node must not be {@literal null}. * @return the {@link NodeResult} from the single, targeted {@link RedisClusterNode}. * @throws IllegalArgumentException in case no resource can be acquired for given node. */ - public NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node) { - return executeCommandOnSingleNode(cmd, node, 0); + public NodeResult executeCommandOnSingleNode(ClusterCommandCallback commandCallback, + RedisClusterNode node) { + + return executeCommandOnSingleNode(commandCallback, node, 0); } - private NodeResult executeCommandOnSingleNode(ClusterCommandCallback cmd, RedisClusterNode node, - int redirectCount) { + private NodeResult executeCommandOnSingleNode(ClusterCommandCallback commandCallback, + RedisClusterNode node, int redirectCount) { - Assert.notNull(cmd, "ClusterCommandCallback must not be null"); + Assert.notNull(commandCallback, "ClusterCommandCallback must not be null"); Assert.notNull(node, "RedisClusterNode must not be null"); - if (redirectCount > maxRedirects) { + if (redirectCount > this.maxRedirects) { + throw new TooManyClusterRedirectionsException(String.format( - "Cannot follow Cluster Redirects over more than %s legs; Please consider increasing the number of redirects to follow; Current value is: %s.", - redirectCount, maxRedirects)); + "Cannot follow Cluster Redirects over more than %s legs; " + + "Consider increasing the number of redirects to follow; Current value is: %s.", + redirectCount, this.maxRedirects)); } RedisClusterNode nodeToUse = lookupNode(node); @@ -135,13 +145,13 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback(node, cmd.doInCluster(client)); + return new NodeResult<>(node, commandCallback.doInCluster(client)); } catch (RuntimeException cause) { RuntimeException translatedException = convertToDataAccessException(cause); if (translatedException instanceof ClusterRedirectException clusterRedirectException) { - return executeCommandOnSingleNode(cmd, topologyProvider.getTopology().lookup( + return executeCommandOnSingleNode(commandCallback, topologyProvider.getTopology().lookup( clusterRedirectException.getTargetHost(), clusterRedirectException.getTargetPort()), redirectCount + 1); } else { throw translatedException != null ? translatedException : cause; @@ -152,10 +162,11 @@ private NodeResult executeCommandOnSingleNode(ClusterCommandCallback MultiNodeResult executeCommandOnAllNodes(final ClusterCommandCallback cmd) { - return executeCommandAsyncOnNodes(cmd, getClusterTopology().getActiveMasterNodes()); + public MultiNodeResult executeCommandOnAllNodes(ClusterCommandCallback commandCallback) { + return executeCommandAsyncOnNodes(commandCallback, getClusterTopology().getActiveMasterNodes()); } /** - * @param callback must not be {@literal null}. + * @param commandCallback must not be {@literal null}. * @param nodes must not be {@literal null}. * @return never {@literal null}. * @throws ClusterCommandExecutionFailureException if a failure occurs while executing the given * {@link ClusterCommandCallback command} on any given {@link RedisClusterNode node}. * @throws IllegalArgumentException in case the node could not be resolved to a topology-known node */ - public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback callback, + public MultiNodeResult executeCommandAsyncOnNodes(ClusterCommandCallback commandCallback, Iterable nodes) { - Assert.notNull(callback, "Callback must not be null"); + Assert.notNull(commandCallback, "Callback must not be null"); Assert.notNull(nodes, "Nodes must not be null"); + ClusterTopology topology = this.topologyProvider.getTopology(); List resolvedRedisClusterNodes = new ArrayList<>(); - ClusterTopology topology = topologyProvider.getTopology(); for (RedisClusterNode node : nodes) { try { resolvedRedisClusterNodes.add(topology.lookup(node)); - } catch (ClusterStateFailureException e) { - throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), e); + } catch (ClusterStateFailureException cause) { + throw new IllegalArgumentException(String.format("Node %s is unknown to cluster", node), cause); } } Map>> futures = new LinkedHashMap<>(); for (RedisClusterNode node : resolvedRedisClusterNodes) { - futures.put(new NodeExecution(node), executor.submit(() -> executeCommandOnSingleNode(callback, node))); + Callable> nodeCommandExecution = () -> executeCommandOnSingleNode(commandCallback, node); + futures.put(new NodeExecution(node), executor.submit(nodeCommandExecution)); } return collectResults(futures); } - private MultiNodeResult collectResults(Map>> futures) { + MultiNodeResult collectResults(Map>> futures) { - boolean done = false; - - Map exceptions = new HashMap<>(); + NodeExceptionCollector exceptionCollector = new NodeExceptionCollector(); MultiNodeResult result = new MultiNodeResult<>(); - Set saveGuard = new HashSet<>(); - - while (!done) { + Object placeholder = new Object(); + Map>, Object> safeguard = new IdentityHashMap<>(); - done = true; + for (;;) { + boolean timeout = false; for (Map.Entry>> entry : futures.entrySet()) { - if (!entry.getValue().isDone() && !entry.getValue().isCancelled()) { - done = false; - } else { + NodeExecution nodeExecution = entry.getKey(); + Future> futureNodeResult = entry.getValue(); - NodeExecution execution = entry.getKey(); + try { - try { + if (!safeguard.containsKey(futureNodeResult)) { - String futureId = ObjectUtils.getIdentityHexString(entry.getValue()); + NodeResult nodeResult = futureNodeResult.get(10L, TimeUnit.MICROSECONDS); - if (!saveGuard.contains(futureId)) { - - if (execution.isPositional()) { - result.add(execution.getPositionalKey(), entry.getValue().get()); - } else { - result.add(entry.getValue().get()); - } - - saveGuard.add(futureId); + if (nodeExecution.isPositional()) { + result.add(nodeExecution.getPositionalKey(), nodeResult); + } else { + result.add(nodeResult); } - } catch (ExecutionException cause) { - - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - } catch (InterruptedException cause) { - - Thread.currentThread().interrupt(); - - RuntimeException exception = convertToDataAccessException((Exception) cause.getCause()); - exceptions.put(execution.getNode(), exception != null ? exception : cause.getCause()); - - break; + safeguard.put(futureNodeResult, placeholder); } + } catch (ExecutionException exception) { + safeguard.put(futureNodeResult, placeholder); + exceptionCollector.addException(nodeExecution, exception.getCause()); + } catch (TimeoutException ignore) { + timeout = true; + } catch (InterruptedException exception) { + Thread.currentThread().interrupt(); + exceptionCollector.addException(nodeExecution, exception); + break; } } - try { - Thread.sleep(10); - } catch (InterruptedException e) { - done = true; - Thread.currentThread().interrupt(); + if (!timeout) { + break; } } - if (!exceptions.isEmpty()) { - throw new ClusterCommandExecutionFailureException(new ArrayList<>(exceptions.values())); + if (exceptionCollector.hasExceptions()) { + throw new ClusterCommandExecutionFailureException(exceptionCollector.getExceptions()); } return result; @@ -328,10 +326,11 @@ private NodeResult executeMultiKeyCommandOnSingleNode(MultiKeyClusterC try { return new NodeResult<>(node, commandCallback.doInCluster(client, key), key); - } catch (RuntimeException ex) { + } catch (RuntimeException cause) { + + RuntimeException translatedException = convertToDataAccessException(cause); - RuntimeException translatedException = convertToDataAccessException(ex); - throw translatedException != null ? translatedException : ex; + throw translatedException != null ? translatedException : cause; } finally { this.resourceProvider.returnResourceForSpecificNode(node, client); } @@ -343,7 +342,7 @@ private ClusterTopology getClusterTopology() { @Nullable private DataAccessException convertToDataAccessException(Exception cause) { - return exceptionTranslationStrategy.translate(cause); + return this.exceptionTranslationStrategy.translate(cause); } /** @@ -395,7 +394,7 @@ public interface MultiKeyClusterCommandCallback { * @author Mark Paluch * @since 1.7 */ - private static class NodeExecution { + static class NodeExecution { private final RedisClusterNode node; private final @Nullable PositionalKey positionalKey; @@ -414,7 +413,7 @@ private static class NodeExecution { * Get the {@link RedisClusterNode} the execution happens on. */ RedisClusterNode getNode() { - return node; + return this.node; } /** @@ -423,30 +422,31 @@ RedisClusterNode getNode() { * @since 2.0.3 */ PositionalKey getPositionalKey() { - return positionalKey; + return this.positionalKey; } boolean isPositional() { - return positionalKey != null; + return this.positionalKey != null; } } /** - * {@link NodeResult} encapsulates the actual value returned by a {@link ClusterCommandCallback} on a given + * {@link NodeResult} encapsulates the actual {@link T value} returned by a {@link ClusterCommandCallback} on a given * {@link RedisClusterNode}. * + * @param {@link Class Type} of the {@link Object value} returned in the result. * @author Christoph Strobl - * @param + * @author John Blum * @since 1.7 */ public static class NodeResult { - private RedisClusterNode node; - private @Nullable T value; - private ByteArrayWrapper key; + private final RedisClusterNode node; + private final ByteArrayWrapper key; + private final @Nullable T value; /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -456,7 +456,7 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { } /** - * Create new {@link NodeResult}. + * Create a new {@link NodeResult}. * * @param node must not be {@literal null}. * @param value can be {@literal null}. @@ -465,37 +465,36 @@ public NodeResult(RedisClusterNode node, @Nullable T value) { public NodeResult(RedisClusterNode node, @Nullable T value, byte[] key) { this.node = node; - this.value = value; - this.key = new ByteArrayWrapper(key); + this.value = value; } /** - * Get the actual value of the command execution. + * Get the {@link RedisClusterNode} the command was executed on. * - * @return can be {@literal null}. + * @return never {@literal null}. */ - @Nullable - public T getValue() { - return value; + public RedisClusterNode getNode() { + return this.node; } /** - * Get the {@link RedisClusterNode} the command was executed on. + * Return the {@link byte[] key} mapped to the value stored in Redis. * - * @return never {@literal null}. + * @return a {@link byte[] byte array} of the key mapped to the value stored in Redis. */ - public RedisClusterNode getNode() { - return node; + public byte[] getKey() { + return this.key.getArray(); } /** - * Returns the key as an array of bytes. + * Get the actual value of the command execution. * - * @return the key as an array of bytes. + * @return can be {@literal null}. */ - public byte[] getKey() { - return key.getArray(); + @Nullable + public T getValue() { + return this.value; } /** @@ -513,6 +512,33 @@ public U mapValue(Function mapper) { return mapper.apply(getValue()); } + + @Override + public boolean equals(@Nullable Object obj) { + + if (obj == this) { + return true; + } + + if (!(obj instanceof NodeResult that)) { + return false; + } + + return ObjectUtils.nullSafeEquals(this.getNode(), that.getNode()) && Objects.equals(this.key, that.key) + && Objects.equals(this.getValue(), that.getValue()); + } + + @Override + public int hashCode() { + + int hashValue = 17; + + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getNode()); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(this.key); + hashValue = 37 * hashValue + ObjectUtils.nullSafeHashCode(getValue()); + + return hashValue; + } } /** @@ -566,12 +592,14 @@ public List resultsAsListSortBy(byte[]... keys) { if (positionalResults.isEmpty()) { List> clone = new ArrayList<>(nodeResults); + clone.sort(new ResultByReferenceKeyPositionComparator(keys)); return toList(clone); } Map> result = new TreeMap<>(new ResultByKeyPositionComparator(keys)); + result.putAll(positionalResults); return result.values().stream().map(tNodeResult -> tNodeResult.value).collect(Collectors.toList()); @@ -604,10 +632,12 @@ public T getFirstNonNullNotEmptyOrDefault(@Nullable T returnValue) { private List toList(Collection> source) { - ArrayList result = new ArrayList<>(); + List result = new ArrayList<>(); + for (NodeResult nodeResult : source) { result.add(nodeResult.getValue()); } + return result; } @@ -678,7 +708,7 @@ static PositionalKey of(byte[] key, int index) { * @return binary key. */ byte[] getBytes() { - return key.getArray(); + return getKey().getArray(); } public ByteArrayWrapper getKey() { @@ -690,23 +720,22 @@ public int getPosition() { } @Override - public boolean equals(@Nullable Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; + public boolean equals(@Nullable Object obj) { - PositionalKey that = (PositionalKey) o; + if (this == obj) { + return true; + } - if (position != that.position) + if (!(obj instanceof PositionalKey that)) return false; - return ObjectUtils.nullSafeEquals(key, that.key); + + return this.getPosition() == that.getPosition() && ObjectUtils.nullSafeEquals(this.getKey(), that.getKey()); } @Override public int hashCode() { - int result = ObjectUtils.nullSafeHashCode(key); - result = 31 * result + position; + int result = ObjectUtils.nullSafeHashCode(getKey()); + result = 31 * result + ObjectUtils.nullSafeHashCode(getPosition()); return result; } } @@ -753,6 +782,7 @@ static PositionalKeys of(byte[]... keys) { static PositionalKeys of(PositionalKey... keys) { PositionalKeys result = PositionalKeys.empty(); + result.append(keys); return result; @@ -769,12 +799,42 @@ void append(PositionalKey... keys) { * @return index of the {@link PositionalKey}. */ int indexOf(PositionalKey key) { - return keys.indexOf(key); + return this.keys.indexOf(key); } @Override public Iterator iterator() { - return keys.iterator(); + return this.keys.iterator(); + } + } + + /** + * Collector for exceptions. Applies translation of exceptions if possible. + */ + private class NodeExceptionCollector { + + private final Map exceptions = new HashMap<>(); + + /** + * @return {@code true} if the collector contains at least one exception. + */ + public boolean hasExceptions() { + return !exceptions.isEmpty(); + } + + public void addException(NodeExecution execution, Throwable throwable) { + + Throwable translated = throwable instanceof Exception e ? convertToDataAccessException(e) : throwable; + Throwable resolvedException = translated != null ? translated : throwable; + + exceptions.putIfAbsent(execution.getNode(), resolvedException); + } + + /** + * @return the collected exceptions. + */ + public List getExceptions() { + return new ArrayList<>(exceptions.values()); } } } diff --git a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java index f99e786ce4..1da31a7f73 100644 --- a/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java +++ b/src/test/java/org/springframework/data/redis/connection/ClusterCommandExecutorUnitTests.java @@ -16,15 +16,34 @@ package org.springframework.data.redis.connection; import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; import static org.springframework.data.redis.test.util.MockitoUtils.*; +import edu.umd.cs.mtc.MultithreadedTestCase; +import edu.umd.cs.mtc.TestFramework; + import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.function.Supplier; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -33,7 +52,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; - +import org.mockito.stubbing.Answer; import org.springframework.core.convert.converter.Converter; import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.core.task.SyncTaskExecutor; @@ -45,14 +64,22 @@ import org.springframework.data.redis.connection.ClusterCommandExecutor.ClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiKeyClusterCommandCallback; import org.springframework.data.redis.connection.ClusterCommandExecutor.MultiNodeResult; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeExecution; +import org.springframework.data.redis.connection.ClusterCommandExecutor.NodeResult; import org.springframework.data.redis.connection.RedisClusterNode.LinkState; import org.springframework.data.redis.connection.RedisClusterNode.SlotRange; import org.springframework.data.redis.connection.RedisNode.NodeType; +import org.springframework.data.redis.test.util.MockitoUtils; +import org.springframework.lang.Nullable; import org.springframework.scheduling.concurrent.ConcurrentTaskExecutor; /** + * Unit Tests for {@link ClusterCommandExecutor}. + * * @author Christoph Strobl * @author Mark Paluch + * @author John Blum + * @since 1.7 */ @ExtendWith(MockitoExtension.class) class ClusterCommandExecutorUnitTests { @@ -66,19 +93,35 @@ class ClusterCommandExecutorUnitTests { private static final int CLUSTER_NODE_3_PORT = 7381; private static final RedisClusterNode CLUSTER_NODE_1 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT).serving(new SlotRange(0, 5460)) - .withId("ef570f86c7b1a953846668debc177a3a16733420").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT) // + .serving(new SlotRange(0, 5460)) // + .withId("ef570f86c7b1a953846668debc177a3a16733420") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeX") // .build(); + private static final RedisClusterNode CLUSTER_NODE_2 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT).serving(new SlotRange(5461, 10922)) - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT) // + .serving(new SlotRange(5461, 10922)) // + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeY") // .build(); + private static final RedisClusterNode CLUSTER_NODE_3 = RedisClusterNode.newRedisClusterNode() - .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT).serving(new SlotRange(10923, 16383)) - .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7").promotedAs(NodeType.MASTER).linkState(LinkState.CONNECTED) + .listeningAt(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT) // + .serving(new SlotRange(10923, 16383)) // + .withId("3b9b8192a874fa8f1f09dbc0ee20afab5738eee7") // + .promotedAs(NodeType.MASTER) // + .linkState(LinkState.CONNECTED) // + .withName("ClusterNodeZ") // .build(); + private static final RedisClusterNode CLUSTER_NODE_2_LOOKUP = RedisClusterNode.newRedisClusterNode() - .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84").build(); + .withId("0f2ee5df45d18c50aca07228cc18b1da96fd5e84") // + .build(); private static final RedisClusterNode UNKNOWN_CLUSTER_NODE = new RedisClusterNode("8.8.8.8", 7379, SlotRange.empty()); @@ -88,8 +131,8 @@ class ClusterCommandExecutorUnitTests { private static final Converter exceptionConverter = source -> { - if (source instanceof MovedException) { - return new ClusterRedirectException(1000, ((MovedException) source).host, ((MovedException) source).port, source); + if (source instanceof MovedException movedException) { + return new ClusterRedirectException(1000, movedException.host, movedException.port, source); } return new InvalidDataAccessApiUsageException(source.getMessage(), source); @@ -97,14 +140,14 @@ class ClusterCommandExecutorUnitTests { private static final MultiKeyConnectionCommandCallback MULTIKEY_CALLBACK = Connection::bloodAndAshes; - @Mock Connection con1; - @Mock Connection con2; - @Mock Connection con3; + @Mock Connection connection1; + @Mock Connection connection2; + @Mock Connection connection3; @BeforeEach void setUp() { - this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterResourceProvider(), + this.executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ImmediateExecutor()); } @@ -118,7 +161,7 @@ void executeCommandOnSingleNodeShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_2); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -127,7 +170,7 @@ void executeCommandOnSingleNodeByHostAndPortShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -135,15 +178,17 @@ void executeCommandOnSingleNodeByNodeIdShouldBeExecutedCorrectly() { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, new RedisClusterNode(CLUSTER_NODE_2.id)); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(COMMAND_CALLBACK, null)); } @Test // DATAREDIS-315 + @SuppressWarnings("all") void executeCommandOnSingleNodeShouldThrowExceptionWhenCommandCallbackIsNull() { assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandOnSingleNode(null, CLUSTER_NODE_1)); } @@ -158,52 +203,52 @@ void executeCommandOnSingleNodeShouldThrowExceptionWhenNodeIsUnknown() { void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByHostAndPort() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT), new RedisClusterNode(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT))); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldExecuteCommandOnGivenNodesByNodeId() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, Arrays.asList(new RedisClusterNode(CLUSTER_NODE_1.id), CLUSTER_NODE_2_LOOKUP)); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); assertThatIllegalArgumentException().isThrownBy(() -> executor.executeCommandAsyncOnNodes(COMMAND_CALLBACK, @@ -214,42 +259,43 @@ void executeCommandAsyncOnNodesShouldFailOnGivenUnknownNodes() { void executeCommandOnAllNodesShouldExecuteCommandOnEveryKnownClusterNode() { ClusterCommandExecutor executor = new ClusterCommandExecutor(new MockClusterNodeProvider(), - new MockClusterResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), + new MockClusterNodeResourceProvider(), new PassThroughExceptionTranslationStrategy(exceptionConverter), new ConcurrentTaskExecutor(new SyncTaskExecutor())); executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCompleteAndCollectErrorsOfAllNodes() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new IllegalStateException("(error) mat lost the dagger...")); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()) + .thenThrow(new IllegalStateException("(error) mat lost the dagger...")); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); try { executor.executeCommandOnAllNodes(COMMAND_CALLBACK); - } catch (ClusterCommandExecutionFailureException e) { + } catch (ClusterCommandExecutionFailureException cause) { - assertThat(e.getSuppressed()).hasSize(1); - assertThat(e.getSuppressed()[0]).isInstanceOf(DataAccessException.class); + assertThat(cause.getSuppressed()).hasSize(1); + assertThat(cause.getSuppressed()[0]).isInstanceOf(DataAccessException.class); } - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandAsyncOnNodesShouldCollectResultsCorrectly() { - when(con1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); - when(con2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); - when(con3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); + when(connection1.theWheelWeavesAsTheWheelWills()).thenReturn("rand"); + when(connection2.theWheelWeavesAsTheWheelWills()).thenReturn("mat"); + when(connection3.theWheelWeavesAsTheWheelWills()).thenReturn("perrin"); MultiNodeResult result = executor.executeCommandOnAllNodes(COMMAND_CALLBACK); @@ -261,14 +307,13 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { // key-1 and key-9 map both to node1 ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); - when(con1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); - when(con2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); - when(con3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); + when(connection1.bloodAndAshes(captor.capture())).thenReturn("rand").thenReturn("egwene"); + when(connection2.bloodAndAshes(any(byte[].class))).thenReturn("mat"); + when(connection3.bloodAndAshes(any(byte[].class))).thenReturn("perrin"); MultiNodeResult result = executor.executeMultiKeyCommand(MULTIKEY_CALLBACK, - new HashSet<>( - Arrays.asList("key-1".getBytes(), "key-2".getBytes(), "key-3".getBytes(), "key-9".getBytes()))); + new HashSet<>(Arrays.asList("key-1".getBytes(), "key-2".getBytes(), "key-3".getBytes(), "key-9".getBytes()))); assertThat(result.resultsAsList()).contains("rand", "mat", "perrin", "egwene"); @@ -279,32 +324,35 @@ void executeMultikeyCommandShouldRunCommandAcrossCluster() { @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirect() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - verify(con1, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(1)).theWheelWeavesAsTheWheelWills(); - verify(con2, never()).theWheelWeavesAsTheWheelWills(); + verify(connection1).theWheelWeavesAsTheWheelWills(); + verify(connection3).theWheelWeavesAsTheWheelWills(); + verify(connection2, never()).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 void executeCommandOnSingleNodeAndFollowRedirectButStopsAfterMaxRedirects() { - when(con1.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); - when(con3.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); - when(con2.theWheelWeavesAsTheWheelWills()).thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); + when(connection1.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_3_HOST, CLUSTER_NODE_3_PORT)); + when(connection3.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_2_HOST, CLUSTER_NODE_2_PORT)); + when(connection2.theWheelWeavesAsTheWheelWills()) + .thenThrow(new MovedException(CLUSTER_NODE_1_HOST, CLUSTER_NODE_1_PORT)); - try { - executor.setMaxRedirects(4); + executor.setMaxRedirects(4); + + assertThatExceptionOfType(TooManyClusterRedirectionsException.class).isThrownBy(() -> { executor.executeCommandOnSingleNode(COMMAND_CALLBACK, CLUSTER_NODE_1); - } catch (Exception e) { - assertThat(e).isInstanceOf(TooManyClusterRedirectionsException.class); - } + }); - verify(con1, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con3, times(2)).theWheelWeavesAsTheWheelWills(); - verify(con2, times(1)).theWheelWeavesAsTheWheelWills(); + verify(connection1, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection3, times(2)).theWheelWeavesAsTheWheelWills(); + verify(connection2).theWheelWeavesAsTheWheelWills(); } @Test // DATAREDIS-315 @@ -312,53 +360,167 @@ void executeCommandOnArbitraryNodeShouldPickARandomNode() { executor.executeCommandOnArbitraryNode(COMMAND_CALLBACK); - verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", times(1), con1, con2, con3); + verifyInvocationsAcross("theWheelWeavesAsTheWheelWills", connection1, connection2, connection3); } - class MockClusterNodeProvider implements ClusterTopologyProvider { + @Test // GH-2518 + void collectResultsCompletesSuccessfullyAfterTimeouts() { - @Override - public ClusterTopology getTopology() { - return new ClusterTopology( - new LinkedHashSet<>(Arrays.asList(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3))); - } + Map>> futures = new HashMap<>(); + + NodeResult nodeOneA = new NodeResult<>(CLUSTER_NODE_1, "A"); + NodeResult nodeTwoB = new NodeResult<>(CLUSTER_NODE_2, "B"); + NodeResult nodeThreeC = new NodeResult<>(CLUSTER_NODE_3, "C"); + doWithScheduler(scheduler -> { + + futures.put(new NodeExecution(CLUSTER_NODE_1), scheduler.schedule(() -> nodeOneA, 15, TimeUnit.MILLISECONDS)); + futures.put(new NodeExecution(CLUSTER_NODE_2), scheduler.schedule(() -> nodeTwoB, 15, TimeUnit.MILLISECONDS)); + futures.put(new NodeExecution(CLUSTER_NODE_3), scheduler.schedule(() -> nodeThreeC, 15, TimeUnit.MILLISECONDS)); + + MultiNodeResult results = this.executor.collectResults(futures); + + assertThat(results).isNotNull(); + assertThat(results.getResults()).containsExactlyInAnyOrder(nodeOneA, nodeTwoB, nodeThreeC); + }); } - class MockClusterResourceProvider implements ClusterNodeResourceProvider { + @Test // GH-2518 + void collectResultsFailsWithExecutionException() { - @Override - public Connection getResourceForSpecificNode(RedisClusterNode node) { + Map>> futures = new HashMap<>(); + NodeResult nodeOneA = new NodeResult<>(CLUSTER_NODE_1, "A"); - if (CLUSTER_NODE_1.equals(node)) { - return con1; - } - if (CLUSTER_NODE_2.equals(node)) { - return con2; - } - if (CLUSTER_NODE_3.equals(node)) { - return con3; - } + futures.put(new NodeExecution(CLUSTER_NODE_1), CompletableFuture.completedFuture(nodeOneA)); + futures.put(new NodeExecution(CLUSTER_NODE_2), + CompletableFuture.failedFuture(new IllegalArgumentException("MockError"))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) // + .withMessage("MockError") // + .withRootCauseInstanceOf(IllegalArgumentException.class); + } + + @Test // GH-2518 + void collectResultsFailsWithInterruptedException() throws Throwable { + TestFramework.runOnce(new CollectResultsInterruptedMultithreadedTestCase(this.executor)); + } + + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCallsFutureGetOnlyOnce() throws Exception { + + AtomicInteger count = new AtomicInteger(0); + Map>> futures = new HashMap<>(); - return null; + Future> clusterNodeOneFutureResult = mockFutureAndIsDone(null, + future -> count.incrementAndGet() % 2 == 0); + + Future> clusterNodeTwoFutureResult = mockFutureThrowingExecutionException( + new ExecutionException("TestError", new IllegalArgumentException("MockError"))); + + futures.put(new NodeExecution(CLUSTER_NODE_1), clusterNodeOneFutureResult); + futures.put(new NodeExecution(CLUSTER_NODE_2), clusterNodeTwoFutureResult); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)); + + verify(clusterNodeOneFutureResult).get(anyLong(), any()); + verify(clusterNodeTwoFutureResult).get(anyLong(), any()); + } + + @Test // GH-2518 + @SuppressWarnings("all") + void collectResultsCapturesFirstExecutionExceptionOnly() { + + AtomicInteger count = new AtomicInteger(0); + AtomicInteger exceptionCount = new AtomicInteger(0); + + Map>> futures = new HashMap<>(); + + futures.put(new NodeExecution(CLUSTER_NODE_1), + mockFutureAndIsDone(null, future -> count.incrementAndGet() % 2 == 0)); + + futures.put(new NodeExecution(CLUSTER_NODE_2), + mockFutureThrowingExecutionException(() -> new ExecutionException("TestError", + new IllegalStateException("MockError" + exceptionCount.getAndIncrement())))); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.executor.collectResults(futures)) // + .withMessage("MockError0") // + .withCauseInstanceOf(InvalidDataAccessApiUsageException.class) + .withRootCauseInstanceOf(IllegalStateException.class); + } + + private Future mockFutureAndIsDone(@Nullable T result, Predicate> isDone) { + + return MockitoUtils.mockFuture(result, future -> { + doAnswer(invocation -> isDone.test(future)).when(future).isDone(); + return future; + }); + } + + private Future mockFutureThrowingExecutionException(ExecutionException exception) { + return mockFutureThrowingExecutionException(() -> exception); + } + + private Future mockFutureThrowingExecutionException(Supplier exceptionSupplier) { + + Answer getAnswer = invocationOnMock -> { + throw exceptionSupplier.get(); + }; + + return MockitoUtils.mockFuture(null, future -> { + doReturn(true).when(future).isDone(); + doAnswer(getAnswer).when(future).get(); + doAnswer(getAnswer).when(future).get(anyLong(), any()); + return future; + }); + } + + /** + * Performs the given action within the scope of a running {@link ScheduledExecutorService}. The scheduler is only + * valid during the callback and shut down after this method returns. + * + * @param callback the action to invoke. + */ + private void doWithScheduler(Consumer callback) { + + ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(3); + try { + callback.accept(scheduler); + } finally { + scheduler.shutdown(); } + } + + static class MockClusterNodeProvider implements ClusterTopologyProvider { @Override - public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) { - // TODO Auto-generated method stub + public ClusterTopology getTopology() { + return new ClusterTopology(Set.of(CLUSTER_NODE_1, CLUSTER_NODE_2, CLUSTER_NODE_3)); } - } - static interface ConnectionCommandCallback extends ClusterCommandCallback { + class MockClusterNodeResourceProvider implements ClusterNodeResourceProvider { - } + @Override + @SuppressWarnings("all") + public Connection getResourceForSpecificNode(RedisClusterNode clusterNode) { - static interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback { + return CLUSTER_NODE_1.equals(clusterNode) ? connection1 + : CLUSTER_NODE_2.equals(clusterNode) ? connection2 : CLUSTER_NODE_3.equals(clusterNode) ? connection3 : null; + } + @Override + public void returnResourceForSpecificNode(RedisClusterNode node, Object resource) {} } - static interface Connection { + interface ConnectionCommandCallback extends ClusterCommandCallback {} + + interface MultiKeyConnectionCommandCallback extends MultiKeyClusterCommandCallback {} + + interface Connection { String theWheelWeavesAsTheWheelWills(); @@ -374,19 +536,20 @@ static class MovedException extends RuntimeException { this.host = host; this.port = port; } - } static class ImmediateExecutor implements AsyncTaskExecutor { @Override - public void execute(Runnable runnable, long l) { + public void execute(Runnable runnable) { runnable.run(); } @Override public Future submit(Runnable runnable) { + return submit(() -> { + runnable.run(); return null; @@ -395,19 +558,103 @@ public Future submit(Runnable runnable) { @Override public Future submit(Callable callable) { + try { return CompletableFuture.completedFuture(callable.call()); - } catch (Exception e) { + } catch (Exception cause) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(e); + + future.completeExceptionally(cause); + return future; } } + } + + @SuppressWarnings("all") + static class CollectResultsInterruptedMultithreadedTestCase extends MultithreadedTestCase { + + private static final CountDownLatch latch = new CountDownLatch(1); + + private static final Comparator NODE_COMPARATOR = Comparator + .comparing(nodeExecution -> nodeExecution.getNode().getName()); + + private final ClusterCommandExecutor clusterCommandExecutor; + + private final Map>> futureNodeResults; + + private Future> mockNodeOneFutureResult; + private Future> mockNodeTwoFutureResult; + + private volatile Thread collectResultsThread; + + private CollectResultsInterruptedMultithreadedTestCase(ClusterCommandExecutor clusterCommandExecutor) { + this.clusterCommandExecutor = clusterCommandExecutor; + this.futureNodeResults = new ConcurrentSkipListMap<>(NODE_COMPARATOR); + } @Override - public void execute(Runnable runnable) { - runnable.run(); + public void initialize() { + + super.initialize(); + + this.mockNodeOneFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_1), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + doReturn(false).when(mockFuture).isDone(); + return mockFuture; + })); + + this.mockNodeTwoFutureResult = this.futureNodeResults.computeIfAbsent(new NodeExecution(CLUSTER_NODE_2), + nodeExecution -> MockitoUtils.mockFuture(null, mockFuture -> { + + doReturn(true).when(mockFuture).isDone(); + + doAnswer(invocation -> { + latch.await(); + return null; + }).when(mockFuture).get(anyLong(), any()); + + return mockFuture; + })); + } + + public void thread1() { + + assertTick(0); + + this.collectResultsThread = Thread.currentThread(); + this.collectResultsThread.setName("CollectResults Thread"); + + assertThatExceptionOfType(ClusterCommandExecutionFailureException.class) + .isThrownBy(() -> this.clusterCommandExecutor.collectResults(this.futureNodeResults)); + + assertThat(this.collectResultsThread.isInterrupted()).isTrue(); + } + + public void thread2() { + + assertTick(0); + + Thread.currentThread().setName("Interrupting Thread"); + + waitForTick(1); + + assertThat(this.collectResultsThread).isNotNull(); + assertThat(this.collectResultsThread.getName()).isEqualTo("CollectResults Thread"); + + this.collectResultsThread.interrupt(); + } + + @Override + public void finish() { + + try { + verify(this.mockNodeOneFutureResult, never()).get(); + verify(this.mockNodeTwoFutureResult).get(anyLong(), any()); + } catch (ExecutionException | InterruptedException | TimeoutException cause) { + throw new RuntimeException(cause); + } } } } diff --git a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java index e42f0749d8..8b3c6770ac 100644 --- a/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java +++ b/src/test/java/org/springframework/data/redis/test/util/MockitoUtils.java @@ -20,53 +20,146 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import org.mockito.internal.invocation.InvocationMatcher; import org.mockito.internal.verification.api.VerificationData; import org.mockito.invocation.Invocation; import org.mockito.invocation.MatchableInvocation; +import org.mockito.quality.Strictness; +import org.mockito.stubbing.Answer; import org.mockito.verification.VerificationMode; +import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; +import org.springframework.util.function.ThrowingFunction; /** + * Utilities for using {@literal Mockito} and creating {@link Object mock objects} in {@literal unit tests}. + * * @author Christoph Strobl + * @author John Blum + * @see org.mockito.Mockito * @since 1.7 */ -public class MockitoUtils { +@SuppressWarnings("unused") +public abstract class MockitoUtils { /** - * Verifies a given method is called a total number of times across all given mocks. + * Creates a mock {@link Future} returning the given {@link Object result}. * - * @param method - * @param mode - * @param mocks + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @return a new mock {@link Future}. + * @see java.util.concurrent.Future */ - @SuppressWarnings({ "rawtypes", "serial" }) - public static void verifyInvocationsAcross(final String method, final VerificationMode mode, Object... mocks) { + @SuppressWarnings("unchecked") + public static Future mockFuture(@Nullable T result) { - mode.verify(new VerificationDataImpl(getInvocations(method, mocks), new InvocationMatcher(null, Collections - .singletonList(org.mockito.internal.matchers.Any.ANY)) { + try { - @Override - public boolean matches(Invocation actual) { - return true; - } + AtomicBoolean cancelled = new AtomicBoolean(false); + AtomicBoolean done = new AtomicBoolean(false); - @Override - public String toString() { - return String.format("%s for method: %s", mode, method); - } + Future mockFuture = mock(Future.class, withSettings().strictness(Strictness.LENIENT)); + + // A Future can only be cancelled if not done, it was not already cancelled, and no error occurred. + // The cancel(..) logic is not Thread-safe due to compound actions involving multiple variables. + // However, the cancel(..) logic does not necessarily need to be Thread-safe given the task execution + // of a Future is asynchronous and cancellation is driven by Thread interrupt from another Thread. + Answer cancelAnswer = invocation -> !done.get() && cancelled.compareAndSet(done.get(), true) + && done.compareAndSet(done.get(), true); + + Answer getAnswer = invocation -> { + + // The Future is done no matter if it returns the result or was cancelled/interrupted. + done.set(true); + + if (Thread.currentThread().isInterrupted()) { + throw new InterruptedException("Thread was interrupted"); + } + + if (cancelled.get()) { + throw new CancellationException("Task was cancelled"); + } + + return result; + }; + + doAnswer(invocation -> cancelled.get()).when(mockFuture).isCancelled(); + doAnswer(invocation -> done.get()).when(mockFuture).isDone(); + doAnswer(cancelAnswer).when(mockFuture).cancel(anyBoolean()); + doAnswer(getAnswer).when(mockFuture).get(); + doAnswer(getAnswer).when(mockFuture).get(anyLong(), isA(TimeUnit.class)); + + return mockFuture; + } catch (Exception cause) { + String message = String.format("Failed to create a mock of Future having result [%s]", result); + throw new IllegalStateException(message, cause); + } + } + + /** + * Creates a mock {@link Future} returning the given {@link Object result}, customized with the given, required + * {@link Function}. + * + * @param {@link Class type} of {@link Object result} returned by the mock {@link Future}. + * @param result {@link Object value} returned as the {@literal result} of the mock {@link Future}. + * @param futureFunction {@link Function} used to customize the mock {@link Future} on creation; must not be + * {@literal null}. + * @return a new mock {@link Future}. + * @see #mockFuture(Object) + */ + public static Future mockFuture(@Nullable T result, ThrowingFunction, Future> futureFunction) { + + Future mockFuture = mockFuture(result); - })); + return futureFunction.apply(mockFuture); + } + + /** + * Verifies a given method is called once across all given mocks. + * + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mocks array of {@link Object mock objects} to verify. + */ + public static void verifyInvocationsAcross(String method, Object... mocks) { + verifyInvocationsAcross(method, times(1), mocks); + } + + /** + * Verifies a given method is called a total number of times across all given mocks. + * + * @param method {@link String name} of a {@link java.lang.reflect.Method} on the {@link Object mock object}. + * @param mode mode of verification used by {@literal Mockito} to verify invocations on {@link Object mock objects}. + * @param mocks array of {@link Object mock objects} to verify. + */ + public static void verifyInvocationsAcross(String method, VerificationMode mode, Object... mocks) { + + mode.verify(new VerificationDataImpl(getInvocations(method, mocks), + new InvocationMatcher(null, Collections.singletonList(org.mockito.internal.matchers.Any.ANY)) { + + @Override + public boolean matches(Invocation actual) { + return true; + } + + @Override + public String toString() { + return String.format("%s for method: %s", mode, method); + } + })); } private static List getInvocations(String method, Object... mocks) { List invocations = new ArrayList<>(); - for (Object mock : mocks) { + for (Object mock : mocks) { if (StringUtils.hasText(method)) { - for (Invocation invocation : mockingDetails(mock).getInvocations()) { if (invocation.getMethod().getName().equals(method)) { invocations.add(invocation); @@ -76,6 +169,7 @@ private static List getInvocations(String method, Object... mocks) { invocations.addAll(mockingDetails(mock).getInvocations()); } } + return invocations; } @@ -98,7 +192,6 @@ public List getAllInvocations() { public MatchableInvocation getTarget() { return wanted; } - } }