Skip to content

Commit fa6c543

Browse files
committed
Add support for Kotlin suspend functions
1 parent d33cc47 commit fa6c543

File tree

7 files changed

+161
-11
lines changed

7 files changed

+161
-11
lines changed

pom.xml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
<properties>
1515
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
1616
<java.version>1.8</java.version>
17-
<kotlin.version>1.2.71</kotlin.version>
17+
<kotlin.version>1.3.0</kotlin.version>
1818
<jackson.version>2.9.6</jackson.version>
1919

2020
<maven.compiler.source>${java.version}</maven.compiler.source>
@@ -29,6 +29,16 @@
2929
<artifactId>kotlin-stdlib</artifactId>
3030
<version>${kotlin.version}</version>
3131
</dependency>
32+
<dependency>
33+
<groupId>org.jetbrains.kotlin</groupId>
34+
<artifactId>kotlin-reflect</artifactId>
35+
<version>${kotlin.version}</version>
36+
</dependency>
37+
<dependency>
38+
<groupId>org.jetbrains.kotlinx</groupId>
39+
<artifactId>kotlinx-coroutines-jdk8</artifactId>
40+
<version>1.0.0</version>
41+
</dependency>
3242
<dependency>
3343
<groupId>com.graphql-java</groupId>
3444
<artifactId>graphql-java</artifactId>

src/main/kotlin/com/coxautodev/graphql/tools/FieldResolverScanner.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import org.apache.commons.lang3.reflect.FieldUtils
99
import org.slf4j.LoggerFactory
1010
import java.lang.reflect.Modifier
1111
import java.lang.reflect.ParameterizedType
12+
import kotlin.reflect.full.valueParameters
13+
import kotlin.reflect.jvm.kotlinFunction
1214

1315
/**
1416
* @author Andrew Potter
@@ -112,7 +114,9 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
112114
true
113115
}
114116

115-
val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last()))
117+
val correctParameterCount = method.parameterCount == requiredCount ||
118+
(method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last())) ||
119+
(method.kotlinFunction?.run { isSuspend && valueParameters.size == requiredCount } == true)
116120
return correctParameterCount && appropriateFirstParameter
117121
}
118122

src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ import graphql.language.FieldDefinition
88
import graphql.language.NonNullType
99
import graphql.schema.DataFetcher
1010
import graphql.schema.DataFetchingEnvironment
11+
import kotlinx.coroutines.GlobalScope
12+
import kotlinx.coroutines.future.future
1113
import java.lang.reflect.Method
12-
import java.lang.reflect.ParameterizedType
13-
import java.lang.reflect.TypeVariable
1414
import java.util.*
15+
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
16+
import kotlin.reflect.full.valueParameters
17+
import kotlin.reflect.jvm.javaType
18+
import kotlin.reflect.jvm.kotlinFunction
1519

1620
/**
1721
* @author Andrew Potter
@@ -31,7 +35,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver
3135
}
3236
}
3337

34-
private val additionalLastArgument = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)
38+
private val additionalLastArgument = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)
3539

3640
override fun createDataFetcher(): DataFetcher<*> {
3741
val batched = isBatched(method, search)
@@ -99,7 +103,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver
99103

100104
override fun scanForMatches(): List<TypeClassMatcher.PotentialMatch> {
101105
val batched = isBatched(method, search)
102-
val unwrappedGenericType = genericType.unwrapGenericType(method.genericReturnType)
106+
val unwrappedGenericType = genericType.unwrapGenericType(method.kotlinFunction?.returnType?.javaType ?: method.returnType)
103107
val returnValueMatch = TypeClassMatcher.PotentialMatch.returnValue(field.type, unwrappedGenericType, genericType, SchemaClassScanner.ReturnValueReference(method), batched)
104108

105109
return field.inputValueDefinitions.mapIndexed { i, inputDefinition ->
@@ -136,6 +140,7 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso
136140
// Convert to reflactasm reflection
137141
private val methodAccess = MethodAccess.get(method.declaringClass)!!
138142
private val methodIndex = methodAccess.getIndex(method.name, *method.parameterTypes)
143+
private val isSuspendFunction = method.kotlinFunction?.isSuspend == true
139144

140145
private class CompareGenericWrappers {
141146
companion object : Comparator<GenericWrapper> {
@@ -149,9 +154,18 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso
149154
override fun get(environment: DataFetchingEnvironment): Any? {
150155
val source = sourceResolver(environment)
151156
val args = this.args.map { it(environment) }.toTypedArray()
152-
val result = methodAccess.invoke(source, methodIndex, *args)
157+
158+
val result = if (isSuspendFunction) {
159+
GlobalScope.future(options.coroutineContext) {
160+
suspendCoroutineUninterceptedOrReturn<Any?> { continuation ->
161+
methodAccess.invoke(source, methodIndex, *args + continuation)
162+
}
163+
}
164+
} else {
165+
methodAccess.invoke(source, methodIndex, *args)
166+
}
153167
return if (result == null) {
154-
result
168+
null
155169
} else {
156170
val wrapper = options.genericWrappers
157171
.asSequence()

src/main/kotlin/com/coxautodev/graphql/tools/SchemaParserBuilder.kt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ import graphql.language.Document
88
import graphql.parser.Parser
99
import graphql.schema.DataFetchingEnvironment
1010
import graphql.schema.GraphQLScalarType
11+
import kotlinx.coroutines.Dispatchers
1112
import org.antlr.v4.runtime.RecognitionException
1213
import org.antlr.v4.runtime.misc.ParseCancellationException
1314
import org.reactivestreams.Publisher
1415
import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl
1516
import java.util.concurrent.CompletableFuture
1617
import java.util.concurrent.CompletionStage
1718
import java.util.concurrent.Future
19+
import kotlin.coroutines.CoroutineContext
1820
import kotlin.reflect.KClass
1921

2022
/**
@@ -247,7 +249,16 @@ class SchemaParserDictionary {
247249
}
248250
}
249251

250-
data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, val genericWrappers: List<GenericWrapper>, val allowUnimplementedResolvers: Boolean, val objectMapperProvider: PerFieldObjectMapperProvider, val proxyHandlers: List<ProxyHandler>, val preferGraphQLResolver: Boolean, val introspectionEnabled: Boolean) {
252+
data class SchemaParserOptions internal constructor(
253+
val contextClass: Class<*>?,
254+
val genericWrappers: List<GenericWrapper>,
255+
val allowUnimplementedResolvers: Boolean,
256+
val objectMapperProvider: PerFieldObjectMapperProvider,
257+
val proxyHandlers: List<ProxyHandler>,
258+
val preferGraphQLResolver: Boolean,
259+
val introspectionEnabled: Boolean,
260+
val coroutineContext: CoroutineContext
261+
) {
251262
companion object {
252263
@JvmStatic
253264
fun newOptions() = Builder()
@@ -265,6 +276,7 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?,
265276
private val proxyHandlers: MutableList<ProxyHandler> = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler(), JavassistProxyHandler())
266277
private var preferGraphQLResolver = false
267278
private var introspectionEnabled = true
279+
private var coroutineContext: CoroutineContext? = null
268280

269281
fun contextClass(contextClass: Class<*>) = this.apply {
270282
this.contextClass = contextClass
@@ -314,6 +326,10 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?,
314326
this.introspectionEnabled = introspectionEnabled
315327
}
316328

329+
fun coroutineContext(context: CoroutineContext) = this.apply {
330+
this.coroutineContext = context
331+
}
332+
317333
fun build(): SchemaParserOptions {
318334
val wrappers = if (useDefaultGenericWrappers) {
319335
genericWrappers + listOf(
@@ -326,7 +342,9 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?,
326342
genericWrappers
327343
}
328344

329-
return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, proxyHandlers, preferGraphQLResolver, introspectionEnabled)
345+
return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider,
346+
proxyHandlers, preferGraphQLResolver, introspectionEnabled,
347+
coroutineContext ?: Dispatchers.Default)
330348
}
331349
}
332350

src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,4 +560,21 @@ class EndToEndSpec extends Specification {
560560
then:
561561
data.dataFetcherResult.name == "item1"
562562
}
563+
564+
def "generated schema supports Kotlin suspend functions"() {
565+
when:
566+
def data = Utils.assertNoGraphQlErrors(gql) {
567+
'''
568+
{
569+
coroutineItems {
570+
id
571+
name
572+
}
573+
}
574+
'''
575+
}
576+
577+
then:
578+
data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]]
579+
}
563580
}

src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import graphql.language.StringValue
77
import graphql.schema.Coercing
88
import graphql.schema.DataFetchingEnvironment
99
import graphql.schema.GraphQLScalarType
10+
import kotlinx.coroutines.CompletableDeferred
1011
import org.reactivestreams.Publisher
1112
import java.util.Optional
1213
import java.util.UUID
@@ -73,6 +74,8 @@ type Query {
7374
7475
propertyField: String!
7576
dataFetcherResult: Item!
77+
78+
coroutineItems: [Item!]!
7679
}
7780
7881
type ExtendedType {
@@ -268,12 +271,13 @@ class Query: GraphQLQueryResolver, ListListResolver<String>() {
268271
fun propertyMapWithComplexItems() = propertyMapWithComplexItems
269272
fun propertyMapWithNestedComplexItems() = propertyMapWithNestedComplexItems
270273

271-
272274
private val propertyField = "test"
273275

274276
fun dataFetcherResult(): DataFetcherResult<Item> {
275277
return DataFetcherResult(items.first(), listOf())
276278
}
279+
280+
suspend fun coroutineItems(): List<Item> = CompletableDeferred(items).await()
277281
}
278282

279283
class UnusedRootResolver: GraphQLQueryResolver
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package com.coxautodev.graphql.tools
2+
3+
import graphql.ExecutionResult
4+
import graphql.execution.*
5+
import graphql.execution.instrumentation.SimpleInstrumentation
6+
import graphql.language.FieldDefinition
7+
import graphql.language.InputValueDefinition
8+
import graphql.language.TypeName
9+
import graphql.schema.DataFetcher
10+
import graphql.schema.DataFetchingEnvironment
11+
import graphql.schema.DataFetchingEnvironmentBuilder
12+
import kotlinx.coroutines.Dispatchers
13+
import kotlinx.coroutines.Job
14+
import org.junit.Assert
15+
import org.junit.Test
16+
import java.util.concurrent.CompletableFuture
17+
import kotlin.coroutines.coroutineContext
18+
19+
class MethodFieldResolverDataFetcherTest {
20+
@Test
21+
fun `data fetcher executes suspend function on coroutineContext defined by options`() {
22+
// setup
23+
val dispatcher = Dispatchers.IO
24+
val job = Job()
25+
val options = SchemaParserOptions.Builder()
26+
.coroutineContext(dispatcher + job)
27+
.build()
28+
29+
val resolver = createFetcher("active", object : GraphQLResolver<DataClass> {
30+
suspend fun isActive(data: DataClass): Boolean {
31+
return coroutineContext[dispatcher.key] == dispatcher &&
32+
coroutineContext[Job] == job.children.first()
33+
}
34+
}, options = options)
35+
36+
// expect
37+
@Suppress("UNCHECKED_CAST")
38+
val future = resolver.get(createEnvironment(DataClass())) as CompletableFuture<Boolean>
39+
Assert.assertTrue(future.get())
40+
}
41+
42+
private fun createFetcher(
43+
methodName: String,
44+
resolver: GraphQLResolver<*>,
45+
arguments: List<InputValueDefinition> = emptyList(),
46+
options: SchemaParserOptions = SchemaParserOptions.defaultOptions()
47+
): DataFetcher<*> {
48+
val field = FieldDefinition(methodName, TypeName("Boolean")).apply { inputValueDefinitions.addAll(arguments) }
49+
val resolverInfo = if (resolver is GraphQLQueryResolver) {
50+
RootResolverInfo(listOf(resolver), options)
51+
} else {
52+
NormalResolverInfo(resolver, options)
53+
}
54+
return FieldResolverScanner(options).findFieldResolver(field, resolverInfo).createDataFetcher()
55+
}
56+
57+
private fun createEnvironment(source: Any, arguments: Map<String, Any> = emptyMap(), context: Any? = null): DataFetchingEnvironment {
58+
return DataFetchingEnvironmentBuilder.newDataFetchingEnvironment()
59+
.source(source)
60+
.arguments(arguments)
61+
.context(context)
62+
.executionContext(buildExecutionContext())
63+
.build()
64+
}
65+
66+
private fun buildExecutionContext(): ExecutionContext {
67+
val executionStrategy = object : ExecutionStrategy() {
68+
override fun execute(executionContext: ExecutionContext, parameters: ExecutionStrategyParameters): CompletableFuture<ExecutionResult> {
69+
throw AssertionError("should not be called")
70+
}
71+
}
72+
val executionId = ExecutionId.from("executionId123")
73+
return ExecutionContextBuilder.newExecutionContextBuilder()
74+
.instrumentation(SimpleInstrumentation.INSTANCE)
75+
.executionId(executionId)
76+
.queryStrategy(executionStrategy)
77+
.mutationStrategy(executionStrategy)
78+
.subscriptionStrategy(executionStrategy)
79+
.build()
80+
}
81+
82+
data class DataClass(val name: String = "TestName")
83+
}

0 commit comments

Comments
 (0)