diff --git a/graphql-java-spring-webflux/src/main/java/graphql/spring/web/reactive/components/DefaultGraphQLInvocation.java b/graphql-java-spring-webflux/src/main/java/graphql/spring/web/reactive/components/DefaultGraphQLInvocation.java index 196c835..b3669f4 100644 --- a/graphql-java-spring-webflux/src/main/java/graphql/spring/web/reactive/components/DefaultGraphQLInvocation.java +++ b/graphql-java-spring-webflux/src/main/java/graphql/spring/web/reactive/components/DefaultGraphQLInvocation.java @@ -7,6 +7,7 @@ import graphql.spring.web.reactive.ExecutionInputCustomizer; import graphql.spring.web.reactive.GraphQLInvocation; import graphql.spring.web.reactive.GraphQLInvocationData; +import org.dataloader.DataLoaderRegistry; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.web.server.ServerWebExchange; @@ -19,16 +20,22 @@ public class DefaultGraphQLInvocation implements GraphQLInvocation { @Autowired GraphQL graphQL; + @Autowired(required = false) + DataLoaderRegistry dataLoaderRegistry; + @Autowired ExecutionInputCustomizer executionInputCustomizer; @Override public Mono invoke(GraphQLInvocationData invocationData, ServerWebExchange serverWebExchange) { - ExecutionInput executionInput = ExecutionInput.newExecutionInput() + ExecutionInput.Builder executionInputBuilder = ExecutionInput.newExecutionInput() .query(invocationData.getQuery()) .operationName(invocationData.getOperationName()) - .variables(invocationData.getVariables()) - .build(); + .variables(invocationData.getVariables()); + if (dataLoaderRegistry != null) { + executionInputBuilder.dataLoaderRegistry(dataLoaderRegistry); + } + ExecutionInput executionInput = executionInputBuilder.build(); Mono customizedExecutionInputMono = executionInputCustomizer.customizeExecutionInput(executionInput, serverWebExchange); return customizedExecutionInputMono.flatMap(customizedExecutionInput -> Mono.fromCompletionStage(graphQL.executeAsync(customizedExecutionInput))); } diff --git a/graphql-java-spring-webmvc/src/main/java/graphql/spring/web/servlet/components/DefaultGraphQLInvocation.java b/graphql-java-spring-webmvc/src/main/java/graphql/spring/web/servlet/components/DefaultGraphQLInvocation.java index 9e5dfe7..d96cff9 100644 --- a/graphql-java-spring-webmvc/src/main/java/graphql/spring/web/servlet/components/DefaultGraphQLInvocation.java +++ b/graphql-java-spring-webmvc/src/main/java/graphql/spring/web/servlet/components/DefaultGraphQLInvocation.java @@ -7,6 +7,7 @@ import graphql.spring.web.servlet.ExecutionInputCustomizer; import graphql.spring.web.servlet.GraphQLInvocation; import graphql.spring.web.servlet.GraphQLInvocationData; +import org.dataloader.DataLoaderRegistry; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.web.context.request.WebRequest; @@ -20,16 +21,22 @@ public class DefaultGraphQLInvocation implements GraphQLInvocation { @Autowired GraphQL graphQL; + @Autowired(required = false) + DataLoaderRegistry dataLoaderRegistry; + @Autowired ExecutionInputCustomizer executionInputCustomizer; @Override public CompletableFuture invoke(GraphQLInvocationData invocationData, WebRequest webRequest) { - ExecutionInput executionInput = ExecutionInput.newExecutionInput() + ExecutionInput.Builder executionInputBuilder = ExecutionInput.newExecutionInput() .query(invocationData.getQuery()) .operationName(invocationData.getOperationName()) - .variables(invocationData.getVariables()) - .build(); + .variables(invocationData.getVariables()); + if (dataLoaderRegistry != null) { + executionInputBuilder.dataLoaderRegistry(dataLoaderRegistry); + } + ExecutionInput executionInput = executionInputBuilder.build(); CompletableFuture customizedExecutionInput = executionInputCustomizer.customizeExecutionInput(executionInput, webRequest); return customizedExecutionInput.thenCompose(graphQL::executeAsync); }