diff --git a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java index 3120a216..27cc8023 100644 --- a/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java +++ b/src/main/java/graphql/servlet/AbstractGraphQLHttpServlet.java @@ -365,12 +365,7 @@ private void query(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQL } private void queryBatched(GraphQLQueryInvoker queryInvoker, GraphQLObjectMapper graphQLObjectMapper, GraphQLBatchedInvocationInput invocationInput, HttpServletResponse resp) throws Exception { - resp.setContentType(APPLICATION_JSON_UTF8); - resp.setStatus(STATUS_OK); - - Writer respWriter = resp.getWriter(); - - queryInvoker.query(invocationInput, respWriter, graphQLObjectMapper); + queryInvoker.query(invocationInput, resp, graphQLObjectMapper); } private List runListeners(Function action) { diff --git a/src/main/java/graphql/servlet/BatchExecutionHandler.java b/src/main/java/graphql/servlet/BatchExecutionHandler.java index 850f4843..b07fe5d6 100644 --- a/src/main/java/graphql/servlet/BatchExecutionHandler.java +++ b/src/main/java/graphql/servlet/BatchExecutionHandler.java @@ -3,6 +3,7 @@ import graphql.ExecutionInput; import graphql.ExecutionResult; +import javax.servlet.http.HttpServletResponse; import java.io.Writer; import java.util.function.BiFunction; @@ -16,9 +17,9 @@ public interface BatchExecutionHandler { * @param batchedInvocationInput the batch query input * @param queryFunction Function to produce query results. * @param graphQLObjectMapper object mapper used to serialize results - * @param writer request writer to ouput results. + * @param response http response object */ - void handleBatch(GraphQLBatchedInvocationInput batchedInvocationInput, Writer writer, GraphQLObjectMapper graphQLObjectMapper, + void handleBatch(GraphQLBatchedInvocationInput batchedInvocationInput, HttpServletResponse response, GraphQLObjectMapper graphQLObjectMapper, BiFunction queryFunction); } diff --git a/src/main/java/graphql/servlet/DefaultBatchExecutionHandler.java b/src/main/java/graphql/servlet/DefaultBatchExecutionHandler.java index 6ad9fcf1..fbf92adb 100644 --- a/src/main/java/graphql/servlet/DefaultBatchExecutionHandler.java +++ b/src/main/java/graphql/servlet/DefaultBatchExecutionHandler.java @@ -3,6 +3,7 @@ import graphql.ExecutionInput; import graphql.ExecutionResult; +import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.Writer; import java.util.Iterator; @@ -11,10 +12,14 @@ public class DefaultBatchExecutionHandler implements BatchExecutionHandler { @Override - public void handleBatch(GraphQLBatchedInvocationInput batchedInvocationInput, Writer writer, GraphQLObjectMapper graphQLObjectMapper, + public void handleBatch(GraphQLBatchedInvocationInput batchedInvocationInput, HttpServletResponse response, GraphQLObjectMapper graphQLObjectMapper, BiFunction queryFunction) { - Iterator executionInputIterator = batchedInvocationInput.getExecutionInputs().iterator(); + response.setContentType(AbstractGraphQLHttpServlet.APPLICATION_JSON_UTF8); + response.setStatus(AbstractGraphQLHttpServlet.STATUS_OK); try { + Writer writer = response.getWriter(); + Iterator executionInputIterator = batchedInvocationInput.getExecutionInputs().iterator(); + writer.write("["); while (executionInputIterator.hasNext()) { ExecutionResult result = queryFunction.apply(batchedInvocationInput, executionInputIterator.next()); diff --git a/src/main/java/graphql/servlet/GraphQLQueryInvoker.java b/src/main/java/graphql/servlet/GraphQLQueryInvoker.java index 1b21c340..400ee499 100644 --- a/src/main/java/graphql/servlet/GraphQLQueryInvoker.java +++ b/src/main/java/graphql/servlet/GraphQLQueryInvoker.java @@ -13,6 +13,7 @@ import graphql.schema.GraphQLSchema; import javax.security.auth.Subject; +import javax.servlet.http.HttpServletResponse; import java.io.Writer; import java.security.AccessController; import java.security.PrivilegedAction; @@ -43,8 +44,8 @@ public ExecutionResult query(GraphQLSingleInvocationInput singleInvocationInput) return query(singleInvocationInput, singleInvocationInput.getExecutionInput()); } - public void query(GraphQLBatchedInvocationInput batchedInvocationInput, Writer writer, GraphQLObjectMapper graphQLObjectMapper) { - batchExecutionHandler.handleBatch(batchedInvocationInput, writer, graphQLObjectMapper, this::query); + public void query(GraphQLBatchedInvocationInput batchedInvocationInput, HttpServletResponse response, GraphQLObjectMapper graphQLObjectMapper) { + batchExecutionHandler.handleBatch(batchedInvocationInput, response, graphQLObjectMapper, this::query); } private GraphQL newGraphQL(GraphQLSchema schema, Object context) { diff --git a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy index 54712c84..27d23eed 100644 --- a/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy +++ b/src/test/groovy/graphql/servlet/AbstractGraphQLHttpServletSpec.groovy @@ -285,7 +285,7 @@ class AbstractGraphQLHttpServletSpec extends Specification { getBatchedResponseContent()[1].data.echo == "test" } - def "Execution Result Handler allows limiting number of queries"() { + def "Batch Execution Handler allows limiting batches and sending error messages."() { setup: servlet = TestUtils.createBatchCustomizedServlet({ env -> env.arguments.arg }, { env -> env.arguments.arg }, { env -> AtomicReference> publisherRef = new AtomicReference<>() @@ -304,9 +304,8 @@ class AbstractGraphQLHttpServletSpec extends Specification { servlet.doGet(request, response) then: - response.getStatus() == STATUS_OK - response.getContentType() == CONTENT_TYPE_JSON_UTF8 - getBatchedResponseContent().size() == 2 + response.getStatus() == STATUS_BAD_REQUEST + response.getErrorMessage() == TestBatchExecutionHandler.BATCH_ERROR_MESSAGE } def "Default Execution Result Handler does not limit number of queries"() { diff --git a/src/test/groovy/graphql/servlet/TestBatchExecutionHandler.java b/src/test/groovy/graphql/servlet/TestBatchExecutionHandler.java index cc42a303..31b5ffdd 100644 --- a/src/test/groovy/graphql/servlet/TestBatchExecutionHandler.java +++ b/src/test/groovy/graphql/servlet/TestBatchExecutionHandler.java @@ -3,6 +3,7 @@ import graphql.ExecutionInput; import graphql.ExecutionResult; +import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.Writer; import java.util.Iterator; @@ -12,18 +13,34 @@ public class TestBatchExecutionHandler implements BatchExecutionHandler { + public static String BATCH_ERROR_MESSAGE = "Batch limit exceeded"; + @Override - public void handleBatch(GraphQLBatchedInvocationInput batchedInvocationInput, Writer writer, GraphQLObjectMapper graphQLObjectMapper, + public void handleBatch(GraphQLBatchedInvocationInput batchedInvocationInput, HttpServletResponse response, GraphQLObjectMapper graphQLObjectMapper, BiFunction queryFunction) { - List results = batchedInvocationInput.getExecutionInputs().parallelStream() - .limit(2) + List inputs = batchedInvocationInput.getExecutionInputs(); + if (inputs.size() > 2) { + handleBadInput(response); + } + List results = inputs.parallelStream() .map(input -> queryFunction.apply(batchedInvocationInput, input)) .collect(Collectors.toList()); - writeResults(results, writer, graphQLObjectMapper); + writeResults(results, response, graphQLObjectMapper); + } + + private void handleBadInput(HttpServletResponse response) { + try { + response.sendError(HttpServletResponse.SC_BAD_REQUEST, BATCH_ERROR_MESSAGE); + } catch (IOException e) { + throw new RuntimeException(e); + } } - private void writeResults(List results, Writer writer, GraphQLObjectMapper mapper) { + private void writeResults(List results, HttpServletResponse response, GraphQLObjectMapper mapper) { + response.setContentType(AbstractGraphQLHttpServlet.APPLICATION_JSON_UTF8); + response.setStatus(AbstractGraphQLHttpServlet.STATUS_OK); try { + Writer writer = response.getWriter(); writer.write("["); Iterator iter = results.iterator(); while (iter.hasNext()) {