diff --git a/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java b/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java index 484d244ad45..34d80b52d59 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/channel/FluxMessageChannel.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.LockSupport; import org.reactivestreams.Publisher; @@ -100,18 +101,42 @@ public void subscribe(Subscriber> subscriber) { .share() .subscribe(subscriber); - this.upstreamSubscriptions.add( + Mono subscribersBarrier = Mono.fromCallable(() -> this.sink.currentSubscriberCount() > 0) .filter(Boolean::booleanValue) .doOnNext(this.subscribedSignal::tryEmitNext) .repeatWhenEmpty((repeat) -> - this.active ? repeat.delayElements(Duration.ofMillis(100)) : repeat) // NOSONAR - .subscribe()); + this.active ? repeat.delayElements(Duration.ofMillis(100)) : repeat); // NOSONAR + + addPublisherToSubscribe(Flux.from(subscribersBarrier)); + } + + private void addPublisherToSubscribe(Flux publisher) { + AtomicReference disposableReference = new AtomicReference<>(); + + Disposable disposable = + publisher + .doOnTerminate(() -> disposeUpstreamSubscription(disposableReference)) + .subscribe(); + + if (!disposable.isDisposed()) { + if (this.upstreamSubscriptions.add(disposable)) { + disposableReference.set(disposable); + } + } + } + + private void disposeUpstreamSubscription(AtomicReference disposableReference) { + Disposable disposable = disposableReference.get(); + if (disposable != null) { + this.upstreamSubscriptions.remove(disposable); + disposable.dispose(); + } } @Override public void subscribeTo(Publisher> publisher) { - this.upstreamSubscriptions.add( + Flux upstreamPublisher = Flux.from(publisher) .delaySubscription(this.subscribedSignal.asFlux().filter(Boolean::booleanValue).next()) .publishOn(this.scheduler) @@ -119,8 +144,9 @@ public void subscribeTo(Publisher> publisher) { Mono.just(message) .handle((messageToHandle, syncSink) -> sendReactiveMessage(messageToHandle)) .contextWrite(StaticMessageHeaderAccessor.getReactorContext(message))) - .contextCapture() - .subscribe()); + .contextCapture(); + + addPublisherToSubscribe(upstreamPublisher); } private void sendReactiveMessage(Message message) { diff --git a/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java b/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java index 9fdfb392d57..40fb4d3aa76 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/channel/reactive/FluxMessageChannelTests.java @@ -26,6 +26,8 @@ import org.junit.jupiter.api.Test; import reactor.core.Disposable; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -144,6 +146,25 @@ void testFluxMessageChannelCleanUp() throws InterruptedException { .until(() -> TestUtils.getPropertyValue(flux, "sink.sink.done", Boolean.class)); } + @Test + void noMemoryLeakInFluxMessageChannelForVolatilePublishers() { + FluxMessageChannel messageChannel = new FluxMessageChannel(); + + StepVerifier stepVerifier = StepVerifier.create(messageChannel) + .expectNextCount(3) + .thenCancel() + .verifyLater(); + + messageChannel.subscribeTo(Mono.just(new GenericMessage<>("test"))); + messageChannel.subscribeTo(Flux.just("test1", "test2").map(GenericMessage::new)); + + stepVerifier.verify(); + + Disposable.Composite upstreamSubscriptions = + TestUtils.getPropertyValue(messageChannel, "upstreamSubscriptions", Disposable.Composite.class); + assertThat(upstreamSubscriptions.size()).isEqualTo(0); + } + @Configuration @EnableIntegration public static class TestConfiguration {