From fa6c543a4b905a0568163a6e0e3a9c3879c96ed7 Mon Sep 17 00:00:00 2001 From: Mykola Varahash Date: Sat, 3 Nov 2018 13:21:48 +0200 Subject: [PATCH 1/3] Add support for Kotlin suspend functions --- pom.xml | 12 ++- .../graphql/tools/FieldResolverScanner.kt | 6 +- .../graphql/tools/MethodFieldResolver.kt | 26 ++++-- .../graphql/tools/SchemaParserBuilder.kt | 22 ++++- .../graphql/tools/EndToEndSpec.groovy | 17 ++++ .../coxautodev/graphql/tools/EndToEndSpec.kt | 6 +- .../MethodFieldResolverDataFetcherTest.kt | 83 +++++++++++++++++++ 7 files changed, 161 insertions(+), 11 deletions(-) create mode 100644 src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt diff --git a/pom.xml b/pom.xml index f7249a4f..90e38c6b 100644 --- a/pom.xml +++ b/pom.xml @@ -14,7 +14,7 @@ UTF-8 1.8 - 1.2.71 + 1.3.0 2.9.6 ${java.version} @@ -29,6 +29,16 @@ kotlin-stdlib ${kotlin.version} + + org.jetbrains.kotlin + kotlin-reflect + ${kotlin.version} + + + org.jetbrains.kotlinx + kotlinx-coroutines-jdk8 + 1.0.0 + com.graphql-java graphql-java diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt b/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt index cc710d2e..85d56d5f 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt @@ -9,6 +9,8 @@ import org.apache.commons.lang3.reflect.FieldUtils import org.slf4j.LoggerFactory import java.lang.reflect.Modifier import java.lang.reflect.ParameterizedType +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.kotlinFunction /** * @author Andrew Potter @@ -112,7 +114,9 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { true } - val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last())) + val correctParameterCount = method.parameterCount == requiredCount || + (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last())) || + (method.kotlinFunction?.run { isSuspend && valueParameters.size == requiredCount } == true) return correctParameterCount && appropriateFirstParameter } diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt index 550acc0f..7e2bf658 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt @@ -8,10 +8,14 @@ import graphql.language.FieldDefinition import graphql.language.NonNullType import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.future.future import java.lang.reflect.Method -import java.lang.reflect.ParameterizedType -import java.lang.reflect.TypeVariable import java.util.* +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.javaType +import kotlin.reflect.jvm.kotlinFunction /** * @author Andrew Potter @@ -31,7 +35,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver } } - private val additionalLastArgument = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1) + private val additionalLastArgument = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1) override fun createDataFetcher(): DataFetcher<*> { val batched = isBatched(method, search) @@ -99,7 +103,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver override fun scanForMatches(): List { val batched = isBatched(method, search) - val unwrappedGenericType = genericType.unwrapGenericType(method.genericReturnType) + val unwrappedGenericType = genericType.unwrapGenericType(method.kotlinFunction?.returnType?.javaType ?: method.returnType) val returnValueMatch = TypeClassMatcher.PotentialMatch.returnValue(field.type, unwrappedGenericType, genericType, SchemaClassScanner.ReturnValueReference(method), batched) return field.inputValueDefinitions.mapIndexed { i, inputDefinition -> @@ -136,6 +140,7 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso // Convert to reflactasm reflection private val methodAccess = MethodAccess.get(method.declaringClass)!! private val methodIndex = methodAccess.getIndex(method.name, *method.parameterTypes) + private val isSuspendFunction = method.kotlinFunction?.isSuspend == true private class CompareGenericWrappers { companion object : Comparator { @@ -149,9 +154,18 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso override fun get(environment: DataFetchingEnvironment): Any? { val source = sourceResolver(environment) val args = this.args.map { it(environment) }.toTypedArray() - val result = methodAccess.invoke(source, methodIndex, *args) + + val result = if (isSuspendFunction) { + GlobalScope.future(options.coroutineContext) { + suspendCoroutineUninterceptedOrReturn { continuation -> + methodAccess.invoke(source, methodIndex, *args + continuation) + } + } + } else { + methodAccess.invoke(source, methodIndex, *args) + } return if (result == null) { - result + null } else { val wrapper = options.genericWrappers .asSequence() diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt index ce406d9c..c239e268 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt @@ -8,6 +8,7 @@ import graphql.language.Document import graphql.parser.Parser import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType +import kotlinx.coroutines.Dispatchers import org.antlr.v4.runtime.RecognitionException import org.antlr.v4.runtime.misc.ParseCancellationException import org.reactivestreams.Publisher @@ -15,6 +16,7 @@ import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletionStage import java.util.concurrent.Future +import kotlin.coroutines.CoroutineContext import kotlin.reflect.KClass /** @@ -247,7 +249,16 @@ class SchemaParserDictionary { } } -data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, val genericWrappers: List, val allowUnimplementedResolvers: Boolean, val objectMapperProvider: PerFieldObjectMapperProvider, val proxyHandlers: List, val preferGraphQLResolver: Boolean, val introspectionEnabled: Boolean) { +data class SchemaParserOptions internal constructor( + val contextClass: Class<*>?, + val genericWrappers: List, + val allowUnimplementedResolvers: Boolean, + val objectMapperProvider: PerFieldObjectMapperProvider, + val proxyHandlers: List, + val preferGraphQLResolver: Boolean, + val introspectionEnabled: Boolean, + val coroutineContext: CoroutineContext +) { companion object { @JvmStatic fun newOptions() = Builder() @@ -265,6 +276,7 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, private val proxyHandlers: MutableList = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler(), JavassistProxyHandler()) private var preferGraphQLResolver = false private var introspectionEnabled = true + private var coroutineContext: CoroutineContext? = null fun contextClass(contextClass: Class<*>) = this.apply { this.contextClass = contextClass @@ -314,6 +326,10 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, this.introspectionEnabled = introspectionEnabled } + fun coroutineContext(context: CoroutineContext) = this.apply { + this.coroutineContext = context + } + fun build(): SchemaParserOptions { val wrappers = if (useDefaultGenericWrappers) { genericWrappers + listOf( @@ -326,7 +342,9 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, genericWrappers } - return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, proxyHandlers, preferGraphQLResolver, introspectionEnabled) + return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, + proxyHandlers, preferGraphQLResolver, introspectionEnabled, + coroutineContext ?: Dispatchers.Default) } } diff --git a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy index c3f9ccd1..5267e144 100644 --- a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy +++ b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy @@ -560,4 +560,21 @@ class EndToEndSpec extends Specification { then: data.dataFetcherResult.name == "item1" } + + def "generated schema supports Kotlin suspend functions"() { + when: + def data = Utils.assertNoGraphQlErrors(gql) { + ''' + { + coroutineItems { + id + name + } + } + ''' + } + + then: + data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]] + } } diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt index 07c5d285..b503d91d 100644 --- a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt +++ b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt @@ -7,6 +7,7 @@ import graphql.language.StringValue import graphql.schema.Coercing import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType +import kotlinx.coroutines.CompletableDeferred import org.reactivestreams.Publisher import java.util.Optional import java.util.UUID @@ -73,6 +74,8 @@ type Query { propertyField: String! dataFetcherResult: Item! + + coroutineItems: [Item!]! } type ExtendedType { @@ -268,12 +271,13 @@ class Query: GraphQLQueryResolver, ListListResolver() { fun propertyMapWithComplexItems() = propertyMapWithComplexItems fun propertyMapWithNestedComplexItems() = propertyMapWithNestedComplexItems - private val propertyField = "test" fun dataFetcherResult(): DataFetcherResult { return DataFetcherResult(items.first(), listOf()) } + + suspend fun coroutineItems(): List = CompletableDeferred(items).await() } class UnusedRootResolver: GraphQLQueryResolver diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt b/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt new file mode 100644 index 00000000..d0e568ba --- /dev/null +++ b/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt @@ -0,0 +1,83 @@ +package com.coxautodev.graphql.tools + +import graphql.ExecutionResult +import graphql.execution.* +import graphql.execution.instrumentation.SimpleInstrumentation +import graphql.language.FieldDefinition +import graphql.language.InputValueDefinition +import graphql.language.TypeName +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.DataFetchingEnvironmentBuilder +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import org.junit.Assert +import org.junit.Test +import java.util.concurrent.CompletableFuture +import kotlin.coroutines.coroutineContext + +class MethodFieldResolverDataFetcherTest { + @Test + fun `data fetcher executes suspend function on coroutineContext defined by options`() { + // setup + val dispatcher = Dispatchers.IO + val job = Job() + val options = SchemaParserOptions.Builder() + .coroutineContext(dispatcher + job) + .build() + + val resolver = createFetcher("active", object : GraphQLResolver { + suspend fun isActive(data: DataClass): Boolean { + return coroutineContext[dispatcher.key] == dispatcher && + coroutineContext[Job] == job.children.first() + } + }, options = options) + + // expect + @Suppress("UNCHECKED_CAST") + val future = resolver.get(createEnvironment(DataClass())) as CompletableFuture + Assert.assertTrue(future.get()) + } + + private fun createFetcher( + methodName: String, + resolver: GraphQLResolver<*>, + arguments: List = emptyList(), + options: SchemaParserOptions = SchemaParserOptions.defaultOptions() + ): DataFetcher<*> { + val field = FieldDefinition(methodName, TypeName("Boolean")).apply { inputValueDefinitions.addAll(arguments) } + val resolverInfo = if (resolver is GraphQLQueryResolver) { + RootResolverInfo(listOf(resolver), options) + } else { + NormalResolverInfo(resolver, options) + } + return FieldResolverScanner(options).findFieldResolver(field, resolverInfo).createDataFetcher() + } + + private fun createEnvironment(source: Any, arguments: Map = emptyMap(), context: Any? = null): DataFetchingEnvironment { + return DataFetchingEnvironmentBuilder.newDataFetchingEnvironment() + .source(source) + .arguments(arguments) + .context(context) + .executionContext(buildExecutionContext()) + .build() + } + + private fun buildExecutionContext(): ExecutionContext { + val executionStrategy = object : ExecutionStrategy() { + override fun execute(executionContext: ExecutionContext, parameters: ExecutionStrategyParameters): CompletableFuture { + throw AssertionError("should not be called") + } + } + val executionId = ExecutionId.from("executionId123") + return ExecutionContextBuilder.newExecutionContextBuilder() + .instrumentation(SimpleInstrumentation.INSTANCE) + .executionId(executionId) + .queryStrategy(executionStrategy) + .mutationStrategy(executionStrategy) + .subscriptionStrategy(executionStrategy) + .build() + } + + data class DataClass(val name: String = "TestName") +} \ No newline at end of file From 20f30f80793f0fac267a471d20d9c02eef6f3eb9 Mon Sep 17 00:00:00 2001 From: Mykola Varahash Date: Sat, 3 Nov 2018 16:16:49 +0200 Subject: [PATCH 2/3] Add support for Kotlin coroutine channels for subscription queries --- pom.xml | 11 ++++ .../graphql/tools/FieldResolverScanner.kt | 9 ++- .../graphql/tools/MethodFieldResolver.kt | 29 ++++------ .../graphql/tools/SchemaParserBuilder.kt | 20 ++++++- .../graphql/tools/EndToEndSpec.groovy | 57 ++++++++++++++++--- .../coxautodev/graphql/tools/EndToEndSpec.kt | 20 ++++++- .../MethodFieldResolverDataFetcherTest.kt | 51 +++++++++++++++++ 7 files changed, 164 insertions(+), 33 deletions(-) diff --git a/pom.xml b/pom.xml index 90e38c6b..95a469bf 100644 --- a/pom.xml +++ b/pom.xml @@ -39,6 +39,11 @@ kotlinx-coroutines-jdk8 1.0.0 + + org.jetbrains.kotlinx + kotlinx-coroutines-reactive + 1.0.0 + com.graphql-java graphql-java @@ -120,6 +125,12 @@ 2.1 test + + org.reactivestreams + reactive-streams-tck + 1.0.2 + test + diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt b/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt index 85d56d5f..e059b698 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt @@ -10,6 +10,7 @@ import org.slf4j.LoggerFactory import java.lang.reflect.Modifier import java.lang.reflect.ParameterizedType import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.javaType import kotlin.reflect.jvm.kotlinFunction /** @@ -114,9 +115,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { true } - val correctParameterCount = method.parameterCount == requiredCount || - (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last())) || - (method.kotlinFunction?.run { isSuspend && valueParameters.size == requiredCount } == true) + val methodParameterCount = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount + val methodLastParameter = method.kotlinFunction?.valueParameters?.lastOrNull()?.type?.javaType ?: method.parameterTypes.lastOrNull() + + val correctParameterCount = methodParameterCount == requiredCount || + (methodParameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(methodLastParameter)) return correctParameterCount && appropriateFirstParameter } diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt index 7e2bf658..cb45c9b0 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt @@ -155,31 +155,26 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso val source = sourceResolver(environment) val args = this.args.map { it(environment) }.toTypedArray() - val result = if (isSuspendFunction) { + return if (isSuspendFunction) { GlobalScope.future(options.coroutineContext) { suspendCoroutineUninterceptedOrReturn { continuation -> - methodAccess.invoke(source, methodIndex, *args + continuation) + methodAccess.invoke(source, methodIndex, *args + continuation)?.transformWithGenericWrapper(environment) } } } else { - methodAccess.invoke(source, methodIndex, *args) - } - return if (result == null) { - null - } else { - val wrapper = options.genericWrappers - .asSequence() - .filter { it.type.isInstance(result) } - .sortedWith(CompareGenericWrappers) - .firstOrNull() - if (wrapper == null) { - result - } else { - wrapper.transformer.invoke(result, environment) - } + methodAccess.invoke(source, methodIndex, *args)?.transformWithGenericWrapper(environment) } } + private fun Any.transformWithGenericWrapper(environment: DataFetchingEnvironment): Any? { + return options.genericWrappers + .asSequence() + .filter { it.type.isInstance(this) } + .sortedWith(CompareGenericWrappers) + .firstOrNull() + ?.transformer?.invoke(this, environment) ?: this + } + /** * Function that return the object used to fetch the data * It can be a DataFetcher or an entity diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt index c239e268..8922b669 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt @@ -9,6 +9,9 @@ import graphql.parser.Parser import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.reactive.publish import org.antlr.v4.runtime.RecognitionException import org.antlr.v4.runtime.misc.ParseCancellationException import org.reactivestreams.Publisher @@ -331,20 +334,31 @@ data class SchemaParserOptions internal constructor( } fun build(): SchemaParserOptions { + val coroutineContext = coroutineContext ?: Dispatchers.Default val wrappers = if (useDefaultGenericWrappers) { genericWrappers + listOf( GenericWrapper(Future::class, 0), GenericWrapper(CompletableFuture::class, 0), GenericWrapper(CompletionStage::class, 0), - GenericWrapper(Publisher::class, 0) + GenericWrapper(Publisher::class, 0), + GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel -> + GlobalScope.publish(coroutineContext) { + try { + for (item in receiveChannel) { + send(item) + } + } finally { + receiveChannel.cancel() + } + } + }) ) } else { genericWrappers } return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, - proxyHandlers, preferGraphQLResolver, introspectionEnabled, - coroutineContext ?: Dispatchers.Default) + proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContext) } } diff --git a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy index 5267e144..3e0e1ea4 100644 --- a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy +++ b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy @@ -7,6 +7,7 @@ import graphql.execution.batched.BatchedExecutionStrategy import graphql.schema.GraphQLSchema import org.reactivestreams.Publisher import org.reactivestreams.Subscriber +import org.reactivestreams.tck.TestEnvironment import spock.lang.Shared import spock.lang.Specification @@ -563,18 +564,56 @@ class EndToEndSpec extends Specification { def "generated schema supports Kotlin suspend functions"() { when: - def data = Utils.assertNoGraphQlErrors(gql) { - ''' - { - coroutineItems { - id - name + def data = Utils.assertNoGraphQlErrors(gql) { + ''' + { + coroutineItems { + id + name + } } - } + ''' + } + + then: + data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]] + } + + def "generated schema supports Kotlin coroutine channels for the subscription query"() { + when: + def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), []) + def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) { ''' - } + subscription { + onItemCreatedCoroutineChannel { + id + } + } + ''' + } + def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher) + + then: + subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannel").id == 1 + subscriber.expectCompletion() + } + + def "generated schema supports Kotlin coroutine channels with suspend function for the subscription query"() { + when: + def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), []) + def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) { + ''' + subscription { + onItemCreatedCoroutineChannelAndSuspendFunction { + id + } + } + ''' + } + def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher) then: - data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]] + subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannelAndSuspendFunction").id == 1 + subscriber.expectCompletion() } } diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt index b503d91d..02bf7a7a 100644 --- a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt +++ b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt @@ -7,7 +7,9 @@ import graphql.language.StringValue import graphql.schema.Coercing import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType -import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel import org.reactivestreams.Publisher import java.util.Optional import java.util.UUID @@ -107,6 +109,8 @@ type Mutation { type Subscription { onItemCreated: Item! + onItemCreatedCoroutineChannel: Item! + onItemCreatedCoroutineChannelAndSuspendFunction: Item! } input ItemSearchInput { @@ -306,6 +310,20 @@ class Subscription : GraphQLSubscriptionResolver { subscriber.onNext(env.getContext().newItem) // subscriber.onComplete() } + + fun onItemCreatedCoroutineChannel(env: DataFetchingEnvironment): ReceiveChannel { + val channel = Channel(1) + channel.offer(env.getContext().newItem) + return channel + } + + suspend fun onItemCreatedCoroutineChannelAndSuspendFunction(env: DataFetchingEnvironment): ReceiveChannel { + return coroutineScope { + val channel = Channel(1) + channel.offer(env.getContext().newItem) + channel + } + } } class ItemResolver : GraphQLResolver { diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt b/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt index d0e568ba..327678af 100644 --- a/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt +++ b/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt @@ -11,8 +11,12 @@ import graphql.schema.DataFetchingEnvironment import graphql.schema.DataFetchingEnvironmentBuilder import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel import org.junit.Assert import org.junit.Test +import org.reactivestreams.Publisher +import org.reactivestreams.tck.TestEnvironment import java.util.concurrent.CompletableFuture import kotlin.coroutines.coroutineContext @@ -39,6 +43,53 @@ class MethodFieldResolverDataFetcherTest { Assert.assertTrue(future.get()) } + @Test + fun `canceling subscription Publisher also cancels underlying Kotlin coroutine channel`() { + // setup + val channel = Channel(10) + channel.offer("A") + channel.offer("B") + + val resolver = createFetcher("onDataNameChanged", object : GraphQLResolver { + fun onDataNameChanged(date: DataClass): ReceiveChannel { + return channel + } + }) + + // expect + @Suppress("UNCHECKED_CAST") + val publisher = resolver.get(createEnvironment(DataClass())) as Publisher + val subscriber = TestEnvironment().newManualSubscriber(publisher) + + Assert.assertEquals("A", subscriber.requestNextElement()) + + subscriber.cancel() + Thread.sleep(100) + Assert.assertTrue(channel.isClosedForReceive) + } + + @Test + fun `canceling underlying Kotlin coroutine channel also cancels subscription Publisher`() { + // setup + val channel = Channel(10) + channel.offer("A") + channel.close(IllegalStateException("Channel error")) + + val resolver = createFetcher("onDataNameChanged", object : GraphQLResolver { + fun onDataNameChanged(date: DataClass): ReceiveChannel { + return channel + } + }) + + // expect + @Suppress("UNCHECKED_CAST") + val publisher = resolver.get(createEnvironment(DataClass())) as Publisher + val subscriber = TestEnvironment().newManualSubscriber(publisher) + + Assert.assertEquals("A", subscriber.requestNextElement()) + subscriber.expectErrorWithMessage(IllegalStateException::class.java, "Channel error") + } + private fun createFetcher( methodName: String, resolver: GraphQLResolver<*>, From cfc02dcc9e11836c7ff1f9d9673799d4c82e8bb6 Mon Sep 17 00:00:00 2001 From: Mykola Varahash Date: Sat, 3 Nov 2018 19:03:04 +0200 Subject: [PATCH 3/3] Fix invoking suspend function --- .../coxautodev/graphql/tools/MethodFieldResolver.kt | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt index cb45c9b0..a7fe2844 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt @@ -157,9 +157,7 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso return if (isSuspendFunction) { GlobalScope.future(options.coroutineContext) { - suspendCoroutineUninterceptedOrReturn { continuation -> - methodAccess.invoke(source, methodIndex, *args + continuation)?.transformWithGenericWrapper(environment) - } + methodAccess.invokeSuspend(source, methodIndex, args)?.transformWithGenericWrapper(environment) } } else { methodAccess.invoke(source, methodIndex, *args)?.transformWithGenericWrapper(environment) @@ -185,6 +183,12 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso } } +private suspend inline fun MethodAccess.invokeSuspend(target: Any, methodIndex: Int, args: Array): Any? { + return suspendCoroutineUninterceptedOrReturn { continuation -> + invoke(target, methodIndex, *args + continuation) + } +} + class BatchedMethodFieldResolverDataFetcher(sourceResolver: SourceResolver, method: Method, args: List, options: SchemaParserOptions) : MethodFieldResolverDataFetcher(sourceResolver, method, args, options) { @Batched override fun get(environment: DataFetchingEnvironment) = super.get(environment)