diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java index b254e8ef3064..127b1cd54533 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java @@ -71,6 +71,9 @@ */ final class DefaultWebClient implements WebClient { + // Copy of CoExchangeFilterFunction.COROUTINE_CONTEXT_ATTRIBUTE value to avoid compilation errors in Eclipse + private static final String COROUTINE_CONTEXT_ATTRIBUTE = "org.springframework.web.reactive.function.client.CoExchangeFilterFunction.context"; + private static final String URI_TEMPLATE_ATTRIBUTE = WebClient.class.getName() + ".uriTemplate"; private static final Mono NO_HTTP_CLIENT_RESPONSE_ERROR = Mono.error( @@ -430,6 +433,8 @@ private Mono exchange() { if (filterFunctions != null) { filterFunction = filterFunctions.andThen(filterFunction); } + contextView.getOrEmpty(COROUTINE_CONTEXT_ATTRIBUTE) + .ifPresent(context -> requestBuilder.attribute(COROUTINE_CONTEXT_ATTRIBUTE, context)); ClientRequest request = requestBuilder.build(); observationContext.setUriTemplate((String) request.attribute(URI_TEMPLATE_ATTRIBUTE).orElse(null)); observationContext.setRequest(request); diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/CoExchangeFilterFunction.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/CoExchangeFilterFunction.kt index 940eb7210cb0..21d4236f500d 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/CoExchangeFilterFunction.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/CoExchangeFilterFunction.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,13 @@ package org.springframework.web.reactive.function.client import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.reactor.awaitSingle import kotlinx.coroutines.reactor.mono import reactor.core.publisher.Mono +import kotlin.coroutines.CoroutineContext +import kotlin.jvm.optionals.getOrNull /** * Kotlin-specific implementation of the [ExchangeFilterFunction] interface @@ -31,10 +35,14 @@ import reactor.core.publisher.Mono abstract class CoExchangeFilterFunction : ExchangeFilterFunction { final override fun filter(request: ClientRequest, next: ExchangeFunction): Mono { - return mono(Dispatchers.Unconfined) { + val context = request.attribute(COROUTINE_CONTEXT_ATTRIBUTE).getOrNull() as CoroutineContext? + return mono(context ?: Dispatchers.Unconfined) { filter(request, object : CoExchangeFunction { override suspend fun exchange(request: ClientRequest): ClientResponse { - return next.exchange(request).awaitSingle() + val newRequest = ClientRequest.from(request) + .attribute(COROUTINE_CONTEXT_ATTRIBUTE, currentCoroutineContext().minusKey(Job.Key)) + .build() + return next.exchange(newRequest).awaitSingle() } }) } @@ -58,6 +66,17 @@ abstract class CoExchangeFilterFunction : ExchangeFilterFunction { * @return the filtered response */ protected abstract suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse + + companion object { + + /** + * Name of the [ClientRequest] attribute that contains the + * [kotlin.coroutines.CoroutineContext] to be passed to the + * [CoExchangeFilterFunction.filter]. + */ + @JvmField + val COROUTINE_CONTEXT_ATTRIBUTE = CoExchangeFilterFunction::class.java.name + ".context" + } } diff --git a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt index 2f1e61c9b582..dd8449969c1a 100644 --- a/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt +++ b/spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt @@ -20,17 +20,18 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.reactive.asFlow -import kotlinx.coroutines.reactor.asFlux -import kotlinx.coroutines.reactor.awaitSingle -import kotlinx.coroutines.reactor.awaitSingleOrNull -import kotlinx.coroutines.reactor.mono +import kotlinx.coroutines.reactor.* +import kotlinx.coroutines.withContext import org.reactivestreams.Publisher import org.springframework.core.ParameterizedTypeReference import org.springframework.http.ResponseEntity +import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE import org.springframework.web.reactive.function.client.WebClient.RequestBodySpec import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec import reactor.core.publisher.Flux import reactor.core.publisher.Mono +import reactor.util.context.Context +import kotlin.coroutines.CoroutineContext /** * Extension for [WebClient.RequestBodySpec.body] providing a `body(Publisher)` variant @@ -38,6 +39,7 @@ import reactor.core.publisher.Mono * erasure and retains actual generic type arguments. * * @author Sebastien Deleuze + * @author Dmitry Sulman * @since 5.0 */ inline fun > RequestBodySpec.body(publisher: S): RequestHeadersSpec<*> = @@ -89,7 +91,7 @@ inline fun RequestBodySpec.bodyValueWithType(body: T): Request */ suspend fun RequestHeadersSpec>.awaitExchange(responseHandler: suspend (ClientResponse) -> T): T { val context = currentCoroutineContext().minusKey(Job.Key) - return exchangeToMono { mono(context) { responseHandler.invoke(it) } }.awaitSingle() + return withContext(context.toReactorContext()) { exchangeToMono { mono(context) { responseHandler.invoke(it) } }.awaitSingle() } } /** @@ -99,7 +101,7 @@ suspend fun RequestHeadersSpec>.awaitExchange */ suspend fun RequestHeadersSpec>.awaitExchangeOrNull(responseHandler: suspend (ClientResponse) -> T?): T? { val context = currentCoroutineContext().minusKey(Job.Key) - return exchangeToMono { mono(context) { responseHandler.invoke(it) } }.awaitSingleOrNull() + return withContext(context.toReactorContext()) { exchangeToMono { mono(context) { responseHandler.invoke(it) } }.awaitSingleOrNull() } } /** @@ -150,11 +152,15 @@ inline fun WebClient.ResponseSpec.bodyToFlow(): Flow = * @author Sebastien Deleuze * @since 5.2 */ -suspend inline fun WebClient.ResponseSpec.awaitBody() : T = - when (T::class) { - Unit::class -> awaitBodilessEntity().let { Unit as T } - else -> bodyToMono().awaitSingle() +suspend inline fun WebClient.ResponseSpec.awaitBody() : T { + val context = currentCoroutineContext().minusKey(Job.Key) + return withContext(context.toReactorContext()) { + when (T::class) { + Unit::class -> toBodilessEntity().awaitSingle().let { Unit as T } + else -> bodyToMono().awaitSingle() + } } +} /** * Coroutines variant of [WebClient.ResponseSpec.bodyToMono]. @@ -162,17 +168,23 @@ suspend inline fun WebClient.ResponseSpec.awaitBody() : T = * @author Valentin Shakhov * @since 5.3.6 */ -suspend inline fun WebClient.ResponseSpec.awaitBodyOrNull() : T? = - when (T::class) { - Unit::class -> awaitBodilessEntity().let { Unit as T? } - else -> bodyToMono().awaitSingleOrNull() +suspend inline fun WebClient.ResponseSpec.awaitBodyOrNull() : T? { + val context = currentCoroutineContext().minusKey(Job.Key) + return withContext(context.toReactorContext()) { + when (T::class) { + Unit::class -> toBodilessEntity().awaitSingle().let { Unit as T? } + else -> bodyToMono().awaitSingleOrNull() + } } +} /** * Coroutines variant of [WebClient.ResponseSpec.toBodilessEntity]. */ -suspend fun WebClient.ResponseSpec.awaitBodilessEntity() = - toBodilessEntity().awaitSingle() +suspend fun WebClient.ResponseSpec.awaitBodilessEntity(): ResponseEntity { + val context = currentCoroutineContext().minusKey(Job.Key) + return withContext(context.toReactorContext()) { toBodilessEntity().awaitSingle() } +} /** * Extension for [WebClient.ResponseSpec.toEntity] providing a `toEntity()` variant @@ -203,3 +215,22 @@ inline fun WebClient.ResponseSpec.toEntityList(): Mono WebClient.ResponseSpec.toEntityFlux(): Mono>> = toEntityFlux(object : ParameterizedTypeReference() {}) + +/** + * Extension for [WebClient.ResponseSpec.toEntity] providing a `toEntity()` variant + * leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext] + * propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure + * and retains actual generic type arguments. + * + * @since 7.0.0 + */ +suspend inline fun WebClient.ResponseSpec.awaitEntity(): ResponseEntity { + val context = currentCoroutineContext().minusKey(Job.Key) + return withContext(context.toReactorContext()) { toEntity(T::class.java).awaitSingle() } +} + +@PublishedApi +internal fun CoroutineContext.toReactorContext(): ReactorContext { + val context = Context.of(COROUTINE_CONTEXT_ATTRIBUTE, this).readOnly() + return (this[ReactorContext.Key]?.context?.putAll(context) ?: context).asCoroutineContext() +} diff --git a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt index 048b472676ff..0528a0cfd1a7 100644 --- a/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt +++ b/spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt @@ -25,13 +25,18 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.reactivestreams.Publisher import org.springframework.core.ParameterizedTypeReference +import org.springframework.http.HttpHeaders +import org.springframework.http.HttpStatus import org.springframework.http.ResponseEntity +import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE import reactor.core.publisher.Flux import reactor.core.publisher.Mono +import java.time.Duration import java.util.concurrent.CompletableFuture import java.util.function.Function import kotlin.coroutines.AbstractCoroutineContextElement @@ -41,6 +46,7 @@ import kotlin.coroutines.CoroutineContext * Mock object based tests for [WebClient] Kotlin extensions * * @author Sebastien Deleuze + * @author Dmitry Sulman */ class WebClientExtensionsTests { @@ -226,6 +232,225 @@ class WebClientExtensionsTests { verify { responseSpec.toEntityFlux(object : ParameterizedTypeReference>() {}) } } + @Test + fun `ResponseSpec#awaitEntity with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val mockClientHeaders = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.headers() } returns mockClientHeaders + every { mockClientHeaders.asHttpHeaders() } returns HttpHeaders() + every { mockResponse.bodyToMono(Foo::class.java) } returns Mono.just(foo) + runBlocking(FooContextElement(foo)) { + val responseEntity = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().awaitEntity() + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseEntity.body).isEqualTo(foo) + } + } + + @Test + fun `ResponseSpec#awaitEntity with coroutine context propagation to multiple CoExchangeFilterFunctions`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val mockClientHeaders = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.headers() } returns mockClientHeaders + every { mockClientHeaders.asHttpHeaders() } returns HttpHeaders() + every { mockResponse.bodyToMono(Foo::class.java) } returns Mono.just(foo) + runBlocking { + val responseEntity = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + return withContext(FooContextElement(foo)) { next.exchange(request) } + } + }) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().awaitEntity() + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseEntity.body).isEqualTo(foo) + } + } + + @Test + fun `ResponseSpec#toEntity with coroutine context propagation to multiple CoExchangeFilterFunctions`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val mockClientHeaders = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.headers() } returns mockClientHeaders + every { mockClientHeaders.asHttpHeaders() } returns HttpHeaders() + every { mockResponse.bodyToMono(Foo::class.java) } returns Mono.just(foo) + val responseEntity = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + return withContext(FooContextElement(foo)) { next.exchange(request) } + } + }) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().toEntity(Foo::class.java) + .block(Duration.ofSeconds(10)) + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseEntity!!.body).isEqualTo(foo) + } + + @Test + fun `ResponseSpec#awaitExchange with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.releaseBody() } returns Mono.empty() + runBlocking(FooContextElement(foo)) { + val responseBody = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").awaitExchange { foo } + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseBody).isEqualTo(foo) + } + } + + @Test + fun `ResponseSpec#awaitExchangeOrNull with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.releaseBody() } returns Mono.empty() + runBlocking(FooContextElement(foo)) { + val responseBody = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").awaitExchangeOrNull { foo } + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseBody).isEqualTo(foo) + } + } + + @Test + fun `ResponseSpec#awaitBody with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.bodyToMono(object : ParameterizedTypeReference() {}) } returns Mono.just(foo) + runBlocking(FooContextElement(foo)) { + val responseBody = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().awaitBody() + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseBody).isEqualTo(foo) + } + } + + @Test + fun `ResponseSpec#awaitBodyOrNull with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.bodyToMono(object : ParameterizedTypeReference() {}) } returns Mono.just(foo) + runBlocking(FooContextElement(foo)) { + val responseBody = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().awaitBodyOrNull() + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseBody).isEqualTo(foo) + } + } + + @Test + fun `ResponseSpec#awaitBodilessEntity with coroutine context propagation`() { + val exchangeFunction = mockk() + val mockResponse = mockk() + val mockClientHeaders = mockk() + val foo = mockk() + val slot = slot() + every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse) + every { mockResponse.statusCode() } returns HttpStatus.OK + every { mockResponse.releaseBody() } returns Mono.empty() + every { mockResponse.headers() } returns mockClientHeaders + every { mockClientHeaders.asHttpHeaders() } returns HttpHeaders() + runBlocking(FooContextElement(foo)) { + val responseEntity = WebClient.builder() + .exchangeFunction(exchangeFunction) + .filter(object : CoExchangeFilterFunction() { + override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse { + assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo) + return next.exchange(request) + } + }) + .build().get().uri("/path").retrieve().awaitBodilessEntity() + val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext + assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo) + assertThat(responseEntity.hasBody()).isEqualTo(false) + } + } + class Foo private data class FooContextElement(val foo: Foo) : AbstractCoroutineContextElement(FooContextElement) {