diff --git a/spring-integration-core/src/main/java/org/springframework/integration/store/AbstractKeyValueMessageStore.java b/spring-integration-core/src/main/java/org/springframework/integration/store/AbstractKeyValueMessageStore.java index 68505a830b9..1f5eac29f68 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/store/AbstractKeyValueMessageStore.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/store/AbstractKeyValueMessageStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import java.util.stream.Stream; import org.springframework.jmx.export.annotation.ManagedAttribute; +import org.springframework.lang.Nullable; import org.springframework.messaging.Message; import org.springframework.util.Assert; @@ -45,7 +46,7 @@ public abstract class AbstractKeyValueMessageStore extends AbstractMessageGroupS protected static final String MESSAGE_KEY_PREFIX = "MESSAGE_"; - protected static final String MESSAGE_GROUP_KEY_PREFIX = "MESSAGE_GROUP_"; + protected static final String MESSAGE_GROUP_KEY_PREFIX = "GROUP_OF_MESSAGES_"; private final String messagePrefix; @@ -57,9 +58,9 @@ protected AbstractKeyValueMessageStore() { /** * Construct an instance based on the provided prefix for keys to distinguish between - * different store instances in the same target key-value data base. Defaults to an + * different store instances in the same target key-value database. Defaults to an * empty string - no prefix. The actual prefix for messages is - * {@code prefix + MESSAGE_}; for message groups - {@code prefix + MESSAGE_GROUP_} + * {@code prefix + MESSAGE_}; for message groups - {@code prefix + GROUP_OF_MESSAGES_} * @param prefix the prefix to use * @since 4.3.12 */ @@ -71,18 +72,18 @@ protected AbstractKeyValueMessageStore(String prefix) { /** * Return the configured prefix for message keys to distinguish between different - * store instances in the same target key-value data base. Defaults to the + * store instances in the same target key-value database. Defaults to the * {@value MESSAGE_KEY_PREFIX} - without a custom prefix. * @return the prefix for keys * @since 4.3.12 */ - protected String getMessagePrefix() { + public String getMessagePrefix() { return this.messagePrefix; } /** * Return the configured prefix for message group keys to distinguish between - * different store instances in the same target key-value data base. Defaults to the + * different store instances in the same target key-value database. Defaults to the * {@value MESSAGE_GROUP_KEY_PREFIX} - without custom prefix. * @return the prefix for keys * @since 4.3.12 @@ -140,10 +141,15 @@ public Message addMessage(Message message) { } protected void doAddMessage(Message message) { + doAddMessage(message, null); + } + + protected void doAddMessage(Message message, @Nullable Object groupId) { Assert.notNull(message, "'message' must not be null"); UUID messageId = message.getHeaders().getId(); Assert.notNull(messageId, "Cannot store messages without an ID header"); - doStoreIfAbsent(this.messagePrefix + messageId, new MessageHolder(message)); + String messageKey = this.messagePrefix + (groupId != null ? groupId.toString() + '_' : "") + messageId; + doStoreIfAbsent(messageKey, new MessageHolder(message)); } @Override @@ -165,7 +171,6 @@ public long getMessageCount() { return (messageIds != null) ? messageIds.size() : 0; } - // MessageGroupStore methods /** @@ -211,7 +216,7 @@ public void addMessagesToGroup(Object groupId, Message... messages) { } for (Message message : messages) { - doAddMessage(message); + doAddMessage(message, groupId); if (metadata != null) { metadata.add(message.getHeaders().getId()); } @@ -253,7 +258,7 @@ public void removeMessagesFromGroup(Object groupId, Collection> messa List messageIds = new ArrayList<>(); for (UUID id : ids) { - messageIds.add(this.messagePrefix + id); + messageIds.add(this.messagePrefix + groupId + '_' + id); } doRemoveAll(messageIds); @@ -288,7 +293,7 @@ public void removeMessageGroup(Object groupId) { List messageIds = messageGroupMetadata.getMessageIds() .stream() - .map(id -> this.messagePrefix + id) + .map(id -> this.messagePrefix + groupId + '_' + id) .collect(Collectors.toList()); doRemoveAll(messageIds); @@ -326,24 +331,47 @@ public Message pollMessageFromGroup(Object groupId) { groupMetadata.remove(firstId); groupMetadata.setLastModified(System.currentTimeMillis()); doStore(this.groupPrefix + groupId, groupMetadata); - return removeMessage(firstId); + return removeMessageFromGroup(firstId, groupId); } } return null; } + private Message removeMessageFromGroup(UUID id, Object groupId) { + Assert.notNull(id, "'id' must not be null"); + Object object = doRemove(this.messagePrefix + groupId + '_' + id); + if (object != null) { + return extractMessage(object); + } + else { + return null; + } + } + @Override public Message getOneMessageFromGroup(Object groupId) { MessageGroupMetadata groupMetadata = getGroupMetadata(groupId); if (groupMetadata != null) { UUID messageId = groupMetadata.firstId(); if (messageId != null) { - return getMessage(messageId); + return getMessageFromGroup(messageId, groupId); } } return null; } + @Nullable + private Message getMessageFromGroup(UUID messageId, Object groupId) { + Assert.notNull(messageId, "'messageId' must not be null"); + Object object = doRetrieve(this.messagePrefix + groupId + '_' + messageId); + if (object != null) { + return extractMessage(object); + } + else { + return null; + } + } + @Override public Collection> getMessagesForGroup(Object groupId) { MessageGroupMetadata groupMetadata = getGroupMetadata(groupId); @@ -351,7 +379,7 @@ public Collection> getMessagesForGroup(Object groupId) { if (groupMetadata != null) { Iterator messageIds = groupMetadata.messageIdIterator(); while (messageIds.hasNext()) { - messages.add(getMessage(messageIds.next())); + messages.add(getMessageFromGroup(messageIds.next(), groupId)); } } return messages; @@ -362,7 +390,7 @@ public Stream> streamMessagesForGroup(Object groupId) { return getGroupMetadata(groupId) .getMessageIds() .stream() - .map(this::getMessage); + .map((messageId) -> getMessageFromGroup(messageId, groupId)); } @Override @@ -376,8 +404,8 @@ public Iterator iterator() { private Collection normalizeKeys(Collection keys) { Set normalizedKeys = new HashSet<>(); - for (Object key : keys) { - String strKey = (String) key; + for (String key : keys) { + String strKey = key; if (strKey.startsWith(this.groupPrefix)) { strKey = strKey.replace(this.groupPrefix, ""); } diff --git a/spring-integration-hazelcast/src/main/java/org/springframework/integration/hazelcast/store/HazelcastMessageStore.java b/spring-integration-hazelcast/src/main/java/org/springframework/integration/hazelcast/store/HazelcastMessageStore.java index a6e3cb51db9..b58a20927cf 100644 --- a/spring-integration-hazelcast/src/main/java/org/springframework/integration/hazelcast/store/HazelcastMessageStore.java +++ b/spring-integration-hazelcast/src/main/java/org/springframework/integration/hazelcast/store/HazelcastMessageStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2017-2022 the original author or authors. + * Copyright 2017-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/spring-integration-hazelcast/src/test/java/org/springframework/integration/hazelcast/store/HazelcastMessageStoreTests.java b/spring-integration-hazelcast/src/test/java/org/springframework/integration/hazelcast/store/HazelcastMessageStoreTests.java index 23cd21d10cf..d2b88a377e6 100644 --- a/spring-integration-hazelcast/src/test/java/org/springframework/integration/hazelcast/store/HazelcastMessageStoreTests.java +++ b/spring-integration-hazelcast/src/test/java/org/springframework/integration/hazelcast/store/HazelcastMessageStoreTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2017-2022 the original author or authors. + * Copyright 2017-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ import com.hazelcast.core.Hazelcast; import com.hazelcast.core.HazelcastInstance; import com.hazelcast.map.IMap; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.springframework.integration.channel.DirectChannel; import org.springframework.integration.history.MessageHistory; @@ -50,26 +50,25 @@ public class HazelcastMessageStoreTests { private static IMap map; - @BeforeClass + @BeforeAll public static void init() { instance = Hazelcast.newHazelcastInstance(); map = instance.getMap("customTestsMessageStore"); store = new HazelcastMessageStore(map); } - @AfterClass + @AfterAll public static void destroy() { instance.shutdown(); } - @Before + @BeforeEach public void clean() { map.clear(); } @Test public void testWithMessageHistory() { - Message message = new GenericMessage<>("Hello"); DirectChannel fooChannel = new DirectChannel(); fooChannel.setBeanName("fooChannel"); @@ -107,7 +106,6 @@ public void testAddAndRemoveMessagesFromMessageGroup() { @Test public void addAndGetMessage() { - Message message = MessageBuilder.withPayload("test").build(); store.addMessage(message); Message retrieved = store.getMessage(message.getHeaders().getId()); @@ -145,4 +143,34 @@ public void messageStoreIterator() { assertThat(groupCount).isEqualTo(1); } + @Test + public void sameMessageInTwoGroupsNotRemovedByFirstGroup() { + GenericMessage testMessage = new GenericMessage<>("test data"); + + store.addMessageToGroup("1", testMessage); + store.addMessageToGroup("2", testMessage); + + store.removeMessageGroup("1"); + + assertThat(store.getMessageCount()).isEqualTo(1); + + store.removeMessageGroup("2"); + + assertThat(store.getMessageCount()).isEqualTo(0); + } + + @Test + public void removeMessagesFromGroupDontRemoveSameMessageInOtherGroup() { + GenericMessage testMessage = new GenericMessage<>("test data"); + + store.addMessageToGroup("1", testMessage); + store.addMessageToGroup("2", testMessage); + + store.removeMessagesFromGroup("1", testMessage); + + assertThat(store.getMessageCount()).isEqualTo(1); + assertThat(store.messageGroupSize("1")).isEqualTo(0); + assertThat(store.messageGroupSize("2")).isEqualTo(1); + } + } diff --git a/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/ConfigurableMongoDbMessageStore.java b/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/ConfigurableMongoDbMessageStore.java index cb6c809f8fb..df6ab0363cb 100644 --- a/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/ConfigurableMongoDbMessageStore.java +++ b/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/ConfigurableMongoDbMessageStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2022 the original author or authors. + * Copyright 2013-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,7 +101,8 @@ public Message addMessage(Message message) { @Override public Message removeMessage(UUID id) { Assert.notNull(id, "'id' must not be null"); - Query query = Query.query(Criteria.where(MessageDocumentFields.MESSAGE_ID).is(id)); + Query query = Query.query(Criteria.where(MessageDocumentFields.MESSAGE_ID).is(id) + .and(MessageDocumentFields.GROUP_ID).exists(false)); MessageDocument document = getMongoTemplate().findAndRemove(query, MessageDocument.class, this.collectionName); return (document != null) ? document.getMessage() : null; } diff --git a/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java b/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java index 9696ffa387e..81e68d6620a 100644 --- a/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java +++ b/spring-integration-mongodb/src/main/java/org/springframework/integration/mongodb/store/MongoDbMessageStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -258,14 +258,15 @@ public MessageMetadata getMessageMetadata(UUID id) { @Override @ManagedAttribute public long getMessageCount() { - return this.template.getCollection(this.collectionName).countDocuments(); + Query query = Query.query(Criteria.where("headers.id").exists(true).and(GROUP_ID_KEY).exists(false)); + return this.template.getCollection(this.collectionName).countDocuments(query.getQueryObject()); } @Override public Message removeMessage(UUID id) { Assert.notNull(id, "'id' must not be null"); - MessageWrapper messageWrapper = - this.template.findAndRemove(whereMessageIdIs(id), MessageWrapper.class, this.collectionName); + Query query = Query.query(Criteria.where("headers.id").is(id).and(GROUP_ID_KEY).exists(false)); + MessageWrapper messageWrapper = this.template.findAndRemove(query, MessageWrapper.class, this.collectionName); return (messageWrapper != null ? messageWrapper.getMessage() : null); } diff --git a/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageGroupStoreTests.java b/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageGroupStoreTests.java index d8385cc1781..0c94b46eb13 100644 --- a/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageGroupStoreTests.java +++ b/spring-integration-mongodb/src/test/java/org/springframework/integration/mongodb/store/AbstractMongoDbMessageGroupStoreTests.java @@ -536,6 +536,28 @@ void testWithMessageHistory() { .containsEntry("type", "channel"); } + @Test + public void removeMessageDoesntRemoveSameMessageInTheGroup() { + GenericMessage testMessage = new GenericMessage<>("test data"); + + MessageGroupStore store = getMessageGroupStore(); + + store.addMessageToGroup("1", testMessage); + + MessageStore messageStore = (MessageStore) store; + + messageStore.removeMessage(testMessage.getHeaders().getId()); + + assertThat(messageStore.getMessageCount()).isEqualTo(0); + assertThat(store.getMessageCountForAllMessageGroups()).isEqualTo(1); + assertThat(store.messageGroupSize("1")).isEqualTo(1); + + store.removeMessageGroup("1"); + + assertThat(store.getMessageCountForAllMessageGroups()).isEqualTo(0); + assertThat(store.messageGroupSize("1")).isEqualTo(0); + } + protected abstract MessageGroupStore getMessageGroupStore(); protected abstract MessageStore getMessageStore(); diff --git a/spring-integration-redis/src/test/java/org/springframework/integration/redis/store/RedisMessageGroupStoreTests.java b/spring-integration-redis/src/test/java/org/springframework/integration/redis/store/RedisMessageGroupStoreTests.java index e70bf2c741c..1da003ac884 100644 --- a/spring-integration-redis/src/test/java/org/springframework/integration/redis/store/RedisMessageGroupStoreTests.java +++ b/spring-integration-redis/src/test/java/org/springframework/integration/redis/store/RedisMessageGroupStoreTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2007-2022 the original author or authors. + * Copyright 2007-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,6 +56,7 @@ import org.springframework.messaging.support.GenericMessage; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.fail; /** @@ -65,6 +66,7 @@ * @author Artem Vozhdayenko */ class RedisMessageGroupStoreTests implements RedisContainerTest { + private static RedisConnectionFactory redisConnectionFactory; @BeforeAll @@ -74,17 +76,18 @@ static void setupConnection() { private final UUID groupId = UUID.randomUUID(); + RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); + @BeforeEach @AfterEach void setUpTearDown() { StringRedisTemplate template = RedisContainerTest.createStringRedisTemplate(redisConnectionFactory); - template.delete(template.keys("MESSAGE_GROUP_*")); + template.delete(template.keys("MESSAGE_*")); + template.delete(template.keys("GROUP_OF_MESSAGES_*")); } @Test void testNonExistingEmptyMessageGroup() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - MessageGroup messageGroup = store.getMessageGroup(this.groupId); assertThat(messageGroup).isNotNull(); assertThat(messageGroup).isInstanceOf(SimpleMessageGroup.class); @@ -93,8 +96,6 @@ void testNonExistingEmptyMessageGroup() { @Test void testMessageGroupUpdatedDateChangesWithEachAddedMessage() throws Exception { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - Message message = new GenericMessage<>("Hello"); MessageGroup messageGroup = store.addMessageToGroup(this.groupId, message); assertThat(messageGroup.size()).isEqualTo(1); @@ -117,8 +118,6 @@ void testMessageGroupUpdatedDateChangesWithEachAddedMessage() throws Exception { @Test void testMessageGroupWithAddedMessage() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - Message message = new GenericMessage<>("Hello"); MessageGroup messageGroup = store.addMessageToGroup(this.groupId, message); assertThat(messageGroup.size()).isEqualTo(1); @@ -132,8 +131,6 @@ void testMessageGroupWithAddedMessage() { @Test void testRemoveMessageGroup() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - MessageGroup messageGroup = store.getMessageGroup(this.groupId); Message message = new GenericMessage<>("Hello"); messageGroup = store.addMessageToGroup(messageGroup.getGroupId(), message); @@ -157,8 +154,6 @@ void testRemoveMessageGroup() { @Test void testCompleteMessageGroup() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - MessageGroup messageGroup = store.getMessageGroup(this.groupId); Message message = new GenericMessage<>("Hello"); messageGroup = store.addMessageToGroup(messageGroup.getGroupId(), message); @@ -169,8 +164,6 @@ void testCompleteMessageGroup() { @Test void testLastReleasedSequenceNumber() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - MessageGroup messageGroup = store.getMessageGroup(this.groupId); Message message = new GenericMessage<>("Hello"); messageGroup = store.addMessageToGroup(messageGroup.getGroupId(), message); @@ -181,8 +174,6 @@ void testLastReleasedSequenceNumber() { @Test void testRemoveMessageFromTheGroup() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - MessageGroup messageGroup = store.getMessageGroup(this.groupId); Message message = new GenericMessage<>("2"); store.addMessagesToGroup(messageGroup.getGroupId(), new GenericMessage<>("1"), message); @@ -202,8 +193,6 @@ void testRemoveMessageFromTheGroup() { @Test void testWithMessageHistory() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - Message message = new GenericMessage<>("Hello"); DirectChannel fooChannel = new DirectChannel(); fooChannel.setBeanName("fooChannel"); @@ -227,17 +216,16 @@ void testWithMessageHistory() { @Test void testRemoveNonExistingMessageFromTheGroup() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - MessageGroup messageGroup = store.getMessageGroup(this.groupId); store.addMessagesToGroup(messageGroup.getGroupId(), new GenericMessage<>("1")); - store.removeMessagesFromGroup(this.groupId, new GenericMessage<>("2")); + assertThatNoException() + .isThrownBy(() -> store.removeMessagesFromGroup(this.groupId, new GenericMessage<>("2"))); } @Test void testRemoveNonExistingMessageFromNonExistingTheGroup() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - store.removeMessagesFromGroup(this.groupId, new GenericMessage<>("2")); + assertThatNoException() + .isThrownBy(() -> store.removeMessagesFromGroup(this.groupId, new GenericMessage<>("2"))); } @@ -283,14 +271,9 @@ void testIteratorOfMessageGroups() { while (messageGroups.hasNext()) { MessageGroup group = messageGroups.next(); String groupId = (String) group.getGroupId(); - if (groupId.equals("1")) { - assertThat(group.getMessages().size()).isEqualTo(1); - } - else if (groupId.equals("2")) { - assertThat(group.getMessages().size()).isEqualTo(1); - } - else if (groupId.equals("3")) { - assertThat(group.getMessages().size()).isEqualTo(2); + switch (groupId) { + case "1", "2" -> assertThat(group.getMessages().size()).isEqualTo(1); + case "3" -> assertThat(group.getMessages().size()).isEqualTo(2); } counter++; } @@ -390,25 +373,22 @@ void testWithAggregatorWithShutdown() { @Test void testAddAndRemoveMessagesFromMessageGroup() { - RedisMessageStore messageStore = new RedisMessageStore(redisConnectionFactory); - List> messages = new ArrayList>(); + List> messages = new ArrayList<>(); for (int i = 0; i < 25; i++) { Message message = MessageBuilder.withPayload("foo").setCorrelationId(this.groupId).build(); - messageStore.addMessagesToGroup(this.groupId, message); + store.addMessagesToGroup(this.groupId, message); messages.add(message); } - MessageGroup group = messageStore.getMessageGroup(this.groupId); + MessageGroup group = store.getMessageGroup(this.groupId); assertThat(group.size()).isEqualTo(25); - messageStore.removeMessagesFromGroup(this.groupId, messages); - group = messageStore.getMessageGroup(this.groupId); + store.removeMessagesFromGroup(this.groupId, messages); + group = store.getMessageGroup(this.groupId); assertThat(group.size()).isZero(); - messageStore.removeMessageGroup(this.groupId); + store.removeMessageGroup(this.groupId); } @Test void testJsonSerialization() { - RedisMessageStore store = new RedisMessageStore(redisConnectionFactory); - ObjectMapper mapper = JacksonJsonUtils.messagingAwareMapper(); GenericJackson2JsonRedisSerializer serializer = new GenericJackson2JsonRedisSerializer(mapper); @@ -476,6 +456,36 @@ void testJsonSerialization() { assertThat(messageGroup.getMessages().iterator().next()).isEqualTo(fooMessage); } + @Test + public void sameMessageInTwoGroupsNotRemovedByFirstGroup() { + GenericMessage testMessage = new GenericMessage<>("test data"); + + store.addMessageToGroup("1", testMessage); + store.addMessageToGroup("2", testMessage); + + store.removeMessageGroup("1"); + + assertThat(store.getMessageCount()).isEqualTo(1); + + store.removeMessageGroup("2"); + + assertThat(store.getMessageCount()).isEqualTo(0); + } + + @Test + public void removeMessagesFromGroupDontRemoveSameMessageInOtherGroup() { + GenericMessage testMessage = new GenericMessage<>("test data"); + + store.addMessageToGroup("1", testMessage); + store.addMessageToGroup("2", testMessage); + + store.removeMessagesFromGroup("1", testMessage); + + assertThat(store.getMessageCount()).isEqualTo(1); + assertThat(store.messageGroupSize("1")).isEqualTo(0); + assertThat(store.messageGroupSize("2")).isEqualTo(1); + } + private static class Foo { private String foo;