Skip to content

Commit 0fcdf92

Browse files
bky373artembilan
authored andcommitted
GH-3328: Add missing seek callbacks on each topic partition
Fixes: #3328 When using an `AbstractConsumerSeekAware` in a multi-group listeners scenario, there are cases where the number of registered callbacks differs from the number of discovered callbacks. This is due to the value type of callbacks Map in `AbstractConsumerSeekAware` class being simply `ConsumerSeekCallback`. This causes some callbacks looking at the same partition to be missing. * Change the value type of callbacks Map in `AbstractConsumerSeekAware` class from `ConsumerSeekCallback` to `List<ConsumerSeekCallback>`. * Also modify some methods, test codes and docs that are affected by this change. * Add test codes to verify that the callbacks registered via `registeredSeekCallback()` and the ones you can get via `getSeekCallbacks()` match completely.
1 parent f91f8a9 commit 0fcdf92

File tree

7 files changed

+124
-51
lines changed

7 files changed

+124
-51
lines changed

spring-kafka-docs/src/main/antora/modules/ROOT/pages/kafka/seek.adoc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,18 @@ public class SeekToLastOnIdleListener extends AbstractConsumerSeekAware {
186186
* Rewind all partitions one record.
187187
*/
188188
public void rewindAllOneRecord() {
189-
getSeekCallbacks()
190-
.forEach((tp, callback) ->
191-
callback.seekRelative(tp.topic(), tp.partition(), -1, true));
189+
getTopicsAndCallbacks()
190+
.forEach((tp, callbacks) ->
191+
callbacks.forEach(callback -> callback.seekRelative(tp.topic(), tp.partition(), -1, true))
192+
);
192193
}
193194
194195
/**
195196
* Rewind one partition one record.
196197
*/
197198
public void rewindOnePartitionOneRecord(String topic, int partition) {
198-
getSeekCallbackFor(new TopicPartition(topic, partition))
199-
.seekRelative(topic, partition, -1, true);
199+
getSeekCallbacksFor(new TopicPartition(topic, partition))
200+
.forEach(callback -> callback.seekRelative(topic, partition, -1, true));
200201
}
201202
202203
}

spring-kafka-docs/src/main/antora/modules/ROOT/pages/whats-new.adoc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ The naming convention for DLT topics has been standardized to use the "-dlt" suf
1717

1818
A new method, `getGroupId()`, has been added to the `ConsumerSeekCallback` interface.
1919
This method allows for more selective seek operations by targeting only the desired consumer group.
20+
The `AbstractConsumerSeekAware` can also now register, retrieve, and remove all callbacks for each topic partition in a multi-group listener scenario without missing any.
21+
See the new APIs (`getSeekCallbacksFor(TopicPartition topicPartition)`, `getTopicsAndCallbacks()`) for more details.
2022
For more details, see xref:kafka/seek.adoc#seek[Seek API Docs].
2123

2224
[[x33-new-option-ignore-empty-batch]]

spring-kafka/src/main/java/org/springframework/kafka/listener/AbstractConsumerSeekAware.java

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019-2023 the original author or authors.
2+
* Copyright 2019-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,33 +16,37 @@
1616

1717
package org.springframework.kafka.listener;
1818

19+
import java.util.ArrayList;
1920
import java.util.Collection;
2021
import java.util.Collections;
2122
import java.util.LinkedList;
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.concurrent.ConcurrentHashMap;
26+
import java.util.stream.Collectors;
2527

2628
import org.apache.kafka.common.TopicPartition;
2729

2830
import org.springframework.lang.Nullable;
31+
import org.springframework.util.CollectionUtils;
2932

3033
/**
3134
* Manages the {@link ConsumerSeekAware.ConsumerSeekCallback} s for the listener. If the
3235
* listener subclasses this class, it can easily seek arbitrary topics/partitions without
3336
* having to keep track of the callbacks itself.
3437
*
3538
* @author Gary Russell
39+
* @author Borahm Lee
3640
* @since 2.3
3741
*
3842
*/
3943
public abstract class AbstractConsumerSeekAware implements ConsumerSeekAware {
4044

4145
private final Map<Thread, ConsumerSeekCallback> callbackForThread = new ConcurrentHashMap<>();
4246

43-
private final Map<TopicPartition, ConsumerSeekCallback> callbacks = new ConcurrentHashMap<>();
47+
private final Map<TopicPartition, List<ConsumerSeekCallback>> topicToCallbacks = new ConcurrentHashMap<>();
4448

45-
private final Map<ConsumerSeekCallback, List<TopicPartition>> callbacksToTopic = new ConcurrentHashMap<>();
49+
private final Map<ConsumerSeekCallback, List<TopicPartition>> callbackToTopics = new ConcurrentHashMap<>();
4650

4751
@Override
4852
public void registerSeekCallback(ConsumerSeekCallback callback) {
@@ -54,24 +58,26 @@ public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, Consumer
5458
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
5559
if (threadCallback != null) {
5660
assignments.keySet().forEach(tp -> {
57-
this.callbacks.put(tp, threadCallback);
58-
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
61+
this.topicToCallbacks.computeIfAbsent(tp, key -> new ArrayList<>()).add(threadCallback);
62+
this.callbackToTopics.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
5963
});
6064
}
6165
}
6266

6367
@Override
6468
public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
6569
partitions.forEach(tp -> {
66-
ConsumerSeekCallback removed = this.callbacks.remove(tp);
67-
if (removed != null) {
68-
List<TopicPartition> topics = this.callbacksToTopic.get(removed);
69-
if (topics != null) {
70-
topics.remove(tp);
71-
if (topics.size() == 0) {
72-
this.callbacksToTopic.remove(removed);
70+
List<ConsumerSeekCallback> removedCallbacks = this.topicToCallbacks.remove(tp);
71+
if (removedCallbacks != null && !removedCallbacks.isEmpty()) {
72+
removedCallbacks.forEach(cb -> {
73+
List<TopicPartition> topics = this.callbackToTopics.get(cb);
74+
if (topics != null) {
75+
topics.remove(tp);
76+
if (topics.isEmpty()) {
77+
this.callbackToTopics.remove(cb);
78+
}
7379
}
74-
}
80+
});
7581
}
7682
});
7783
}
@@ -82,21 +88,55 @@ public void unregisterSeekCallback() {
8288
}
8389

8490
/**
85-
* Return the callback for the specified topic/partition.
86-
* @param topicPartition the topic/partition.
87-
* @return the callback (or null if there is no assignment).
88-
*/
91+
* Return the callback for the specified topic/partition.
92+
* @param topicPartition the topic/partition.
93+
* @return the callback (or null if there is no assignment).
94+
* @deprecated Replaced by {@link #getSeekCallbacksFor(TopicPartition)}
95+
*/
96+
@Deprecated(since = "3.3", forRemoval = true)
8997
@Nullable
9098
protected ConsumerSeekCallback getSeekCallbackFor(TopicPartition topicPartition) {
91-
return this.callbacks.get(topicPartition);
99+
List<ConsumerSeekCallback> callbacks = getSeekCallbacksFor(topicPartition);
100+
if (CollectionUtils.isEmpty(callbacks)) {
101+
return null;
102+
}
103+
return callbacks.get(0);
104+
}
105+
106+
/**
107+
* Return the callbacks for the specified topic/partition.
108+
* @param topicPartition the topic/partition.
109+
* @return the callbacks (or null if there is no assignment).
110+
* @since 3.3
111+
*/
112+
@Nullable
113+
protected List<ConsumerSeekCallback> getSeekCallbacksFor(TopicPartition topicPartition) {
114+
return this.topicToCallbacks.get(topicPartition);
92115
}
93116

94117
/**
95118
* The map of callbacks for all currently assigned partitions.
96119
* @return the map.
120+
* @deprecated Replaced by {@link #getTopicsAndCallbacks()}
97121
*/
122+
@Deprecated(since = "3.3", forRemoval = true)
98123
protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
99-
return Collections.unmodifiableMap(this.callbacks);
124+
Map<TopicPartition, List<ConsumerSeekCallback>> topicsAndCallbacks = getTopicsAndCallbacks();
125+
return topicsAndCallbacks.entrySet().stream()
126+
.filter(entry -> !entry.getValue().isEmpty())
127+
.collect(Collectors.toMap(
128+
Map.Entry::getKey,
129+
entry -> entry.getValue().get(0)
130+
));
131+
}
132+
133+
/**
134+
* The map of callbacks for all currently assigned partitions.
135+
* @return the map.
136+
* @since 3.3
137+
*/
138+
protected Map<TopicPartition, List<ConsumerSeekCallback>> getTopicsAndCallbacks() {
139+
return Collections.unmodifiableMap(this.topicToCallbacks);
100140
}
101141

102142
/**
@@ -105,23 +145,23 @@ protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
105145
* @since 2.6
106146
*/
107147
protected Map<ConsumerSeekCallback, List<TopicPartition>> getCallbacksAndTopics() {
108-
return Collections.unmodifiableMap(this.callbacksToTopic);
148+
return Collections.unmodifiableMap(this.callbackToTopics);
109149
}
110150

111151
/**
112152
* Seek all assigned partitions to the beginning.
113153
* @since 2.6
114154
*/
115155
public void seekToBeginning() {
116-
getCallbacksAndTopics().forEach((cb, topics) -> cb.seekToBeginning(topics));
156+
getCallbacksAndTopics().forEach(ConsumerSeekCallback::seekToBeginning);
117157
}
118158

119159
/**
120160
* Seek all assigned partitions to the end.
121161
* @since 2.6
122162
*/
123163
public void seekToEnd() {
124-
getCallbacksAndTopics().forEach((cb, topics) -> cb.seekToEnd(topics));
164+
getCallbacksAndTopics().forEach(ConsumerSeekCallback::seekToEnd);
125165
}
126166

127167
/**

spring-kafka/src/test/java/org/springframework/kafka/annotation/EnableKafkaIntegrationTests.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@
180180
* @author Nakul Mishra
181181
* @author Soby Chacko
182182
* @author Wang Zhiyang
183+
* @author Borahm Lee
183184
*/
184185
@SpringJUnitConfig
185186
@DirtiesContext
@@ -1081,7 +1082,7 @@ public void testSeekToLastOnIdle() throws InterruptedException {
10811082
assertThat(this.seekOnIdleListener.latch3.await(10, TimeUnit.SECONDS)).isTrue();
10821083
this.registry.getListenerContainer("seekOnIdle").stop();
10831084
assertThat(this.seekOnIdleListener.latch4.await(10, TimeUnit.SECONDS)).isTrue();
1084-
assertThat(KafkaTestUtils.getPropertyValue(this.seekOnIdleListener, "callbacks", Map.class)).hasSize(0);
1085+
assertThat(KafkaTestUtils.getPropertyValue(this.seekOnIdleListener, "topicToCallbacks", Map.class)).hasSize(0);
10851086
}
10861087

10871088
@SuppressWarnings({"unchecked", "rawtypes"})
@@ -2523,11 +2524,10 @@ public void listen(String in) throws InterruptedException {
25232524
if (latch1.getCount() > 0) {
25242525
latch1.countDown();
25252526
if (latch1.getCount() == 0) {
2526-
ConsumerSeekCallback seekToComputeFn = getSeekCallbackFor(
2527+
List<ConsumerSeekCallback> seekToComputeFunctions = getSeekCallbacksFor(
25272528
new org.apache.kafka.common.TopicPartition("seekToComputeFn", 0));
2528-
assertThat(seekToComputeFn).isNotNull();
2529-
seekToComputeFn.
2530-
seek("seekToComputeFn", 0, current -> 0L);
2529+
assertThat(seekToComputeFunctions).isNotEmpty();
2530+
seekToComputeFunctions.forEach(callback -> callback.seek("seekToComputeFn", 0, current -> 0L));
25312531
}
25322532
}
25332533
}
@@ -2576,14 +2576,15 @@ public void onIdleContainer(Map<org.apache.kafka.common.TopicPartition, Long> as
25762576
}
25772577

25782578
public void rewindAllOneRecord() {
2579-
getSeekCallbacks()
2580-
.forEach((tp, callback) ->
2581-
callback.seekRelative(tp.topic(), tp.partition(), -1, true));
2579+
getTopicsAndCallbacks()
2580+
.forEach((tp, callbacks) ->
2581+
callbacks.forEach(callback -> callback.seekRelative(tp.topic(), tp.partition(), -1, true))
2582+
);
25822583
}
25832584

25842585
public void rewindOnePartitionOneRecord(String topic, int partition) {
2585-
getSeekCallbackFor(new org.apache.kafka.common.TopicPartition(topic, partition))
2586-
.seekRelative(topic, partition, -1, true);
2586+
getSeekCallbacksFor(new org.apache.kafka.common.TopicPartition(topic, partition))
2587+
.forEach(callback -> callback.seekRelative(topic, partition, -1, true));
25872588
}
25882589

25892590
@Override

spring-kafka/src/test/java/org/springframework/kafka/listener/AbstractConsumerSeekAwareTests.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717
package org.springframework.kafka.listener;
1818

1919
import static org.assertj.core.api.Assertions.assertThat;
20+
import static org.awaitility.Awaitility.await;
2021

22+
import java.time.Duration;
23+
import java.util.Collection;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.Set;
2127
import java.util.concurrent.CountDownLatch;
2228
import java.util.concurrent.TimeUnit;
29+
import java.util.stream.Collectors;
2330

31+
import org.apache.kafka.common.TopicPartition;
2432
import org.junit.jupiter.api.Test;
2533

2634
import org.springframework.beans.factory.annotation.Autowired;
@@ -35,6 +43,7 @@
3543
import org.springframework.kafka.core.KafkaTemplate;
3644
import org.springframework.kafka.core.ProducerFactory;
3745
import org.springframework.kafka.listener.AbstractConsumerSeekAwareTests.Config.MultiGroupListener;
46+
import org.springframework.kafka.listener.ConsumerSeekAware.ConsumerSeekCallback;
3847
import org.springframework.kafka.test.EmbeddedKafkaBroker;
3948
import org.springframework.kafka.test.context.EmbeddedKafka;
4049
import org.springframework.kafka.test.utils.KafkaTestUtils;
@@ -62,6 +71,22 @@ class AbstractConsumerSeekAwareTests {
6271
@Autowired
6372
MultiGroupListener multiGroupListener;
6473

74+
@Test
75+
public void checkCallbacksAndTopicPartitions() {
76+
await().timeout(Duration.ofSeconds(5)).untilAsserted(() -> {
77+
Map<ConsumerSeekCallback, List<TopicPartition>> callbacksAndTopics = multiGroupListener.getCallbacksAndTopics();
78+
Set<ConsumerSeekCallback> registeredCallbacks = callbacksAndTopics.keySet();
79+
Set<TopicPartition> registeredTopicPartitions = callbacksAndTopics.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
80+
81+
Map<TopicPartition, List<ConsumerSeekCallback>> topicsAndCallbacks = multiGroupListener.getTopicsAndCallbacks();
82+
Set<TopicPartition> getTopicPartitions = topicsAndCallbacks.keySet();
83+
Set<ConsumerSeekCallback> getCallbacks = topicsAndCallbacks.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
84+
85+
assertThat(registeredCallbacks).containsExactlyInAnyOrderElementsOf(getCallbacks).isNotEmpty();
86+
assertThat(registeredTopicPartitions).containsExactlyInAnyOrderElementsOf(getTopicPartitions).hasSize(3);
87+
});
88+
}
89+
6590
@Test
6691
void seekForAllGroups() throws Exception {
6792
template.send(TOPIC, "test-data");
@@ -130,12 +155,12 @@ static class MultiGroupListener extends AbstractConsumerSeekAware {
130155

131156
static CountDownLatch latch2 = new CountDownLatch(2);
132157

133-
@KafkaListener(groupId = "group1", topics = TOPIC)
158+
@KafkaListener(groupId = "group1", topics = TOPIC, concurrency = "2")
134159
void listenForGroup1(String in) {
135160
latch1.countDown();
136161
}
137162

138-
@KafkaListener(groupId = "group2", topics = TOPIC)
163+
@KafkaListener(groupId = "group2", topics = TOPIC, concurrency = "2")
139164
void listenForGroup2(String in) {
140165
latch2.countDown();
141166
}

spring-kafka/src/test/java/org/springframework/kafka/listener/ConsumerSeekAwareTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
/**
3838
* @author Gary Russell
39+
* @author Borahm Lee
3940
* @since 2.6
4041
*
4142
*/
@@ -104,8 +105,8 @@ class CSA extends AbstractConsumerSeekAware {
104105
};
105106
exec1.submit(revoke2).get();
106107
exec2.submit(revoke2).get();
107-
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacks", Map.class)).isEmpty();
108-
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacksToTopic", Map.class)).isEmpty();
108+
assertThat(KafkaTestUtils.getPropertyValue(csa, "topicToCallbacks", Map.class)).isEmpty();
109+
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackToTopics", Map.class)).isEmpty();
109110
var checkTL = (Callable<Void>) () -> {
110111
csa.unregisterSeekCallback();
111112
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackForThread", Map.class).get(Thread.currentThread()))

spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
* @author Soby Chacko
145145
* @author Wang Zhiyang
146146
* @author Mikael Carlstedt
147+
* @author Borahm Lee
147148
*/
148149
@EmbeddedKafka(topics = { KafkaMessageListenerContainerTests.topic1, KafkaMessageListenerContainerTests.topic2,
149150
KafkaMessageListenerContainerTests.topic3, KafkaMessageListenerContainerTests.topic4,
@@ -2595,16 +2596,18 @@ public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, Consumer
25952596
public void onMessage(ConsumerRecord<String, String> data) {
25962597
if (data.partition() == 0 && data.offset() == 0) {
25972598
TopicPartition topicPartition = new TopicPartition(data.topic(), data.partition());
2598-
final ConsumerSeekCallback seekCallbackFor = getSeekCallbackFor(topicPartition);
2599-
assertThat(seekCallbackFor).isNotNull();
2600-
seekCallbackFor.seekToBeginning(records.keySet());
2601-
Iterator<TopicPartition> iterator = records.keySet().iterator();
2602-
seekCallbackFor.seekToBeginning(Collections.singletonList(iterator.next()));
2603-
seekCallbackFor.seekToBeginning(Collections.singletonList(iterator.next()));
2604-
seekCallbackFor.seekToEnd(records.keySet());
2605-
iterator = records.keySet().iterator();
2606-
seekCallbackFor.seekToEnd(Collections.singletonList(iterator.next()));
2607-
seekCallbackFor.seekToEnd(Collections.singletonList(iterator.next()));
2599+
final List<ConsumerSeekCallback> seekCallbacksFor = getSeekCallbacksFor(topicPartition);
2600+
assertThat(seekCallbacksFor).isNotEmpty();
2601+
seekCallbacksFor.forEach(callback -> {
2602+
callback.seekToBeginning(records.keySet());
2603+
Iterator<TopicPartition> iterator = records.keySet().iterator();
2604+
callback.seekToBeginning(Collections.singletonList(iterator.next()));
2605+
callback.seekToBeginning(Collections.singletonList(iterator.next()));
2606+
callback.seekToEnd(records.keySet());
2607+
iterator = records.keySet().iterator();
2608+
callback.seekToEnd(Collections.singletonList(iterator.next()));
2609+
callback.seekToEnd(Collections.singletonList(iterator.next()));
2610+
});
26082611
}
26092612
}
26102613

0 commit comments

Comments
 (0)