Skip to content

Commit 8320262

Browse files
committed
Propagate CoroutineContext to WebClient filter
This commit introduces a new ResponseSpec.awaitEntityOrNull() extension function to replace ResponseSpec.toEntity(...).awaitFirstOrNull() and pass the CoroutineContext to the CoExchangeFilterFunction. CoroutineContext propagation is implemented via ReactorContext and ClientRequest attribute. See gh-32148 Signed-off-by: Dmitry Sulman <dmitry.sulman@gmail.com>
1 parent 86d8163 commit 8320262

File tree

4 files changed

+100
-7
lines changed

4 files changed

+100
-7
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
import org.springframework.web.util.UriBuilder;
5959
import org.springframework.web.util.UriBuilderFactory;
6060

61+
import static org.springframework.web.reactive.function.client.CoExchangeFilterFunction.COROUTINE_CONTEXT_ATTRIBUTE;
62+
6163
/**
6264
* The default implementation of {@link WebClient},
6365
* as created by the static factory methods.
@@ -430,6 +432,8 @@ private Mono<ClientResponse> exchange() {
430432
if (filterFunctions != null) {
431433
filterFunction = filterFunctions.andThen(filterFunction);
432434
}
435+
contextView.getOrEmpty(COROUTINE_CONTEXT_ATTRIBUTE)
436+
.ifPresent(context -> requestBuilder.attribute(COROUTINE_CONTEXT_ATTRIBUTE, context));
433437
ClientRequest request = requestBuilder.build();
434438
observationContext.setUriTemplate((String) request.attribute(URI_TEMPLATE_ATTRIBUTE).orElse(null));
435439
observationContext.setRequest(request);

spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/CoExchangeFilterFunction.kt

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,9 +17,13 @@
1717
package org.springframework.web.reactive.function.client
1818

1919
import kotlinx.coroutines.Dispatchers
20+
import kotlinx.coroutines.Job
21+
import kotlinx.coroutines.currentCoroutineContext
2022
import kotlinx.coroutines.reactor.awaitSingle
2123
import kotlinx.coroutines.reactor.mono
2224
import reactor.core.publisher.Mono
25+
import kotlin.coroutines.CoroutineContext
26+
import kotlin.jvm.optionals.getOrNull
2327

2428
/**
2529
* Kotlin-specific implementation of the [ExchangeFilterFunction] interface
@@ -31,10 +35,14 @@ import reactor.core.publisher.Mono
3135
abstract class CoExchangeFilterFunction : ExchangeFilterFunction {
3236

3337
final override fun filter(request: ClientRequest, next: ExchangeFunction): Mono<ClientResponse> {
34-
return mono(Dispatchers.Unconfined) {
38+
val context = request.attribute(COROUTINE_CONTEXT_ATTRIBUTE).getOrNull() as CoroutineContext?
39+
return mono(context ?: Dispatchers.Unconfined) {
3540
filter(request, object : CoExchangeFunction {
3641
override suspend fun exchange(request: ClientRequest): ClientResponse {
37-
return next.exchange(request).awaitSingle()
42+
val newRequest = ClientRequest.from(request)
43+
.attribute(COROUTINE_CONTEXT_ATTRIBUTE, currentCoroutineContext().minusKey(Job.Key))
44+
.build()
45+
return next.exchange(newRequest).awaitSingle()
3846
}
3947
})
4048
}
@@ -58,6 +66,17 @@ abstract class CoExchangeFilterFunction : ExchangeFilterFunction {
5866
* @return the filtered response
5967
*/
6068
protected abstract suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse
69+
70+
companion object {
71+
72+
/**
73+
* Name of the [ClientRequest] attribute that contains the
74+
* [kotlin.coroutines.CoroutineContext] to be passed to the
75+
* [CoExchangeFilterFunction.filter].
76+
*/
77+
@JvmField
78+
val COROUTINE_CONTEXT_ATTRIBUTE = CoExchangeFilterFunction::class.java.name + ".context"
79+
}
6180
}
6281

6382

spring-webflux/src/main/kotlin/org/springframework/web/reactive/function/client/WebClientExtensions.kt

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ import kotlinx.coroutines.Job
2020
import kotlinx.coroutines.currentCoroutineContext
2121
import kotlinx.coroutines.flow.Flow
2222
import kotlinx.coroutines.reactive.asFlow
23-
import kotlinx.coroutines.reactor.asFlux
24-
import kotlinx.coroutines.reactor.awaitSingle
25-
import kotlinx.coroutines.reactor.awaitSingleOrNull
26-
import kotlinx.coroutines.reactor.mono
23+
import kotlinx.coroutines.reactor.*
24+
import kotlinx.coroutines.withContext
2725
import org.reactivestreams.Publisher
2826
import org.springframework.core.ParameterizedTypeReference
2927
import org.springframework.http.ResponseEntity
28+
import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE
3029
import org.springframework.web.reactive.function.client.WebClient.RequestBodySpec
3130
import org.springframework.web.reactive.function.client.WebClient.RequestHeadersSpec
3231
import reactor.core.publisher.Flux
3332
import reactor.core.publisher.Mono
33+
import reactor.util.context.Context
3434

3535
/**
3636
* Extension for [WebClient.RequestBodySpec.body] providing a `body(Publisher<T>)` variant
@@ -203,3 +203,19 @@ inline fun <reified T : Any> WebClient.ResponseSpec.toEntityList(): Mono<Respons
203203
*/
204204
inline fun <reified T : Any> WebClient.ResponseSpec.toEntityFlux(): Mono<ResponseEntity<Flux<T>>> =
205205
toEntityFlux(object : ParameterizedTypeReference<T>() {})
206+
207+
208+
/**
209+
* Extension for [WebClient.ResponseSpec.toEntity] providing a `toEntity<Foo>()` variant
210+
* leveraging Kotlin reified type parameters and allows [kotlin.coroutines.CoroutineContext]
211+
* propagation to the [CoExchangeFilterFunction]. This extension is not subject to type erasure
212+
* and retains actual generic type arguments.
213+
*
214+
* @since 7.0.0
215+
*/
216+
suspend inline fun <reified T : Any> WebClient.ResponseSpec.awaitEntityOrNull(): ResponseEntity<T>? {
217+
val coroutineContext = currentCoroutineContext().minusKey(Job.Key).minusKey(ReactorContext.Key)
218+
val reactorContext = currentCoroutineContext()[ReactorContext.Key]?.context ?: Context.empty()
219+
val newReactorContext = reactorContext.put(COROUTINE_CONTEXT_ATTRIBUTE, coroutineContext)
220+
return withContext(newReactorContext.asCoroutineContext()) { toEntity(T::class.java).awaitSingleOrNull() }
221+
}

spring-webflux/src/test/kotlin/org/springframework/web/reactive/function/client/WebClientExtensionsTests.kt

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,19 @@ import kotlinx.coroutines.flow.Flow
2525
import kotlinx.coroutines.flow.flow
2626
import kotlinx.coroutines.flow.toList
2727
import kotlinx.coroutines.runBlocking
28+
import kotlinx.coroutines.withContext
2829
import org.assertj.core.api.Assertions.assertThat
2930
import org.junit.jupiter.api.Test
3031
import org.reactivestreams.Publisher
3132
import org.springframework.core.ParameterizedTypeReference
33+
import org.springframework.http.HttpHeaders
34+
import org.springframework.http.HttpStatus
35+
import org.springframework.http.MediaType
3236
import org.springframework.http.ResponseEntity
37+
import org.springframework.web.reactive.function.client.CoExchangeFilterFunction.Companion.COROUTINE_CONTEXT_ATTRIBUTE
3338
import reactor.core.publisher.Flux
3439
import reactor.core.publisher.Mono
40+
import java.util.*
3541
import java.util.concurrent.CompletableFuture
3642
import java.util.function.Function
3743
import kotlin.coroutines.AbstractCoroutineContextElement
@@ -226,9 +232,57 @@ class WebClientExtensionsTests {
226232
verify { responseSpec.toEntityFlux(object : ParameterizedTypeReference<List<Foo>>() {}) }
227233
}
228234

235+
@Test
236+
fun `ResponseSpec#awaitEntityOrNull with coroutine context propagation`() {
237+
val exchangeFunction = mockk<ExchangeFunction>()
238+
val mockResponse = mockk<ClientResponse>()
239+
val foo = mockk<Foo>()
240+
val slot = slot<ClientRequest>()
241+
every { exchangeFunction.exchange(capture(slot)) } returns Mono.just(mockResponse)
242+
every { mockResponse.statusCode() } returns HttpStatus.OK
243+
every { mockResponse.headers() } returns MockClientHeaders()
244+
every { mockResponse.bodyToMono(Foo::class.java) } returns Mono.just(foo)
245+
runBlocking {
246+
withContext(FooContextElement(foo)) {
247+
val responseEntity = WebClient.builder()
248+
.exchangeFunction(exchangeFunction)
249+
.filter(object : CoExchangeFilterFunction() {
250+
override suspend fun filter(request: ClientRequest, next: CoExchangeFunction): ClientResponse {
251+
assertThat(currentCoroutineContext()[FooContextElement.Key]!!.foo).isEqualTo(foo)
252+
return next.exchange(request)
253+
}
254+
})
255+
.build().get().uri("/path").retrieve().awaitEntityOrNull<Foo>()
256+
val capturedContext = slot.captured.attribute(COROUTINE_CONTEXT_ATTRIBUTE).get() as CoroutineContext
257+
assertThat(capturedContext[FooContextElement.Key]!!.foo).isEqualTo(foo)
258+
assertThat(responseEntity!!.body).isEqualTo(foo)
259+
}
260+
}
261+
}
262+
229263
class Foo
230264

231265
private data class FooContextElement(val foo: Foo) : AbstractCoroutineContextElement(FooContextElement) {
232266
companion object Key : CoroutineContext.Key<FooContextElement>
233267
}
268+
269+
private class MockClientHeaders : ClientResponse.Headers {
270+
private val headers = HttpHeaders()
271+
272+
override fun contentLength(): OptionalLong {
273+
return OptionalLong.empty()
274+
}
275+
276+
override fun contentType(): Optional<MediaType> {
277+
return Optional.empty()
278+
}
279+
280+
override fun header(headerName: String): List<String> {
281+
return emptyList()
282+
}
283+
284+
override fun asHttpHeaders(): HttpHeaders {
285+
return headers
286+
}
287+
}
234288
}

0 commit comments

Comments
 (0)