diff --git a/pom.xml b/pom.xml index f7249a4f..95a469bf 100644 --- a/pom.xml +++ b/pom.xml @@ -14,7 +14,7 @@ UTF-8 1.8 - 1.2.71 + 1.3.0 2.9.6 ${java.version} @@ -29,6 +29,21 @@ kotlin-stdlib ${kotlin.version} + + org.jetbrains.kotlin + kotlin-reflect + ${kotlin.version} + + + org.jetbrains.kotlinx + kotlinx-coroutines-jdk8 + 1.0.0 + + + org.jetbrains.kotlinx + kotlinx-coroutines-reactive + 1.0.0 + com.graphql-java graphql-java @@ -110,6 +125,12 @@ 2.1 test + + org.reactivestreams + reactive-streams-tck + 1.0.2 + test + diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt b/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt index cc710d2e..e059b698 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt @@ -9,6 +9,9 @@ import org.apache.commons.lang3.reflect.FieldUtils import org.slf4j.LoggerFactory import java.lang.reflect.Modifier import java.lang.reflect.ParameterizedType +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.javaType +import kotlin.reflect.jvm.kotlinFunction /** * @author Andrew Potter @@ -112,7 +115,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { true } - val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last())) + val methodParameterCount = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount + val methodLastParameter = method.kotlinFunction?.valueParameters?.lastOrNull()?.type?.javaType ?: method.parameterTypes.lastOrNull() + + val correctParameterCount = methodParameterCount == requiredCount || + (methodParameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(methodLastParameter)) return correctParameterCount && appropriateFirstParameter } diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt index 550acc0f..a7fe2844 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt @@ -8,10 +8,14 @@ import graphql.language.FieldDefinition import graphql.language.NonNullType import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.future.future import java.lang.reflect.Method -import java.lang.reflect.ParameterizedType -import java.lang.reflect.TypeVariable import java.util.* +import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn +import kotlin.reflect.full.valueParameters +import kotlin.reflect.jvm.javaType +import kotlin.reflect.jvm.kotlinFunction /** * @author Andrew Potter @@ -31,7 +35,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver } } - private val additionalLastArgument = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1) + private val additionalLastArgument = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1) override fun createDataFetcher(): DataFetcher<*> { val batched = isBatched(method, search) @@ -99,7 +103,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver override fun scanForMatches(): List { val batched = isBatched(method, search) - val unwrappedGenericType = genericType.unwrapGenericType(method.genericReturnType) + val unwrappedGenericType = genericType.unwrapGenericType(method.kotlinFunction?.returnType?.javaType ?: method.returnType) val returnValueMatch = TypeClassMatcher.PotentialMatch.returnValue(field.type, unwrappedGenericType, genericType, SchemaClassScanner.ReturnValueReference(method), batched) return field.inputValueDefinitions.mapIndexed { i, inputDefinition -> @@ -136,6 +140,7 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso // Convert to reflactasm reflection private val methodAccess = MethodAccess.get(method.declaringClass)!! private val methodIndex = methodAccess.getIndex(method.name, *method.parameterTypes) + private val isSuspendFunction = method.kotlinFunction?.isSuspend == true private class CompareGenericWrappers { companion object : Comparator { @@ -149,23 +154,25 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso override fun get(environment: DataFetchingEnvironment): Any? { val source = sourceResolver(environment) val args = this.args.map { it(environment) }.toTypedArray() - val result = methodAccess.invoke(source, methodIndex, *args) - return if (result == null) { - result - } else { - val wrapper = options.genericWrappers - .asSequence() - .filter { it.type.isInstance(result) } - .sortedWith(CompareGenericWrappers) - .firstOrNull() - if (wrapper == null) { - result - } else { - wrapper.transformer.invoke(result, environment) + + return if (isSuspendFunction) { + GlobalScope.future(options.coroutineContext) { + methodAccess.invokeSuspend(source, methodIndex, args)?.transformWithGenericWrapper(environment) } + } else { + methodAccess.invoke(source, methodIndex, *args)?.transformWithGenericWrapper(environment) } } + private fun Any.transformWithGenericWrapper(environment: DataFetchingEnvironment): Any? { + return options.genericWrappers + .asSequence() + .filter { it.type.isInstance(this) } + .sortedWith(CompareGenericWrappers) + .firstOrNull() + ?.transformer?.invoke(this, environment) ?: this + } + /** * Function that return the object used to fetch the data * It can be a DataFetcher or an entity @@ -176,6 +183,12 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso } } +private suspend inline fun MethodAccess.invokeSuspend(target: Any, methodIndex: Int, args: Array): Any? { + return suspendCoroutineUninterceptedOrReturn { continuation -> + invoke(target, methodIndex, *args + continuation) + } +} + class BatchedMethodFieldResolverDataFetcher(sourceResolver: SourceResolver, method: Method, args: List, options: SchemaParserOptions) : MethodFieldResolverDataFetcher(sourceResolver, method, args, options) { @Batched override fun get(environment: DataFetchingEnvironment) = super.get(environment) diff --git a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt index ce406d9c..8922b669 100644 --- a/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt +++ b/src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt @@ -8,6 +8,10 @@ import graphql.language.Document import graphql.parser.Parser import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.reactive.publish import org.antlr.v4.runtime.RecognitionException import org.antlr.v4.runtime.misc.ParseCancellationException import org.reactivestreams.Publisher @@ -15,6 +19,7 @@ import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletionStage import java.util.concurrent.Future +import kotlin.coroutines.CoroutineContext import kotlin.reflect.KClass /** @@ -247,7 +252,16 @@ class SchemaParserDictionary { } } -data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, val genericWrappers: List, val allowUnimplementedResolvers: Boolean, val objectMapperProvider: PerFieldObjectMapperProvider, val proxyHandlers: List, val preferGraphQLResolver: Boolean, val introspectionEnabled: Boolean) { +data class SchemaParserOptions internal constructor( + val contextClass: Class<*>?, + val genericWrappers: List, + val allowUnimplementedResolvers: Boolean, + val objectMapperProvider: PerFieldObjectMapperProvider, + val proxyHandlers: List, + val preferGraphQLResolver: Boolean, + val introspectionEnabled: Boolean, + val coroutineContext: CoroutineContext +) { companion object { @JvmStatic fun newOptions() = Builder() @@ -265,6 +279,7 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, private val proxyHandlers: MutableList = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler(), JavassistProxyHandler()) private var preferGraphQLResolver = false private var introspectionEnabled = true + private var coroutineContext: CoroutineContext? = null fun contextClass(contextClass: Class<*>) = this.apply { this.contextClass = contextClass @@ -314,19 +329,36 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, this.introspectionEnabled = introspectionEnabled } + fun coroutineContext(context: CoroutineContext) = this.apply { + this.coroutineContext = context + } + fun build(): SchemaParserOptions { + val coroutineContext = coroutineContext ?: Dispatchers.Default val wrappers = if (useDefaultGenericWrappers) { genericWrappers + listOf( GenericWrapper(Future::class, 0), GenericWrapper(CompletableFuture::class, 0), GenericWrapper(CompletionStage::class, 0), - GenericWrapper(Publisher::class, 0) + GenericWrapper(Publisher::class, 0), + GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel -> + GlobalScope.publish(coroutineContext) { + try { + for (item in receiveChannel) { + send(item) + } + } finally { + receiveChannel.cancel() + } + } + }) ) } else { genericWrappers } - return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, proxyHandlers, preferGraphQLResolver, introspectionEnabled) + return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, + proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContext) } } diff --git a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy index c3f9ccd1..3e0e1ea4 100644 --- a/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy +++ b/src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy @@ -7,6 +7,7 @@ import graphql.execution.batched.BatchedExecutionStrategy import graphql.schema.GraphQLSchema import org.reactivestreams.Publisher import org.reactivestreams.Subscriber +import org.reactivestreams.tck.TestEnvironment import spock.lang.Shared import spock.lang.Specification @@ -560,4 +561,59 @@ class EndToEndSpec extends Specification { then: data.dataFetcherResult.name == "item1" } + + def "generated schema supports Kotlin suspend functions"() { + when: + def data = Utils.assertNoGraphQlErrors(gql) { + ''' + { + coroutineItems { + id + name + } + } + ''' + } + + then: + data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]] + } + + def "generated schema supports Kotlin coroutine channels for the subscription query"() { + when: + def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), []) + def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) { + ''' + subscription { + onItemCreatedCoroutineChannel { + id + } + } + ''' + } + def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher) + + then: + subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannel").id == 1 + subscriber.expectCompletion() + } + + def "generated schema supports Kotlin coroutine channels with suspend function for the subscription query"() { + when: + def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), []) + def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) { + ''' + subscription { + onItemCreatedCoroutineChannelAndSuspendFunction { + id + } + } + ''' + } + def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher) + + then: + subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannelAndSuspendFunction").id == 1 + subscriber.expectCompletion() + } } diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt index 07c5d285..02bf7a7a 100644 --- a/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt +++ b/src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt @@ -7,6 +7,9 @@ import graphql.language.StringValue import graphql.schema.Coercing import graphql.schema.DataFetchingEnvironment import graphql.schema.GraphQLScalarType +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel import org.reactivestreams.Publisher import java.util.Optional import java.util.UUID @@ -73,6 +76,8 @@ type Query { propertyField: String! dataFetcherResult: Item! + + coroutineItems: [Item!]! } type ExtendedType { @@ -104,6 +109,8 @@ type Mutation { type Subscription { onItemCreated: Item! + onItemCreatedCoroutineChannel: Item! + onItemCreatedCoroutineChannelAndSuspendFunction: Item! } input ItemSearchInput { @@ -268,12 +275,13 @@ class Query: GraphQLQueryResolver, ListListResolver() { fun propertyMapWithComplexItems() = propertyMapWithComplexItems fun propertyMapWithNestedComplexItems() = propertyMapWithNestedComplexItems - private val propertyField = "test" fun dataFetcherResult(): DataFetcherResult { return DataFetcherResult(items.first(), listOf()) } + + suspend fun coroutineItems(): List = CompletableDeferred(items).await() } class UnusedRootResolver: GraphQLQueryResolver @@ -302,6 +310,20 @@ class Subscription : GraphQLSubscriptionResolver { subscriber.onNext(env.getContext().newItem) // subscriber.onComplete() } + + fun onItemCreatedCoroutineChannel(env: DataFetchingEnvironment): ReceiveChannel { + val channel = Channel(1) + channel.offer(env.getContext().newItem) + return channel + } + + suspend fun onItemCreatedCoroutineChannelAndSuspendFunction(env: DataFetchingEnvironment): ReceiveChannel { + return coroutineScope { + val channel = Channel(1) + channel.offer(env.getContext().newItem) + channel + } + } } class ItemResolver : GraphQLResolver { diff --git a/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt b/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt new file mode 100644 index 00000000..327678af --- /dev/null +++ b/src/test/kotlin/com/coxautodev/graphql/tools/MethodFieldResolverDataFetcherTest.kt @@ -0,0 +1,134 @@ +package com.coxautodev.graphql.tools + +import graphql.ExecutionResult +import graphql.execution.* +import graphql.execution.instrumentation.SimpleInstrumentation +import graphql.language.FieldDefinition +import graphql.language.InputValueDefinition +import graphql.language.TypeName +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment +import graphql.schema.DataFetchingEnvironmentBuilder +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ReceiveChannel +import org.junit.Assert +import org.junit.Test +import org.reactivestreams.Publisher +import org.reactivestreams.tck.TestEnvironment +import java.util.concurrent.CompletableFuture +import kotlin.coroutines.coroutineContext + +class MethodFieldResolverDataFetcherTest { + @Test + fun `data fetcher executes suspend function on coroutineContext defined by options`() { + // setup + val dispatcher = Dispatchers.IO + val job = Job() + val options = SchemaParserOptions.Builder() + .coroutineContext(dispatcher + job) + .build() + + val resolver = createFetcher("active", object : GraphQLResolver { + suspend fun isActive(data: DataClass): Boolean { + return coroutineContext[dispatcher.key] == dispatcher && + coroutineContext[Job] == job.children.first() + } + }, options = options) + + // expect + @Suppress("UNCHECKED_CAST") + val future = resolver.get(createEnvironment(DataClass())) as CompletableFuture + Assert.assertTrue(future.get()) + } + + @Test + fun `canceling subscription Publisher also cancels underlying Kotlin coroutine channel`() { + // setup + val channel = Channel(10) + channel.offer("A") + channel.offer("B") + + val resolver = createFetcher("onDataNameChanged", object : GraphQLResolver { + fun onDataNameChanged(date: DataClass): ReceiveChannel { + return channel + } + }) + + // expect + @Suppress("UNCHECKED_CAST") + val publisher = resolver.get(createEnvironment(DataClass())) as Publisher + val subscriber = TestEnvironment().newManualSubscriber(publisher) + + Assert.assertEquals("A", subscriber.requestNextElement()) + + subscriber.cancel() + Thread.sleep(100) + Assert.assertTrue(channel.isClosedForReceive) + } + + @Test + fun `canceling underlying Kotlin coroutine channel also cancels subscription Publisher`() { + // setup + val channel = Channel(10) + channel.offer("A") + channel.close(IllegalStateException("Channel error")) + + val resolver = createFetcher("onDataNameChanged", object : GraphQLResolver { + fun onDataNameChanged(date: DataClass): ReceiveChannel { + return channel + } + }) + + // expect + @Suppress("UNCHECKED_CAST") + val publisher = resolver.get(createEnvironment(DataClass())) as Publisher + val subscriber = TestEnvironment().newManualSubscriber(publisher) + + Assert.assertEquals("A", subscriber.requestNextElement()) + subscriber.expectErrorWithMessage(IllegalStateException::class.java, "Channel error") + } + + private fun createFetcher( + methodName: String, + resolver: GraphQLResolver<*>, + arguments: List = emptyList(), + options: SchemaParserOptions = SchemaParserOptions.defaultOptions() + ): DataFetcher<*> { + val field = FieldDefinition(methodName, TypeName("Boolean")).apply { inputValueDefinitions.addAll(arguments) } + val resolverInfo = if (resolver is GraphQLQueryResolver) { + RootResolverInfo(listOf(resolver), options) + } else { + NormalResolverInfo(resolver, options) + } + return FieldResolverScanner(options).findFieldResolver(field, resolverInfo).createDataFetcher() + } + + private fun createEnvironment(source: Any, arguments: Map = emptyMap(), context: Any? = null): DataFetchingEnvironment { + return DataFetchingEnvironmentBuilder.newDataFetchingEnvironment() + .source(source) + .arguments(arguments) + .context(context) + .executionContext(buildExecutionContext()) + .build() + } + + private fun buildExecutionContext(): ExecutionContext { + val executionStrategy = object : ExecutionStrategy() { + override fun execute(executionContext: ExecutionContext, parameters: ExecutionStrategyParameters): CompletableFuture { + throw AssertionError("should not be called") + } + } + val executionId = ExecutionId.from("executionId123") + return ExecutionContextBuilder.newExecutionContextBuilder() + .instrumentation(SimpleInstrumentation.INSTANCE) + .executionId(executionId) + .queryStrategy(executionStrategy) + .mutationStrategy(executionStrategy) + .subscriptionStrategy(executionStrategy) + .build() + } + + data class DataClass(val name: String = "TestName") +} \ No newline at end of file