Skip to content

Commit 20f30f8

Browse files
committed
Add support for Kotlin coroutine channels for subscription queries
1 parent fa6c543 commit 20f30f8

File tree

7 files changed

+164
-33
lines changed

7 files changed

+164
-33
lines changed

pom.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
<artifactId>kotlinx-coroutines-jdk8</artifactId>
4040
<version>1.0.0</version>
4141
</dependency>
42+
<dependency>
43+
<groupId>org.jetbrains.kotlinx</groupId>
44+
<artifactId>kotlinx-coroutines-reactive</artifactId>
45+
<version>1.0.0</version>
46+
</dependency>
4247
<dependency>
4348
<groupId>com.graphql-java</groupId>
4449
<artifactId>graphql-java</artifactId>
@@ -120,6 +125,12 @@
120125
<version>2.1</version>
121126
<scope>test</scope>
122127
</dependency>
128+
<dependency>
129+
<groupId>org.reactivestreams</groupId>
130+
<artifactId>reactive-streams-tck</artifactId>
131+
<version>1.0.2</version>
132+
<scope>test</scope>
133+
</dependency>
123134
</dependencies>
124135

125136
<build>

src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import org.slf4j.LoggerFactory
1010
import java.lang.reflect.Modifier
1111
import java.lang.reflect.ParameterizedType
1212
import kotlin.reflect.full.valueParameters
13+
import kotlin.reflect.jvm.javaType
1314
import kotlin.reflect.jvm.kotlinFunction
1415

1516
/**
@@ -114,9 +115,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
114115
true
115116
}
116117

117-
val correctParameterCount = method.parameterCount == requiredCount ||
118-
(method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last())) ||
119-
(method.kotlinFunction?.run { isSuspend && valueParameters.size == requiredCount } == true)
118+
val methodParameterCount = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount
119+
val methodLastParameter = method.kotlinFunction?.valueParameters?.lastOrNull()?.type?.javaType ?: method.parameterTypes.lastOrNull()
120+
121+
val correctParameterCount = methodParameterCount == requiredCount ||
122+
(methodParameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(methodLastParameter))
120123
return correctParameterCount && appropriateFirstParameter
121124
}
122125

src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -155,31 +155,26 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso
155155
val source = sourceResolver(environment)
156156
val args = this.args.map { it(environment) }.toTypedArray()
157157

158-
val result = if (isSuspendFunction) {
158+
return if (isSuspendFunction) {
159159
GlobalScope.future(options.coroutineContext) {
160160
suspendCoroutineUninterceptedOrReturn<Any?> { continuation ->
161-
methodAccess.invoke(source, methodIndex, *args + continuation)
161+
methodAccess.invoke(source, methodIndex, *args + continuation)?.transformWithGenericWrapper(environment)
162162
}
163163
}
164164
} else {
165-
methodAccess.invoke(source, methodIndex, *args)
166-
}
167-
return if (result == null) {
168-
null
169-
} else {
170-
val wrapper = options.genericWrappers
171-
.asSequence()
172-
.filter { it.type.isInstance(result) }
173-
.sortedWith(CompareGenericWrappers)
174-
.firstOrNull()
175-
if (wrapper == null) {
176-
result
177-
} else {
178-
wrapper.transformer.invoke(result, environment)
179-
}
165+
methodAccess.invoke(source, methodIndex, *args)?.transformWithGenericWrapper(environment)
180166
}
181167
}
182168

169+
private fun Any.transformWithGenericWrapper(environment: DataFetchingEnvironment): Any? {
170+
return options.genericWrappers
171+
.asSequence()
172+
.filter { it.type.isInstance(this) }
173+
.sortedWith(CompareGenericWrappers)
174+
.firstOrNull()
175+
?.transformer?.invoke(this, environment) ?: this
176+
}
177+
183178
/**
184179
* Function that return the object used to fetch the data
185180
* It can be a DataFetcher or an entity

src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import graphql.parser.Parser
99
import graphql.schema.DataFetchingEnvironment
1010
import graphql.schema.GraphQLScalarType
1111
import kotlinx.coroutines.Dispatchers
12+
import kotlinx.coroutines.GlobalScope
13+
import kotlinx.coroutines.channels.ReceiveChannel
14+
import kotlinx.coroutines.reactive.publish
1215
import org.antlr.v4.runtime.RecognitionException
1316
import org.antlr.v4.runtime.misc.ParseCancellationException
1417
import org.reactivestreams.Publisher
@@ -331,20 +334,31 @@ data class SchemaParserOptions internal constructor(
331334
}
332335

333336
fun build(): SchemaParserOptions {
337+
val coroutineContext = coroutineContext ?: Dispatchers.Default
334338
val wrappers = if (useDefaultGenericWrappers) {
335339
genericWrappers + listOf(
336340
GenericWrapper(Future::class, 0),
337341
GenericWrapper(CompletableFuture::class, 0),
338342
GenericWrapper(CompletionStage::class, 0),
339-
GenericWrapper(Publisher::class, 0)
343+
GenericWrapper(Publisher::class, 0),
344+
GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel ->
345+
GlobalScope.publish(coroutineContext) {
346+
try {
347+
for (item in receiveChannel) {
348+
send(item)
349+
}
350+
} finally {
351+
receiveChannel.cancel()
352+
}
353+
}
354+
})
340355
)
341356
} else {
342357
genericWrappers
343358
}
344359

345360
return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider,
346-
proxyHandlers, preferGraphQLResolver, introspectionEnabled,
347-
coroutineContext ?: Dispatchers.Default)
361+
proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContext)
348362
}
349363
}
350364

src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import graphql.execution.batched.BatchedExecutionStrategy
77
import graphql.schema.GraphQLSchema
88
import org.reactivestreams.Publisher
99
import org.reactivestreams.Subscriber
10+
import org.reactivestreams.tck.TestEnvironment
1011
import spock.lang.Shared
1112
import spock.lang.Specification
1213

@@ -563,18 +564,56 @@ class EndToEndSpec extends Specification {
563564
564565
def "generated schema supports Kotlin suspend functions"() {
565566
when:
566-
def data = Utils.assertNoGraphQlErrors(gql) {
567-
'''
568-
{
569-
coroutineItems {
570-
id
571-
name
567+
def data = Utils.assertNoGraphQlErrors(gql) {
568+
'''
569+
{
570+
coroutineItems {
571+
id
572+
name
573+
}
572574
}
573-
}
575+
'''
576+
}
577+
578+
then:
579+
data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]]
580+
}
581+
582+
def "generated schema supports Kotlin coroutine channels for the subscription query"() {
583+
when:
584+
def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), [])
585+
def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) {
574586
'''
575-
}
587+
subscription {
588+
onItemCreatedCoroutineChannel {
589+
id
590+
}
591+
}
592+
'''
593+
}
594+
def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher<ExecutionResult>)
595+
596+
then:
597+
subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannel").id == 1
598+
subscriber.expectCompletion()
599+
}
600+
601+
def "generated schema supports Kotlin coroutine channels with suspend function for the subscription query"() {
602+
when:
603+
def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), [])
604+
def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) {
605+
'''
606+
subscription {
607+
onItemCreatedCoroutineChannelAndSuspendFunction {
608+
id
609+
}
610+
}
611+
'''
612+
}
613+
def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher<ExecutionResult>)
576614
577615
then:
578-
data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]]
616+
subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannelAndSuspendFunction").id == 1
617+
subscriber.expectCompletion()
579618
}
580619
}

src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import graphql.language.StringValue
77
import graphql.schema.Coercing
88
import graphql.schema.DataFetchingEnvironment
99
import graphql.schema.GraphQLScalarType
10-
import kotlinx.coroutines.CompletableDeferred
10+
import kotlinx.coroutines.*
11+
import kotlinx.coroutines.channels.Channel
12+
import kotlinx.coroutines.channels.ReceiveChannel
1113
import org.reactivestreams.Publisher
1214
import java.util.Optional
1315
import java.util.UUID
@@ -107,6 +109,8 @@ type Mutation {
107109
108110
type Subscription {
109111
onItemCreated: Item!
112+
onItemCreatedCoroutineChannel: Item!
113+
onItemCreatedCoroutineChannelAndSuspendFunction: Item!
110114
}
111115
112116
input ItemSearchInput {
@@ -306,6 +310,20 @@ class Subscription : GraphQLSubscriptionResolver {
306310
subscriber.onNext(env.getContext<OnItemCreatedContext>().newItem)
307311
// subscriber.onComplete()
308312
}
313+
314+
fun onItemCreatedCoroutineChannel(env: DataFetchingEnvironment): ReceiveChannel<Item> {
315+
val channel = Channel<Item>(1)
316+
channel.offer(env.getContext<OnItemCreatedContext>().newItem)
317+
return channel
318+
}
319+
320+
suspend fun onItemCreatedCoroutineChannelAndSuspendFunction(env: DataFetchingEnvironment): ReceiveChannel<Item> {
321+
return coroutineScope {
322+
val channel = Channel<Item>(1)
323+
channel.offer(env.getContext<OnItemCreatedContext>().newItem)
324+
channel
325+
}
326+
}
309327
}
310328

311329
class ItemResolver : GraphQLResolver<Item> {

src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ import graphql.schema.DataFetchingEnvironment
1111
import graphql.schema.DataFetchingEnvironmentBuilder
1212
import kotlinx.coroutines.Dispatchers
1313
import kotlinx.coroutines.Job
14+
import kotlinx.coroutines.channels.Channel
15+
import kotlinx.coroutines.channels.ReceiveChannel
1416
import org.junit.Assert
1517
import org.junit.Test
18+
import org.reactivestreams.Publisher
19+
import org.reactivestreams.tck.TestEnvironment
1620
import java.util.concurrent.CompletableFuture
1721
import kotlin.coroutines.coroutineContext
1822

@@ -39,6 +43,53 @@ class MethodFieldResolverDataFetcherTest {
3943
Assert.assertTrue(future.get())
4044
}
4145

46+
@Test
47+
fun `canceling subscription Publisher also cancels underlying Kotlin coroutine channel`() {
48+
// setup
49+
val channel = Channel<String>(10)
50+
channel.offer("A")
51+
channel.offer("B")
52+
53+
val resolver = createFetcher("onDataNameChanged", object : GraphQLResolver<DataClass> {
54+
fun onDataNameChanged(date: DataClass): ReceiveChannel<String> {
55+
return channel
56+
}
57+
})
58+
59+
// expect
60+
@Suppress("UNCHECKED_CAST")
61+
val publisher = resolver.get(createEnvironment(DataClass())) as Publisher<String>
62+
val subscriber = TestEnvironment().newManualSubscriber(publisher)
63+
64+
Assert.assertEquals("A", subscriber.requestNextElement())
65+
66+
subscriber.cancel()
67+
Thread.sleep(100)
68+
Assert.assertTrue(channel.isClosedForReceive)
69+
}
70+
71+
@Test
72+
fun `canceling underlying Kotlin coroutine channel also cancels subscription Publisher`() {
73+
// setup
74+
val channel = Channel<String>(10)
75+
channel.offer("A")
76+
channel.close(IllegalStateException("Channel error"))
77+
78+
val resolver = createFetcher("onDataNameChanged", object : GraphQLResolver<DataClass> {
79+
fun onDataNameChanged(date: DataClass): ReceiveChannel<String> {
80+
return channel
81+
}
82+
})
83+
84+
// expect
85+
@Suppress("UNCHECKED_CAST")
86+
val publisher = resolver.get(createEnvironment(DataClass())) as Publisher<String>
87+
val subscriber = TestEnvironment().newManualSubscriber(publisher)
88+
89+
Assert.assertEquals("A", subscriber.requestNextElement())
90+
subscriber.expectErrorWithMessage(IllegalStateException::class.java, "Channel error")
91+
}
92+
4293
private fun createFetcher(
4394
methodName: String,
4495
resolver: GraphQLResolver<*>,

0 commit comments

Comments
 (0)