Skip to content

Add listener for errors that occur in parsing InvocationInput #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package graphql.kickstart.servlet;

import com.fasterxml.jackson.core.JsonProcessingException;
import graphql.GraphQLException;
import graphql.kickstart.execution.input.GraphQLInvocationInput;
import java.io.IOException;
Expand All @@ -15,7 +14,7 @@ class HttpRequestHandlerImpl implements HttpRequestHandler {
private final GraphQLConfiguration configuration;
private final HttpRequestInvoker requestInvoker;

public HttpRequestHandlerImpl(GraphQLConfiguration configuration) {
HttpRequestHandlerImpl(GraphQLConfiguration configuration) {
this(
configuration,
new HttpRequestInvokerImpl(
Expand All @@ -24,28 +23,30 @@ public HttpRequestHandlerImpl(GraphQLConfiguration configuration) {
new QueryResponseWriterFactoryImpl()));
}

public HttpRequestHandlerImpl(
HttpRequestHandlerImpl(
GraphQLConfiguration configuration, HttpRequestInvoker requestInvoker) {
this.configuration = configuration;
this.requestInvoker = requestInvoker;
}

@Override
public void handle(HttpServletRequest request, HttpServletResponse response) throws IOException {
if (request.getCharacterEncoding() == null) {
request.setCharacterEncoding(StandardCharsets.UTF_8.name());
}

ListenerHandler listenerHandler =
ListenerHandler.start(request, response, configuration.getListeners());

try {
if (request.getCharacterEncoding() == null) {
request.setCharacterEncoding(StandardCharsets.UTF_8.name());
}
GraphQLInvocationInputParser invocationInputParser =
GraphQLInvocationInputParser.create(
request,
configuration.getInvocationInputFactory(),
configuration.getObjectMapper(),
configuration.getContextSetting());
GraphQLInvocationInput invocationInput =
invocationInputParser. getGraphQLInvocationInput(request, response);
requestInvoker.execute(invocationInput, request, response);
} catch (GraphQLException | JsonProcessingException e) {
GraphQLInvocationInput invocationInput = parseInvocationInput(request, response);
requestInvoker.execute(invocationInput, request, response, listenerHandler);
} catch (InvocationInputParseException e) {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad request: cannot parse http request", e);
listenerHandler.onParseError(e);
throw e;
} catch (GraphQLException e) {
response.setStatus(STATUS_BAD_REQUEST);
log.info("Bad request: cannot handle http request", e);
throw e;
Expand All @@ -55,4 +56,20 @@ public void handle(HttpServletRequest request, HttpServletResponse response) thr
throw t;
}
}

private GraphQLInvocationInput parseInvocationInput(
HttpServletRequest request,
HttpServletResponse response) {
try {
GraphQLInvocationInputParser invocationInputParser =
GraphQLInvocationInputParser.create(
request,
configuration.getInvocationInputFactory(),
configuration.getObjectMapper(),
configuration.getContextSetting());
return invocationInputParser.getGraphQLInvocationInput(request, response);
} catch (Exception e) {
throw new InvocationInputParseException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ public interface HttpRequestInvoker {
void execute(
GraphQLInvocationInput invocationInput,
HttpServletRequest request,
HttpServletResponse response);
HttpServletResponse response,
ListenerHandler listenerHandler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ public class HttpRequestInvokerImpl implements HttpRequestInvoker {
public void execute(
GraphQLInvocationInput invocationInput,
HttpServletRequest request,
HttpServletResponse response) {
ListenerHandler listenerHandler =
ListenerHandler.start(request, response, configuration.getListeners());
HttpServletResponse response,
ListenerHandler listenerHandler) {
if (request.isAsyncSupported()) {
invokeAndHandleAsync(invocationInput, request, response, listenerHandler);
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package graphql.kickstart.servlet;

public class InvocationInputParseException extends RuntimeException {

public InvocationInputParseException(Throwable t) {
super("Request parsing failed", t);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@Slf4j
@RequiredArgsConstructor
class ListenerHandler {
public class ListenerHandler {

private final List<RequestCallback> callbacks;
private final HttpServletRequest request;
Expand Down Expand Up @@ -60,6 +60,10 @@ void runCallbacks(Consumer<RequestCallback> action) {
});
}

void onParseError(Throwable throwable) {
runCallbacks(it -> it.onParseError(request, response, throwable));
}

void beforeFlush() {
runCallbacks(it -> it.beforeFlush(request, response));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import graphql.kickstart.servlet.GraphQLConfiguration;
import graphql.kickstart.servlet.HttpRequestInvoker;
import graphql.kickstart.servlet.HttpRequestInvokerImpl;
import graphql.kickstart.servlet.ListenerHandler;
import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
Expand Down Expand Up @@ -36,11 +37,12 @@ public CachingHttpRequestInvoker(GraphQLConfiguration configuration) {
public void execute(
GraphQLInvocationInput invocationInput,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletResponse response,
ListenerHandler listenerHandler) {
try {
if (!cacheReader.responseFromCache(
invocationInput, request, response, configuration.getResponseCacheManager())) {
requestInvoker.execute(invocationInput, request, response);
requestInvoker.execute(invocationInput, request, response, listenerHandler);
}
} catch (IOException e) {
response.setStatus(STATUS_BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ default RequestCallback onRequest(HttpServletRequest request, HttpServletRespons
*/
interface RequestCallback {

/**
* Called when failed to parse InvocationInput and the response was not written.
* @param request http request
* @param response http response
*/
default void onParseError(
HttpServletRequest request, HttpServletResponse response, Throwable throwable) {}

/**
* Called right before the response will be written and flushed. Can be used for applying some
* changes to the response object, like adding response headers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,22 @@ b
getResponseContent().data.echoFiles == ["test", "test again"]
}

def "errors while accessing file from the request"() {
setup:
request = Spy(MockHttpServletRequest)
request.setMethod("POST")
request.setContentType("multipart/form-data, boundary=test")
// See https://github.com/apache/tomcat/blob/main/java/org/apache/catalina/connector/Request.java#L2775...L2791
request.getParts() >> { throw new IllegalStateException() }

when:
servlet.doPost(request, response)

then:
response.getStatus() == STATUS_BAD_REQUEST
response.getContentLength() == 0
}

def "batched query over HTTP POST body returns data"() {
setup:
request.setContent('[{ "query": "query { echo(arg:\\"test\\") }" }, { "query": "query { echo(arg:\\"test\\") }" }]'.bytes)
Expand Down Expand Up @@ -1112,8 +1128,7 @@ b
servlet.doGet(request, response)

then:
noExceptionThrown()
response.getStatus() == STATUS_ERROR
response.getStatus() == STATUS_BAD_REQUEST
}

def "errors while data fetching are masked in the response"() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import graphql.kickstart.execution.GraphQLQueryResult
import graphql.kickstart.execution.input.GraphQLSingleInvocationInput
import graphql.kickstart.servlet.GraphQLConfiguration
import graphql.kickstart.servlet.HttpRequestInvoker
import graphql.kickstart.servlet.ListenerHandler
import spock.lang.Specification

import javax.servlet.ServletOutputStream
Expand All @@ -28,6 +29,7 @@ class CachingHttpRequestInvokerTest extends Specification {
def configuration
def graphqlObjectMapper
def outputStreamMock
def listenerHandlerMock

def setup() {
cacheReaderMock = Mock(CacheReader)
Expand All @@ -42,6 +44,7 @@ class CachingHttpRequestInvokerTest extends Specification {
outputStreamMock = Mock(ServletOutputStream)
graphqlInvoker.execute(invocationInputMock) >> FutureExecutionResult.single(invocationInputMock, CompletableFuture.completedFuture(Mock(GraphQLQueryResult)))
cachingInvoker = new CachingHttpRequestInvoker(configuration, httpRequestInvokerMock, cacheReaderMock)
listenerHandlerMock = Mock(ListenerHandler)

configuration.getResponseCacheManager() >> responseCacheManagerMock
configuration.getGraphQLInvoker() >> graphqlInvoker
Expand All @@ -57,29 +60,29 @@ class CachingHttpRequestInvokerTest extends Specification {
cacheReaderMock.responseFromCache(invocationInputMock, requestMock, responseMock, responseCacheManagerMock) >> false

when:
cachingInvoker.execute(invocationInputMock, requestMock, responseMock)
cachingInvoker.execute(invocationInputMock, requestMock, responseMock, listenerHandlerMock)

then:
1 * httpRequestInvokerMock.execute(invocationInputMock, requestMock, responseMock)
1 * httpRequestInvokerMock.execute(invocationInputMock, requestMock, responseMock, listenerHandlerMock)
}

def "should not execute regular invoker if cache exists"() {
given:
cacheReaderMock.responseFromCache(invocationInputMock, requestMock, responseMock, responseCacheManagerMock) >> true

when:
cachingInvoker.execute(invocationInputMock, requestMock, responseMock)
cachingInvoker.execute(invocationInputMock, requestMock, responseMock, listenerHandlerMock)

then:
0 * httpRequestInvokerMock.execute(invocationInputMock, requestMock, responseMock)
0 * httpRequestInvokerMock.execute(invocationInputMock, requestMock, responseMock, listenerHandlerMock)
}

def "should return bad request response when ioexception"() {
given:
cacheReaderMock.responseFromCache(invocationInputMock, requestMock, responseMock, responseCacheManagerMock) >> { throw new IOException() }

when:
cachingInvoker.execute(invocationInputMock, requestMock, responseMock)
cachingInvoker.execute(invocationInputMock, requestMock, responseMock, listenerHandlerMock)

then:
1 * responseMock.setStatus(400)
Expand All @@ -90,7 +93,7 @@ class CachingHttpRequestInvokerTest extends Specification {
def invoker = new CachingHttpRequestInvoker(configuration)

when:
invoker.execute(invocationInputMock, requestMock, responseMock)
invoker.execute(invocationInputMock, requestMock, responseMock, listenerHandlerMock)

then:
noExceptionThrown()
Expand Down