diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/AbstractGraphQlSourceBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/AbstractGraphQlSourceBuilder.java index ce0684b96..1fd48b76a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/AbstractGraphQlSourceBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/AbstractGraphQlSourceBuilder.java @@ -16,12 +16,6 @@ package org.springframework.graphql.execution; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - import graphql.GraphQL; import graphql.execution.instrumentation.ChainedInstrumentation; import graphql.execution.instrumentation.Instrumentation; @@ -30,6 +24,9 @@ import graphql.schema.GraphQLTypeVisitor; import graphql.schema.SchemaTraverser; +import java.util.*; +import java.util.function.Consumer; + /** * Implementation of {@link GraphQlSource.Builder} that leaves it to subclasses @@ -43,6 +40,8 @@ abstract class AbstractGraphQlSourceBuilder> private final List exceptionResolvers = new ArrayList<>(); + private final List subscriptionExceptionResolvers = new ArrayList<>(); + private final List typeVisitors = new ArrayList<>(); private final List instrumentations = new ArrayList<>(); @@ -57,6 +56,12 @@ public B exceptionResolvers(List resolvers) { return self(); } + @Override + public B subscriptionExceptionResolvers(List subscriptionExceptionResolvers) { + this.subscriptionExceptionResolvers.addAll(subscriptionExceptionResolvers); + return self(); + } + @Override public B typeVisitors(List typeVisitors) { this.typeVisitors.addAll(typeVisitors); @@ -105,8 +110,12 @@ public GraphQlSource build() { protected abstract GraphQLSchema initGraphQlSchema(); private GraphQLSchema applyTypeVisitors(GraphQLSchema schema) { + SubscriptionExceptionResolver subscriptionExceptionResolver = new DelegatingSubscriptionExceptionResolver( + subscriptionExceptionResolvers); + GraphQLTypeVisitor visitor = ContextDataFetcherDecorator.createVisitor(subscriptionExceptionResolver); + List visitors = new ArrayList<>(this.typeVisitors); - visitors.add(ContextDataFetcherDecorator.TYPE_VISITOR); + visitors.add(visitor); GraphQLCodeRegistry.Builder codeRegistry = GraphQLCodeRegistry.newCodeRegistry(schema.getCodeRegistry()); Map, Object> vars = Collections.singletonMap(GraphQLCodeRegistry.Builder.class, codeRegistry); diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java index 52def63ee..3c1580025 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/ContextDataFetcherDecorator.java @@ -17,22 +17,16 @@ package org.springframework.graphql.execution; import graphql.ExecutionInput; -import graphql.schema.DataFetcher; -import graphql.schema.DataFetchingEnvironment; -import graphql.schema.GraphQLCodeRegistry; -import graphql.schema.GraphQLFieldDefinition; -import graphql.schema.GraphQLFieldsContainer; -import graphql.schema.GraphQLSchemaElement; -import graphql.schema.GraphQLTypeVisitor; -import graphql.schema.GraphQLTypeVisitorStub; +import graphql.schema.*; import graphql.util.TraversalControl; import graphql.util.TraverserContext; import org.reactivestreams.Publisher; +import org.springframework.util.Assert; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.util.context.ContextView; -import org.springframework.util.Assert; +import java.util.function.Function; /** * Wrap a {@link DataFetcher} to enable the following: @@ -51,10 +45,16 @@ final class ContextDataFetcherDecorator implements DataFetcher { private final boolean subscription; - private ContextDataFetcherDecorator(DataFetcher delegate, boolean subscription) { + private final SubscriptionExceptionResolver subscriptionExceptionResolver; + + private ContextDataFetcherDecorator( + DataFetcher delegate, boolean subscription, + SubscriptionExceptionResolver subscriptionExceptionResolver) { Assert.notNull(delegate, "'delegate' DataFetcher is required"); + Assert.notNull(subscriptionExceptionResolver, "'subscriptionExceptionResolver' is required"); this.delegate = delegate; this.subscription = subscription; + this.subscriptionExceptionResolver = subscriptionExceptionResolver; } @Override @@ -66,7 +66,8 @@ public Object get(DataFetchingEnvironment environment) throws Exception { ContextView contextView = ReactorContextManager.getReactorContext(environment.getGraphQlContext()); if (this.subscription) { - return (!contextView.isEmpty() ? Flux.from((Publisher) value).contextWrite(contextView) : value); + Publisher publisher = interceptSubscriptionPublisherWithExceptionHandler((Publisher) value); + return (!contextView.isEmpty() ? Flux.from(publisher).contextWrite(contextView) : publisher); } if (value instanceof Flux) { @@ -84,29 +85,48 @@ public Object get(DataFetchingEnvironment environment) throws Exception { return value; } + @SuppressWarnings("unchecked") + private Publisher interceptSubscriptionPublisherWithExceptionHandler(Publisher publisher) { + Function> onErrorResumeFunction = e -> + subscriptionExceptionResolver.resolveException(e) + .flatMap(errors -> Mono.error(new SubscriptionStreamException(errors))); + + if (publisher instanceof Flux) { + return ((Flux) publisher).onErrorResume(onErrorResumeFunction); + } + + if (publisher instanceof Mono) { + return ((Mono) publisher).onErrorResume(onErrorResumeFunction); + } + + throw new IllegalArgumentException("Unknown publisher type: '" + publisher.getClass().getName() +"'. " + + "Expected reactor.core.publisher.Mono or reactor.core.publisher.Flux"); + } + /** * {@link GraphQLTypeVisitor} that wraps non-GraphQL data fetchers and adapts them if * they return {@link Flux} or {@link Mono}. */ - static GraphQLTypeVisitor TYPE_VISITOR = new GraphQLTypeVisitorStub() { - - @Override - public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition fieldDefinition, - TraverserContext context) { - - GraphQLCodeRegistry.Builder codeRegistry = context.getVarFromParents(GraphQLCodeRegistry.Builder.class); - GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode(); - DataFetcher dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition); - - if (dataFetcher.getClass().getPackage().getName().startsWith("graphql.")) { + static GraphQLTypeVisitor createVisitor(SubscriptionExceptionResolver subscriptionExceptionResolver) { + return new GraphQLTypeVisitorStub() { + @Override + public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition fieldDefinition, + TraverserContext context) { + + GraphQLCodeRegistry.Builder codeRegistry = context.getVarFromParents(GraphQLCodeRegistry.Builder.class); + GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode(); + DataFetcher dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition); + + if (dataFetcher.getClass().getPackage().getName().startsWith("graphql.")) { + return TraversalControl.CONTINUE; + } + + boolean handlesSubscription = parent.getName().equals("Subscription"); + dataFetcher = new ContextDataFetcherDecorator(dataFetcher, handlesSubscription, subscriptionExceptionResolver); + codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher); return TraversalControl.CONTINUE; } - - boolean handlesSubscription = parent.getName().equals("Subscription"); - dataFetcher = new ContextDataFetcherDecorator(dataFetcher, handlesSubscription); - codeRegistry.dataFetcher(parent, fieldDefinition, dataFetcher); - return TraversalControl.CONTINUE; - } - }; + }; + } } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/DelegatingSubscriptionExceptionResolver.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/DelegatingSubscriptionExceptionResolver.java new file mode 100644 index 000000000..3b2024a44 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/DelegatingSubscriptionExceptionResolver.java @@ -0,0 +1,74 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.execution; + +import graphql.ErrorType; +import graphql.GraphQLError; +import graphql.GraphqlErrorBuilder; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.List; + +/** + * An implementation of {@link SubscriptionExceptionResolver} that is trying to map exception to GraphQL error + * using provided implementation of {@link SubscriptionExceptionResolver}. + *
+ * If none of provided implementations resolve exception to error or if any of implementation throw an exception, + * this {@link SubscriptionExceptionResolver} will return a default error. + * + * @author Mykyta Ivchenko + * @see SubscriptionExceptionResolver + */ +public class DelegatingSubscriptionExceptionResolver implements SubscriptionExceptionResolver { + private static final Log logger = LogFactory.getLog(DelegatingSubscriptionExceptionResolver.class); + private final List resolvers; + + public DelegatingSubscriptionExceptionResolver(List resolvers) { + Assert.notNull(resolvers, "'resolvers' list must be not null."); + this.resolvers = resolvers; + } + + @Override + public Mono> resolveException(Throwable exception) { + return Flux.fromIterable(resolvers) + .flatMap(resolver -> resolver.resolveException(exception)) + .next() + .onErrorResume(error -> Mono.just(handleMappingException(error, exception))) + .defaultIfEmpty(createDefaultErrors()); + } + + private List handleMappingException(Throwable resolverException, Throwable originalException) { + if (logger.isWarnEnabled()) { + logger.warn("Failure while resolving " + originalException.getClass().getName(), resolverException); + } + return createDefaultErrors(); + } + + private List createDefaultErrors() { + GraphQLError error = GraphqlErrorBuilder.newError() + .message("Unknown error") + .errorType(ErrorType.DataFetchingException) + .build(); + + return Collections.singletonList(error); + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/GraphQlSource.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/GraphQlSource.java index 6a7311e4a..7f4a076b6 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/GraphQlSource.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/GraphQlSource.java @@ -16,11 +16,6 @@ package org.springframework.graphql.execution; -import java.io.InputStream; -import java.util.List; -import java.util.function.BiFunction; -import java.util.function.Consumer; - import graphql.GraphQL; import graphql.execution.instrumentation.Instrumentation; import graphql.schema.GraphQLSchema; @@ -28,9 +23,13 @@ import graphql.schema.TypeResolver; import graphql.schema.idl.RuntimeWiring; import graphql.schema.idl.TypeDefinitionRegistry; - import org.springframework.core.io.Resource; +import java.io.InputStream; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Consumer; + /** * Strategy to resolve a {@link GraphQL} and a {@link GraphQLSchema}. @@ -91,6 +90,14 @@ interface Builder> { */ B exceptionResolvers(List resolvers); + /** + * Add {@link SubscriptionExceptionResolver}s to map exceptions, thrown by + * GraphQL Subscription publisher. + * @param subscriptionExceptionResolver the subscription exception resolver + * @return the current builder + */ + B subscriptionExceptionResolvers(List subscriptionExceptionResolvers); + /** * Add {@link GraphQLTypeVisitor}s to visit all element of the created * {@link graphql.schema.GraphQLSchema}. diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolver.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolver.java new file mode 100644 index 000000000..63fa9e005 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolver.java @@ -0,0 +1,50 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.execution; + +import graphql.GraphQLError; +import reactor.core.publisher.Mono; + +import java.util.List; + +/** + * Contract to resolve exceptions, that are thrown by subscription publisher. + * Implementations are typically declared as beans in Spring configuration and + * are invoked sequentially until one emits a List of {@link GraphQLError}s. + *
+ * Usually, it is enough to implement this interface by extending {@link SubscriptionExceptionResolverAdapter} + * and overriding one of its {@link SubscriptionExceptionResolverAdapter#resolveToSingleError(Throwable)} + * or {@link SubscriptionExceptionResolverAdapter#resolveToMultipleErrors(Throwable)} + * + * @author Mykyta Ivchenko + * @see SubscriptionExceptionResolverAdapter + * @see DelegatingSubscriptionExceptionResolver + * @see org.springframework.graphql.server.webflux.GraphQlWebSocketHandler + */ +@FunctionalInterface +public interface SubscriptionExceptionResolver { + /** + * Resolve given exception as list of {@link GraphQLError}s and send them as WebSocket message. + * @param exception the exception to resolve + * @return a {@code Mono} with errors to send in a WebSocket message; + * if the {@code Mono} completes with an empty List, the exception is resolved + * without any errors added to the response; if the {@code Mono} completes + * empty, without emitting a List, the exception remains unresolved and gives + * other resolvers a chance. + */ + Mono> resolveException(Throwable exception); +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java new file mode 100644 index 000000000..3b366f904 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java @@ -0,0 +1,48 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.execution; + +import graphql.GraphQLError; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.List; + +/** + * Abstract class for {@link SubscriptionExceptionResolver} implementations. + * This class provide an easy way to map an exception as GraphQL error synchronously. + *
+ * To use this class, you need to override either {@link SubscriptionExceptionResolverAdapter#resolveToSingleError(Throwable)} + * or {@link SubscriptionExceptionResolverAdapter#resolveToMultipleErrors(Throwable)}. + * + * @author Mykyta Ivchenko + * @see SubscriptionExceptionResolver + */ +public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver { + @Override + public Mono> resolveException(Throwable exception) { + return Mono.just(resolveToMultipleErrors(exception)); + } + + protected List resolveToMultipleErrors(Throwable exception) { + return Collections.singletonList(resolveToSingleError(exception)); + } + + protected GraphQLError resolveToSingleError(Throwable exception) { + return null; + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionStreamException.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionStreamException.java new file mode 100644 index 000000000..71e5ae468 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionStreamException.java @@ -0,0 +1,20 @@ +package org.springframework.graphql.execution; + +import graphql.GraphQLError; +import org.springframework.core.NestedRuntimeException; + +import java.util.List; + +@SuppressWarnings("serial") +public class SubscriptionStreamException extends NestedRuntimeException { + private final List errors; + + public SubscriptionStreamException(List errors) { + super("An exception happened in GraphQL subscription stream."); + this.errors = errors; + } + + public List getErrors() { + return errors; + } +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/support/GraphQlWebSocketMessage.java b/spring-graphql/src/main/java/org/springframework/graphql/server/support/GraphQlWebSocketMessage.java index 027a3e928..957aac211 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/support/GraphQlWebSocketMessage.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/support/GraphQlWebSocketMessage.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import graphql.GraphQLError; @@ -188,10 +189,13 @@ public static GraphQlWebSocketMessage next(String id, Map respon * @param id unique request id * @param error the error to add as the message payload */ - public static GraphQlWebSocketMessage error(String id, GraphQLError error) { - Assert.notNull(error, "GraphQlError is required"); - List> errors = Collections.singletonList(error.toSpecification()); - return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.ERROR, errors); + public static GraphQlWebSocketMessage error(String id, List errors) { + Assert.notNull(errors, "GraphQlErrors list is required"); + List> errorsMap = errors.stream() + .map(GraphQLError::toSpecification) + .collect(Collectors.toList()); + + return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.ERROR, errorsMap); } /** diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/CodecDelegate.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/CodecDelegate.java index 60676fcf1..a6a685bed 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/CodecDelegate.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/CodecDelegate.java @@ -15,6 +15,8 @@ */ package org.springframework.graphql.server.webflux; +import java.util.Collections; +import java.util.List; import java.util.Map; import graphql.GraphQLError; @@ -97,9 +99,13 @@ public WebSocketMessage encodeNext(WebSocketSession session, String id, Map errors) { + return encode(session, GraphQlWebSocketMessage.error(id, errors)); } public WebSocketMessage encodeComplete(WebSocketSession session, String id) { diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java index 03f935dff..70629558f 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java @@ -28,10 +28,12 @@ import java.util.concurrent.atomic.AtomicReference; import graphql.ExecutionResult; +import graphql.GraphQLError; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; +import org.springframework.graphql.execution.SubscriptionStreamException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -216,7 +218,11 @@ private Flux handleResponse(WebSocketSession session, String i CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists"); return GraphQlStatus.close(session, status); } - return Mono.fromCallable(() -> this.codecDelegate.encodeError(session, id, ex)); + if (ex instanceof SubscriptionStreamException) { + List errors = ((SubscriptionStreamException) ex).getErrors(); + return Mono.fromCallable(() -> this.codecDelegate.encodeError(session, id, errors)); + } + return Mono.fromCallable(() -> this.codecDelegate.encodeUnknownError(session, id, ex)); }); } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java index aecf07b3b..e3f485ee6 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java @@ -41,6 +41,7 @@ import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; +import org.springframework.graphql.execution.SubscriptionStreamException; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -286,9 +287,13 @@ private Flux handleResponse(WebSocketSession session, String id, We GraphQlStatus.closeSession(session, status); return Flux.empty(); } + if (ex instanceof SubscriptionStreamException) { + List errors = ((SubscriptionStreamException) ex).getErrors(); + return Mono.just(encode(GraphQlWebSocketMessage.error(id, errors))); + } String message = ex.getMessage(); GraphQLError error = GraphqlErrorBuilder.newError().message(message).build(); - return Mono.just(encode(GraphQlWebSocketMessage.error(id, error))); + return Mono.just(encode(GraphQlWebSocketMessage.error(id, Collections.singletonList(error)))); }); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java b/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java index d5facd87f..d236176b8 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java @@ -16,6 +16,7 @@ package org.springframework.graphql.client; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import java.util.function.Function; @@ -105,7 +106,7 @@ private Publisher handleMessage(GraphQlWebSocketMessage .map(response -> GraphQlWebSocketMessage.next(id, response.toMap())) .concatWithValues( request.getError() != null ? - GraphQlWebSocketMessage.error(id, request.getError()) : + GraphQlWebSocketMessage.error(id, Collections.singletonList(request.getError())) : GraphQlWebSocketMessage.complete(id)); case COMPLETE: return Flux.empty(); diff --git a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java index 902b8155f..8dac171c1 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/execution/ContextDataFetcherDecoratorTests.java @@ -16,24 +16,25 @@ package org.springframework.graphql.execution; -import java.time.Duration; -import java.util.List; - -import graphql.ExecutionInput; -import graphql.ExecutionResult; -import graphql.GraphQL; +import graphql.*; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.ResponseHelper; +import org.springframework.graphql.TestThreadLocalAccessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import reactor.util.context.Context; import reactor.util.context.ContextView; -import org.springframework.graphql.ResponseHelper; -import org.springframework.graphql.GraphQlSetup; -import org.springframework.graphql.TestThreadLocalAccessor; +import java.time.Duration; +import java.util.Collections; +import java.util.List; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; /** * Tests for {@link ContextDataFetcherDecorator}. @@ -104,6 +105,97 @@ void fluxDataFetcherSubscription() throws Exception { .verifyComplete(); } + @Test + void fluxDataFetcherSubscriptionThrowException() throws Exception { + GraphQLError expectedError = GraphqlErrorBuilder.newError() + .message("Error: Example Error") + .errorType(ErrorType.INTERNAL_ERROR) + .extensions(Collections.singletonMap("a", "b")) + .build(); + + SubscriptionExceptionResolver subscriptionSingleExceptionResolverAdapter = Mockito.spy( + new SubscriptionExceptionResolverAdapter() { + @Override + protected GraphQLError resolveToSingleError(Throwable exception) { + return GraphqlErrorBuilder.newError() + .message("Error: " + exception.getMessage()) + .errorType(ErrorType.INTERNAL_ERROR) + .extensions(Collections.singletonMap("a", "b")) + .build(); + } + } + ); + + GraphQL graphQl = GraphQlSetup.schemaContent("type Query { greeting: String } type Subscription { greetings: String }") + .subscriptionExceptionResolvers(subscriptionSingleExceptionResolverAdapter) + .subscriptionFetcher("greetings", (env) -> + Mono.delay(Duration.ofMillis(50)) + .flatMapMany((aLong) -> Flux.create(sink -> { + sink.next("Hi!"); + sink.error(new RuntimeException("Example Error")); + }))) + .toGraphQl(); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build(); + + ExecutionResult executionResult = graphQl.executeAsync(input).get(); + + Flux greetingsFlux = ResponseHelper.forSubscription(executionResult) + .map(message -> message.toEntity("greetings", String.class)); + + StepVerifier.create(greetingsFlux) + .expectNext("Hi!") + .expectErrorSatisfies(error -> assertThat(error) + .usingRecursiveComparison() + .isEqualTo(new SubscriptionStreamException(Collections.singletonList(expectedError)))) + .verify(); + + verify(subscriptionSingleExceptionResolverAdapter).resolveException(any(RuntimeException.class)); + } + + @Test + void monoDataFetcherSubscriptionThrowException() throws Exception { + GraphQLError expectedError = GraphqlErrorBuilder.newError() + .message("Error: Example Error") + .errorType(ErrorType.INTERNAL_ERROR) + .extensions(Collections.singletonMap("a", "b")) + .build(); + + SubscriptionExceptionResolver subscriptionSingleExceptionResolverAdapter = Mockito.spy( + new SubscriptionExceptionResolverAdapter() { + @Override + protected GraphQLError resolveToSingleError(Throwable exception) { + return GraphqlErrorBuilder.newError() + .message("Error: " + exception.getMessage()) + .errorType(ErrorType.INTERNAL_ERROR) + .extensions(Collections.singletonMap("a", "b")) + .build(); + } + } + ); + + GraphQL graphQl = GraphQlSetup.schemaContent("type Query { greeting: String } type Subscription { greetings: String }") + .subscriptionExceptionResolvers(subscriptionSingleExceptionResolverAdapter) + .subscriptionFetcher("greetings", (env) -> + Mono.delay(Duration.ofMillis(50)) + .then(Mono.error(new RuntimeException("Example Error")))) + .toGraphQl(); + + ExecutionInput input = ExecutionInput.newExecutionInput().query("subscription { greetings }").build(); + + ExecutionResult executionResult = graphQl.executeAsync(input).get(); + + Flux greetingsFlux = ResponseHelper.forSubscription(executionResult); + + StepVerifier.create(greetingsFlux) + .expectErrorSatisfies(error -> assertThat(error) + .usingRecursiveComparison() + .isEqualTo(new SubscriptionStreamException(Collections.singletonList(expectedError)))) + .verify(); + + verify(subscriptionSingleExceptionResolverAdapter).resolveException(any(RuntimeException.class)); + } + @Test void dataFetcherWithThreadLocalContext() { ThreadLocal nameThreadLocal = new ThreadLocal<>(); diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java index 5b83f8907..95b70bc4e 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java @@ -16,42 +16,45 @@ package org.springframework.graphql.server.webflux; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; - +import graphql.GraphQLError; +import graphql.GraphqlErrorBuilder; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.test.StepVerifier; - +import org.mockito.Mockito; import org.springframework.core.ResolvableType; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.graphql.GraphQlSetup; -import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor; -import org.springframework.graphql.server.WebGraphQlHandler; -import org.springframework.graphql.server.WebGraphQlInterceptor; -import org.springframework.graphql.server.WebSocketHandlerTestSupport; -import org.springframework.graphql.server.WebSocketGraphQlInterceptor; -import org.springframework.graphql.server.WebSocketSessionInfo; +import org.springframework.graphql.execution.ErrorType; +import org.springframework.graphql.execution.SubscriptionExceptionResolver; +import org.springframework.graphql.execution.SubscriptionExceptionResolverAdapter; +import org.springframework.graphql.server.*; import org.springframework.graphql.server.support.GraphQlWebSocketMessage; import org.springframework.graphql.server.support.GraphQlWebSocketMessageType; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.http.codec.json.Jackson2JsonDecoder; import org.springframework.web.reactive.socket.CloseStatus; import org.springframework.web.reactive.socket.WebSocketMessage; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; import static org.assertj.core.api.Assertions.as; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; /** * Unit tests for {@link GraphQlWebSocketHandler}. @@ -356,7 +359,7 @@ void errorMessagePayloadIsArray() { .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) .hasSize(3) .hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty()) - .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("null")) + .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("Unknown error")) .extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class))) .containsEntry("classification", "DataFetchingException")); }) @@ -364,6 +367,77 @@ void errorMessagePayloadIsArray() { .verify(TIMEOUT); } + @Test + void subscriptionStreamException() { + final String GREETING_QUERY = "{" + + "\"id\":\"" + SUBSCRIPTION_ID + "\"," + + "\"type\":\"subscribe\"," + + "\"payload\":{\"query\": \"" + + " subscription TestTypenameSubscription {" + + " greeting" + + " }\"}" + + "}"; + + String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }"; + + WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema) + .subscriptionFetcher("greeting", env -> Flux.create(emitter -> { + emitter.next("a"); + emitter.error(new RuntimeException("Test Exception")); + emitter.next("b"); + })) + .subscriptionExceptionResolvers(new SubscriptionExceptionResolverAdapter() { + @Override + protected GraphQLError resolveToSingleError(Throwable exception) { + return GraphqlErrorBuilder.newError() + .errorType(ErrorType.INTERNAL_ERROR) + .message("Error: " + exception.getMessage()) + .extensions(Collections.singletonMap("key", "value")) + .build(); + } + }) + .interceptor() + .toWebGraphQlHandler(); + + GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler( + initHandler, + ServerCodecConfigurer.create(), + Duration.ofSeconds(60)); + + TestWebSocketSession session = new TestWebSocketSession(Flux.just( + toWebSocketMessage("{\"type\":\"connection_init\"}"), + toWebSocketMessage(GREETING_QUERY))); + handler.handle(session).block(TIMEOUT); + + StepVerifier.create(session.getOutput()) + .consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK)) + .consumeNextWith((message) -> { + GraphQlWebSocketMessage actual = decode(message); + assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); + assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT); + assertThat(actual.>getPayload()) + .extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("greeting", "a"); + }) + .consumeNextWith((message) -> { + GraphQlWebSocketMessage actual = decode(message); + assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); + assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR); + assertThat(actual.>>getPayload()) + .asList().hasSize(1) + .allSatisfy(theError -> assertThat(theError) + .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) + .hasSize(3) + .hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty()) + .hasEntrySatisfying("message", msg -> assertThat(msg).asString().isEqualTo("Error: Test Exception")) + .extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("classification", "INTERNAL_ERROR") + .containsEntry("key", "value")); + }) + .expectComplete() + .verify(TIMEOUT); + } + private TestWebSocketSession handle(Flux input, WebGraphQlInterceptor... interceptors) { GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler( initHandler(interceptors), diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java index 4a54d73bc..f142e5828 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java @@ -28,8 +28,12 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import graphql.GraphQLError; +import graphql.GraphqlErrorBuilder; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; +import org.springframework.graphql.execution.ErrorType; +import org.springframework.graphql.execution.SubscriptionExceptionResolverAdapter; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -366,7 +370,7 @@ void errorMessagePayloadIsCorrectArray() throws Exception { .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) .hasSize(3) .hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty()) - .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("null")) + .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("Unknown error")) .extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class))) .containsEntry("classification", "DataFetchingException")); }) @@ -375,6 +379,74 @@ void errorMessagePayloadIsCorrectArray() throws Exception { .verify(TIMEOUT); } + @Test + void subscriptionStreamException() throws Exception { + final String GREETING_QUERY = "{" + + "\"id\":\"" + SUBSCRIPTION_ID + "\"," + + "\"type\":\"subscribe\"," + + "\"payload\":{\"query\": \"" + + " subscription TestTypenameSubscription {" + + " greeting" + + " }\"}" + + "}"; + + String schema = "type Subscription { greeting: String! }type Query { greetingUnused: String! }"; + + WebGraphQlHandler initHandler = GraphQlSetup.schemaContent(schema) + .subscriptionFetcher("greeting", env -> Flux.create(emitter -> { + emitter.next("a"); + emitter.error(new RuntimeException("Test Exception")); + emitter.next("b"); + })) + .subscriptionExceptionResolvers(new SubscriptionExceptionResolverAdapter() { + @Override + protected GraphQLError resolveToSingleError(Throwable exception) { + return GraphqlErrorBuilder.newError() + .message("Error: " + exception.getMessage()) + .errorType(ErrorType.INTERNAL_ERROR) + .extensions(Collections.singletonMap("key", "value")) + .build(); + } + }) + .interceptor() + .toWebGraphQlHandler(); + + GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(initHandler, converter, Duration.ofSeconds(60)); + + handle(handler, + new TextMessage("{\"type\":\"connection_init\"}"), + new TextMessage(GREETING_QUERY)); + + StepVerifier.create(this.session.getOutput()) + .consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK)) + .consumeNextWith((message) -> { + GraphQlWebSocketMessage actual = decode(message); + assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); + assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT); + assertThat(actual.>getPayload()) + .extractingByKey("data", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("greeting", "a"); + }) + .consumeNextWith((message) -> { + GraphQlWebSocketMessage actual = decode(message); + assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID); + assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR); + assertThat(actual.>>getPayload()) + .asList().hasSize(1) + .allSatisfy(theError -> assertThat(theError) + .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) + .hasSize(3) + .hasEntrySatisfying("locations", loc -> assertThat(loc).asList().isEmpty()) + .hasEntrySatisfying("message", msg -> assertThat(msg).asString().contains("Error: Test Exception")) + .extractingByKey("extensions", as(InstanceOfAssertFactories.map(String.class, Object.class))) + .containsEntry("classification", "INTERNAL_ERROR") + .containsEntry("key", "value")); + }) + .then(this.session::close) + .expectComplete() + .verify(TIMEOUT); + } + @Test void contextPropagation() throws Exception { ThreadLocal threadLocal = new ThreadLocal<>(); diff --git a/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java b/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java index 1db0fbb04..b3a08548b 100644 --- a/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java +++ b/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java @@ -15,32 +15,26 @@ */ package org.springframework.graphql; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - import graphql.GraphQL; import graphql.schema.DataFetcher; import graphql.schema.GraphQLTypeVisitor; import graphql.schema.TypeResolver; - import org.springframework.context.ApplicationContext; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.Resource; import org.springframework.graphql.data.method.annotation.support.AnnotatedControllerConfigurer; -import org.springframework.graphql.execution.DataFetcherExceptionResolver; -import org.springframework.graphql.execution.DataLoaderRegistrar; -import org.springframework.graphql.execution.DefaultExecutionGraphQlService; -import org.springframework.graphql.execution.GraphQlSource; -import org.springframework.graphql.execution.RuntimeWiringConfigurer; -import org.springframework.graphql.execution.ThreadLocalAccessor; +import org.springframework.graphql.execution.*; import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.graphql.server.WebGraphQlInterceptor; import org.springframework.graphql.server.WebGraphQlSetup; import org.springframework.graphql.server.webflux.GraphQlHttpHandler; import org.springframework.lang.Nullable; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + /** * Workflow for GraphQL tests setup that starts with {@link GraphQlSource.Builder} * related input, and then optionally moving on to the creation of a @@ -99,6 +93,11 @@ public GraphQlSetup exceptionResolver(DataFetcherExceptionResolver... resolvers) return this; } + public GraphQlSetup subscriptionExceptionResolvers(SubscriptionExceptionResolver... resolvers) { + this.graphQlSourceBuilder.subscriptionExceptionResolvers(Arrays.asList(resolvers)); + return this; + } + public GraphQlSetup typeResolver(TypeResolver typeResolver) { this.graphQlSourceBuilder.defaultTypeResolver(typeResolver); return this;