diff --git a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java index df2fed925b..586a548a6a 100644 --- a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java +++ b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/KeyValueItemWriter.java @@ -43,7 +43,9 @@ public void write(List items) throws Exception { K key = itemKeyMapper.convert(item); writeKeyValue(key, item); } + flush(); } + protected void flush() throws Exception {} /** * Subclasses implement this method to write each item to key value store diff --git a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java index 75c6f2b040..f9ad5125d7 100644 --- a/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java +++ b/spring-batch-infrastructure/src/main/java/org/springframework/batch/item/kafka/KafkaItemWriter.java @@ -19,7 +19,12 @@ import org.springframework.batch.item.ItemWriter; import org.springframework.batch.item.KeyValueItemWriter; import org.springframework.kafka.core.KafkaTemplate; +import org.springframework.kafka.support.SendResult; import org.springframework.util.Assert; +import org.springframework.util.concurrent.ListenableFuture; + +import java.util.ArrayList; +import java.util.List; /** *

@@ -34,15 +39,24 @@ public class KafkaItemWriter extends KeyValueItemWriter { protected KafkaTemplate kafkaTemplate; + private final List>> listenableFutures = new ArrayList<>(); @Override protected void writeKeyValue(K key, T value) { if (this.delete) { - this.kafkaTemplate.sendDefault(key, null); + listenableFutures.add(this.kafkaTemplate.sendDefault(key, null)); } else { - this.kafkaTemplate.sendDefault(key, value); + listenableFutures.add(this.kafkaTemplate.sendDefault(key, value)); + } + } + @Override + protected void flush() throws Exception{ + kafkaTemplate.flush(); + for(ListenableFuture> future: listenableFutures){ + future.get(); } + listenableFutures.clear(); } @Override diff --git a/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java b/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java index cab3bb8647..19337af5b5 100644 --- a/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java +++ b/spring-batch-infrastructure/src/test/java/org/springframework/batch/item/kafka/KafkaItemWriterTests.java @@ -25,25 +25,31 @@ import org.springframework.core.convert.converter.Converter; import org.springframework.kafka.core.KafkaTemplate; +import org.springframework.kafka.support.SendResult; +import org.springframework.util.concurrent.ListenableFuture; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; public class KafkaItemWriterTests { @Mock private KafkaTemplate kafkaTemplate; + @Mock + private ListenableFuture> future; + private KafkaItemKeyMapper itemKeyMapper; private KafkaItemWriter writer; @Before public void setUp() throws Exception { - MockitoAnnotations.initMocks(this); + MockitoAnnotations.openMocks(this); when(this.kafkaTemplate.getDefaultTopic()).thenReturn("defaultTopic"); + when(this.kafkaTemplate.sendDefault(any(), any())).thenReturn(future); this.itemKeyMapper = new KafkaItemKeyMapper(); this.writer = new KafkaItemWriter<>(); this.writer.setKafkaTemplate(this.kafkaTemplate); @@ -90,6 +96,8 @@ public void testBasicWrite() throws Exception { verify(this.kafkaTemplate).sendDefault(items.get(0), items.get(0)); verify(this.kafkaTemplate).sendDefault(items.get(1), items.get(1)); + verify(this.kafkaTemplate).flush(); + verify(this.future, times(2)).get(); } @Test @@ -101,6 +109,8 @@ public void testBasicDelete() throws Exception { verify(this.kafkaTemplate).sendDefault(items.get(0), null); verify(this.kafkaTemplate).sendDefault(items.get(1), null); + verify(this.kafkaTemplate).flush(); + verify(this.future, times(2)).get(); } @Test