diff --git a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java
index df2fed925b..586a548a6a 100644
--- a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java
+++ b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java
@@ -43,7 +43,9 @@ public void write(List extends V> items) throws Exception {
K key = itemKeyMapper.convert(item);
writeKeyValue(key, item);
}
+ flush();
}
+ protected void flush() throws Exception {}
/**
* Subclasses implement this method to write each item to key value store
diff --git a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java
index 75c6f2b040..f9ad5125d7 100644
--- a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java
+++ b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java
@@ -19,7 +19,12 @@
import org.springframework.batch.item.ItemWriter;
import org.springframework.batch.item.KeyValueItemWriter;
import org.springframework.kafka.core.KafkaTemplate;
+import org.springframework.kafka.support.SendResult;
import org.springframework.util.Assert;
+import org.springframework.util.concurrent.ListenableFuture;
+
+import java.util.ArrayList;
+import java.util.List;
/**
*
@@ -34,15 +39,24 @@
public class KafkaItemWriter extends KeyValueItemWriter {
protected KafkaTemplate kafkaTemplate;
+ private final List>> listenableFutures = new ArrayList<>();
@Override
protected void writeKeyValue(K key, T value) {
if (this.delete) {
- this.kafkaTemplate.sendDefault(key, null);
+ listenableFutures.add(this.kafkaTemplate.sendDefault(key, null));
}
else {
- this.kafkaTemplate.sendDefault(key, value);
+ listenableFutures.add(this.kafkaTemplate.sendDefault(key, value));
+ }
+ }
+ @Override
+ protected void flush() throws Exception{
+ kafkaTemplate.flush();
+ for(ListenableFuture> future: listenableFutures){
+ future.get();
}
+ listenableFutures.clear();
}
@Override
diff --git a/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java b/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java
index cab3bb8647..19337af5b5 100644
--- a/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java
+++ b/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java
@@ -25,25 +25,31 @@
import org.springframework.core.convert.converter.Converter;
import org.springframework.kafka.core.KafkaTemplate;
+import org.springframework.kafka.support.SendResult;
+import org.springframework.util.concurrent.ListenableFuture;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.*;
public class KafkaItemWriterTests {
@Mock
private KafkaTemplate kafkaTemplate;
+ @Mock
+ private ListenableFuture> future;
+
private KafkaItemKeyMapper itemKeyMapper;
private KafkaItemWriter writer;
@Before
public void setUp() throws Exception {
- MockitoAnnotations.initMocks(this);
+ MockitoAnnotations.openMocks(this);
when(this.kafkaTemplate.getDefaultTopic()).thenReturn("defaultTopic");
+ when(this.kafkaTemplate.sendDefault(any(), any())).thenReturn(future);
this.itemKeyMapper = new KafkaItemKeyMapper();
this.writer = new KafkaItemWriter<>();
this.writer.setKafkaTemplate(this.kafkaTemplate);
@@ -90,6 +96,8 @@ public void testBasicWrite() throws Exception {
verify(this.kafkaTemplate).sendDefault(items.get(0), items.get(0));
verify(this.kafkaTemplate).sendDefault(items.get(1), items.get(1));
+ verify(this.kafkaTemplate).flush();
+ verify(this.future, times(2)).get();
}
@Test
@@ -101,6 +109,8 @@ public void testBasicDelete() throws Exception {
verify(this.kafkaTemplate).sendDefault(items.get(0), null);
verify(this.kafkaTemplate).sendDefault(items.get(1), null);
+ verify(this.kafkaTemplate).flush();
+ verify(this.future, times(2)).get();
}
@Test