Skip to content

Kotlin coroutines support #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
<kotlin.version>1.2.71</kotlin.version>
<kotlin.version>1.3.0</kotlin.version>
<jackson.version>2.9.6</jackson.version>

<maven.compiler.source>${java.version}</maven.compiler.source>
Expand All @@ -29,6 +29,21 @@
<artifactId>kotlin-stdlib</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-reflect</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlinx</groupId>
<artifactId>kotlinx-coroutines-jdk8</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlinx</groupId>
<artifactId>kotlinx-coroutines-reactive</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.graphql-java</groupId>
<artifactId>graphql-java</artifactId>
Expand Down Expand Up @@ -110,6 +125,12 @@
<version>2.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.reactivestreams</groupId>
<artifactId>reactive-streams-tck</artifactId>
<version>1.0.2</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import org.apache.commons.lang3.reflect.FieldUtils
import org.slf4j.LoggerFactory
import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType
import kotlin.reflect.full.valueParameters
import kotlin.reflect.jvm.javaType
import kotlin.reflect.jvm.kotlinFunction

/**
* @author Andrew Potter
Expand Down Expand Up @@ -112,7 +115,11 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
true
}

val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(method.parameterTypes.last()))
val methodParameterCount = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount
val methodLastParameter = method.kotlinFunction?.valueParameters?.lastOrNull()?.type?.javaType ?: method.parameterTypes.lastOrNull()

val correctParameterCount = methodParameterCount == requiredCount ||
(methodParameterCount == (requiredCount + 1) && allowedLastArgumentTypes.contains(methodLastParameter))
return correctParameterCount && appropriateFirstParameter
}

Expand Down
47 changes: 30 additions & 17 deletions src/main/kotlin/com/coxautodev/graphql/tools/MethodFieldResolver.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ import graphql.language.FieldDefinition
import graphql.language.NonNullType
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.future.future
import java.lang.reflect.Method
import java.lang.reflect.ParameterizedType
import java.lang.reflect.TypeVariable
import java.util.*
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
import kotlin.reflect.full.valueParameters
import kotlin.reflect.jvm.javaType
import kotlin.reflect.jvm.kotlinFunction

/**
* @author Andrew Potter
Expand All @@ -31,7 +35,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver
}
}

private val additionalLastArgument = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)
private val additionalLastArgument = method.kotlinFunction?.valueParameters?.size ?: method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)

override fun createDataFetcher(): DataFetcher<*> {
val batched = isBatched(method, search)
Expand Down Expand Up @@ -99,7 +103,7 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver

override fun scanForMatches(): List<TypeClassMatcher.PotentialMatch> {
val batched = isBatched(method, search)
val unwrappedGenericType = genericType.unwrapGenericType(method.genericReturnType)
val unwrappedGenericType = genericType.unwrapGenericType(method.kotlinFunction?.returnType?.javaType ?: method.returnType)
val returnValueMatch = TypeClassMatcher.PotentialMatch.returnValue(field.type, unwrappedGenericType, genericType, SchemaClassScanner.ReturnValueReference(method), batched)

return field.inputValueDefinitions.mapIndexed { i, inputDefinition ->
Expand Down Expand Up @@ -136,6 +140,7 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso
// Convert to reflactasm reflection
private val methodAccess = MethodAccess.get(method.declaringClass)!!
private val methodIndex = methodAccess.getIndex(method.name, *method.parameterTypes)
private val isSuspendFunction = method.kotlinFunction?.isSuspend == true

private class CompareGenericWrappers {
companion object : Comparator<GenericWrapper> {
Expand All @@ -149,23 +154,25 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso
override fun get(environment: DataFetchingEnvironment): Any? {
val source = sourceResolver(environment)
val args = this.args.map { it(environment) }.toTypedArray()
val result = methodAccess.invoke(source, methodIndex, *args)
return if (result == null) {
result
} else {
val wrapper = options.genericWrappers
.asSequence()
.filter { it.type.isInstance(result) }
.sortedWith(CompareGenericWrappers)
.firstOrNull()
if (wrapper == null) {
result
} else {
wrapper.transformer.invoke(result, environment)

return if (isSuspendFunction) {
GlobalScope.future(options.coroutineContext) {
methodAccess.invokeSuspend(source, methodIndex, args)?.transformWithGenericWrapper(environment)
}
} else {
methodAccess.invoke(source, methodIndex, *args)?.transformWithGenericWrapper(environment)
}
}

private fun Any.transformWithGenericWrapper(environment: DataFetchingEnvironment): Any? {
return options.genericWrappers
.asSequence()
.filter { it.type.isInstance(this) }
.sortedWith(CompareGenericWrappers)
.firstOrNull()
?.transformer?.invoke(this, environment) ?: this
}

/**
* Function that return the object used to fetch the data
* It can be a DataFetcher or an entity
Expand All @@ -176,6 +183,12 @@ open class MethodFieldResolverDataFetcher(private val sourceResolver: SourceReso
}
}

private suspend inline fun MethodAccess.invokeSuspend(target: Any, methodIndex: Int, args: Array<Any?>): Any? {
return suspendCoroutineUninterceptedOrReturn { continuation ->
invoke(target, methodIndex, *args + continuation)
}
}

class BatchedMethodFieldResolverDataFetcher(sourceResolver: SourceResolver, method: Method, args: List<ArgumentPlaceholder>, options: SchemaParserOptions) : MethodFieldResolverDataFetcher(sourceResolver, method, args, options) {
@Batched
override fun get(environment: DataFetchingEnvironment) = super.get(environment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@ import graphql.language.Document
import graphql.parser.Parser
import graphql.schema.DataFetchingEnvironment
import graphql.schema.GraphQLScalarType
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.reactive.publish
import org.antlr.v4.runtime.RecognitionException
import org.antlr.v4.runtime.misc.ParseCancellationException
import org.reactivestreams.Publisher
import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import java.util.concurrent.Future
import kotlin.coroutines.CoroutineContext
import kotlin.reflect.KClass

/**
Expand Down Expand Up @@ -247,7 +252,16 @@ class SchemaParserDictionary {
}
}

data class SchemaParserOptions internal constructor(val contextClass: Class<*>?, val genericWrappers: List<GenericWrapper>, val allowUnimplementedResolvers: Boolean, val objectMapperProvider: PerFieldObjectMapperProvider, val proxyHandlers: List<ProxyHandler>, val preferGraphQLResolver: Boolean, val introspectionEnabled: Boolean) {
data class SchemaParserOptions internal constructor(
val contextClass: Class<*>?,
val genericWrappers: List<GenericWrapper>,
val allowUnimplementedResolvers: Boolean,
val objectMapperProvider: PerFieldObjectMapperProvider,
val proxyHandlers: List<ProxyHandler>,
val preferGraphQLResolver: Boolean,
val introspectionEnabled: Boolean,
val coroutineContext: CoroutineContext
) {
companion object {
@JvmStatic
fun newOptions() = Builder()
Expand All @@ -265,6 +279,7 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?,
private val proxyHandlers: MutableList<ProxyHandler> = mutableListOf(Spring4AopProxyHandler(), GuiceAopProxyHandler(), JavassistProxyHandler())
private var preferGraphQLResolver = false
private var introspectionEnabled = true
private var coroutineContext: CoroutineContext? = null

fun contextClass(contextClass: Class<*>) = this.apply {
this.contextClass = contextClass
Expand Down Expand Up @@ -314,19 +329,36 @@ data class SchemaParserOptions internal constructor(val contextClass: Class<*>?,
this.introspectionEnabled = introspectionEnabled
}

fun coroutineContext(context: CoroutineContext) = this.apply {
this.coroutineContext = context
}

fun build(): SchemaParserOptions {
val coroutineContext = coroutineContext ?: Dispatchers.Default
val wrappers = if (useDefaultGenericWrappers) {
genericWrappers + listOf(
GenericWrapper(Future::class, 0),
GenericWrapper(CompletableFuture::class, 0),
GenericWrapper(CompletionStage::class, 0),
GenericWrapper(Publisher::class, 0)
GenericWrapper(Publisher::class, 0),
GenericWrapper.withTransformer(ReceiveChannel::class, 0, { receiveChannel ->
GlobalScope.publish(coroutineContext) {
try {
for (item in receiveChannel) {
send(item)
}
} finally {
receiveChannel.cancel()
}
}
})
)
} else {
genericWrappers
}

return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider, proxyHandlers, preferGraphQLResolver, introspectionEnabled)
return SchemaParserOptions(contextClass, wrappers, allowUnimplementedResolvers, objectMapperProvider,
proxyHandlers, preferGraphQLResolver, introspectionEnabled, coroutineContext)
}
}

Expand Down
56 changes: 56 additions & 0 deletions src/test/groovy/com/coxautodev/graphql/tools/EndToEndSpec.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import graphql.execution.batched.BatchedExecutionStrategy
import graphql.schema.GraphQLSchema
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.tck.TestEnvironment
import spock.lang.Shared
import spock.lang.Specification

Expand Down Expand Up @@ -560,4 +561,59 @@ class EndToEndSpec extends Specification {
then:
data.dataFetcherResult.name == "item1"
}

def "generated schema supports Kotlin suspend functions"() {
when:
def data = Utils.assertNoGraphQlErrors(gql) {
'''
{
coroutineItems {
id
name
}
}
'''
}

then:
data.coroutineItems == [[id:0, name:"item1"], [id:1, name:"item2"]]
}

def "generated schema supports Kotlin coroutine channels for the subscription query"() {
when:
def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), [])
def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) {
'''
subscription {
onItemCreatedCoroutineChannel {
id
}
}
'''
}
def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher<ExecutionResult>)

then:
subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannel").id == 1
subscriber.expectCompletion()
}

def "generated schema supports Kotlin coroutine channels with suspend function for the subscription query"() {
when:
def newItem = new Item(1, "item", Type.TYPE_1, UUID.randomUUID(), [])
def data = Utils.assertNoGraphQlErrors(gql, [:], new OnItemCreatedContext(newItem)) {
'''
subscription {
onItemCreatedCoroutineChannelAndSuspendFunction {
id
}
}
'''
}
def subscriber = new TestEnvironment().newManualSubscriber(data as Publisher<ExecutionResult>)

then:
subscriber.requestNextElement().data.get("onItemCreatedCoroutineChannelAndSuspendFunction").id == 1
subscriber.expectCompletion()
}
}
24 changes: 23 additions & 1 deletion src/test/kotlin/com/coxautodev/graphql/tools/EndToEndSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import graphql.language.StringValue
import graphql.schema.Coercing
import graphql.schema.DataFetchingEnvironment
import graphql.schema.GraphQLScalarType
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import org.reactivestreams.Publisher
import java.util.Optional
import java.util.UUID
Expand Down Expand Up @@ -73,6 +76,8 @@ type Query {

propertyField: String!
dataFetcherResult: Item!

coroutineItems: [Item!]!
}

type ExtendedType {
Expand Down Expand Up @@ -104,6 +109,8 @@ type Mutation {

type Subscription {
onItemCreated: Item!
onItemCreatedCoroutineChannel: Item!
onItemCreatedCoroutineChannelAndSuspendFunction: Item!
}

input ItemSearchInput {
Expand Down Expand Up @@ -268,12 +275,13 @@ class Query: GraphQLQueryResolver, ListListResolver<String>() {
fun propertyMapWithComplexItems() = propertyMapWithComplexItems
fun propertyMapWithNestedComplexItems() = propertyMapWithNestedComplexItems


private val propertyField = "test"

fun dataFetcherResult(): DataFetcherResult<Item> {
return DataFetcherResult(items.first(), listOf())
}

suspend fun coroutineItems(): List<Item> = CompletableDeferred(items).await()
}

class UnusedRootResolver: GraphQLQueryResolver
Expand Down Expand Up @@ -302,6 +310,20 @@ class Subscription : GraphQLSubscriptionResolver {
subscriber.onNext(env.getContext<OnItemCreatedContext>().newItem)
// subscriber.onComplete()
}

fun onItemCreatedCoroutineChannel(env: DataFetchingEnvironment): ReceiveChannel<Item> {
val channel = Channel<Item>(1)
channel.offer(env.getContext<OnItemCreatedContext>().newItem)
return channel
}

suspend fun onItemCreatedCoroutineChannelAndSuspendFunction(env: DataFetchingEnvironment): ReceiveChannel<Item> {
return coroutineScope {
val channel = Channel<Item>(1)
channel.offer(env.getContext<OnItemCreatedContext>().newItem)
channel
}
}
}

class ItemResolver : GraphQLResolver<Item> {
Expand Down
Loading