Skip to content

Commit 05ebd63

Browse files
committed
SubscriptionExceptionResolver supports context propagation
See gh-398
1 parent a5809e9 commit 05ebd63

File tree

5 files changed

+188
-134
lines changed

5 files changed

+188
-134
lines changed

spring-graphql/src/main/java/org/springframework/graphql/execution/SubscriptionExceptionResolverAdapter.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,50 @@
4444
*/
4545
public abstract class SubscriptionExceptionResolverAdapter implements SubscriptionExceptionResolver {
4646

47+
private boolean threadLocalContextAware;
48+
49+
50+
/**
51+
* Subclasses can set this to indicate that ThreadLocal context from the
52+
* transport handler (e.g. HTTP handler) should be restored when resolving
53+
* exceptions.
54+
* <p><strong>Note:</strong> This property is applicable only if transports
55+
* use ThreadLocal's' (e.g. Spring MVC) and if a {@link ThreadLocalAccessor}
56+
* is registered to extract ThreadLocal values of interest. There is no
57+
* impact from setting this property otherwise.
58+
* <p>By default this is set to "false" in which case there is no attempt
59+
* to propagate ThreadLocal context.
60+
* @param threadLocalContextAware whether this resolver needs access to
61+
* ThreadLocal context or not.
62+
*/
63+
public void setThreadLocalContextAware(boolean threadLocalContextAware) {
64+
this.threadLocalContextAware = threadLocalContextAware;
65+
}
66+
67+
/**
68+
* Whether ThreadLocal context needs to be restored for this resolver.
69+
*/
70+
public boolean isThreadLocalContextAware() {
71+
return this.threadLocalContextAware;
72+
}
73+
74+
4775
@Override
4876
public final Mono<List<GraphQLError>> resolveException(Throwable exception) {
49-
return Mono.justOrEmpty(resolveToMultipleErrors(exception));
77+
if (!this.threadLocalContextAware) {
78+
return Mono.justOrEmpty(resolveToMultipleErrors(exception));
79+
}
80+
return Mono.deferContextual(contextView -> {
81+
List<GraphQLError> errors;
82+
try {
83+
ReactorContextManager.restoreThreadLocalValues(contextView);
84+
errors = resolveToMultipleErrors(exception);
85+
}
86+
finally {
87+
ReactorContextManager.resetThreadLocalValues(contextView);
88+
}
89+
return Mono.justOrEmpty(errors);
90+
});
5091
}
5192

5293
/**

spring-graphql/src/test/java/org/springframework/graphql/ResponseHelper.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ public static ResponseHelper forResponse(Mono<? extends ExecutionGraphQlResponse
150150

151151
public static Flux<ResponseHelper> forSubscription(ExecutionResult result) {
152152
assertThat(result.getErrors()).as("Errors present in GraphQL response").isEmpty();
153+
Object data = result.getData();
154+
assertThat(data).as("Expected Publisher from subscription").isNotNull();
155+
assertThat(data).as("Expected Publisher from subscription").isInstanceOf(Publisher.class);
153156
Publisher<ExecutionResult> publisher = result.getData();
154157
return Flux.from(publisher).map(ResponseHelper::forResult);
155158
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.graphql.execution;
17+
18+
import java.time.Duration;
19+
import java.util.List;
20+
21+
import graphql.ExecutionInput;
22+
import graphql.GraphQL;
23+
import graphql.GraphQLError;
24+
import graphql.GraphqlErrorBuilder;
25+
import org.junit.jupiter.api.Test;
26+
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
28+
import reactor.test.StepVerifier;
29+
import reactor.util.context.Context;
30+
import reactor.util.context.ContextView;
31+
32+
import org.springframework.graphql.GraphQlSetup;
33+
import org.springframework.graphql.ResponseHelper;
34+
import org.springframework.graphql.TestThreadLocalAccessor;
35+
36+
import static org.assertj.core.api.Assertions.assertThat;
37+
38+
/**
39+
* Tests for resolving exceptions via {@link SubscriptionExceptionResolver}.
40+
* @author Rossen Stoyanchev
41+
*/
42+
public class CompositeSubscriptionExceptionResolverTests {
43+
44+
private static final Duration TIMEOUT = Duration.ofSeconds(5);
45+
46+
47+
@Test
48+
void subscriptionPublisherExceptionResolved() {
49+
String query = "subscription { greetings }";
50+
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
51+
52+
GraphQL graphQL = GraphQlSetup.schemaContent(schema)
53+
.subscriptionFetcher("greetings", env ->
54+
Flux.create(emitter -> {
55+
emitter.next("a");
56+
emitter.error(new RuntimeException("Test Exception"));
57+
emitter.next("b");
58+
}))
59+
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
60+
GraphqlErrorBuilder.newError()
61+
.message("Error: " + exception.getMessage())
62+
.errorType(ErrorType.BAD_REQUEST)
63+
.build()))
64+
.toGraphQl();
65+
66+
ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
67+
Flux<ResponseHelper> flux = Mono.fromFuture(graphQL.executeAsync(input))
68+
.map(ResponseHelper::forSubscription)
69+
.block(TIMEOUT);
70+
71+
StepVerifier.create(flux)
72+
.consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
73+
.consumeErrorWith((ex) -> {
74+
SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
75+
List<GraphQLError> errors = theEx.getErrors();
76+
assertThat(errors).hasSize(1);
77+
assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception");
78+
assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
79+
})
80+
.verify(TIMEOUT);
81+
}
82+
83+
@Test
84+
void resolveExceptionWithThreadLocal() {
85+
String query = "subscription { greetings }";
86+
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
87+
88+
ThreadLocal<String> nameThreadLocal = new ThreadLocal<>();
89+
nameThreadLocal.set("007");
90+
TestThreadLocalAccessor<String> accessor = new TestThreadLocalAccessor<>(nameThreadLocal);
91+
92+
try {
93+
SubscriptionExceptionResolverAdapter resolver = SubscriptionExceptionResolver.forSingleError(exception ->
94+
GraphqlErrorBuilder.newError()
95+
.message("Error: " + exception.getMessage() + ", name=" + nameThreadLocal.get())
96+
.errorType(ErrorType.BAD_REQUEST)
97+
.build());
98+
resolver.setThreadLocalContextAware(true);
99+
100+
GraphQL graphQL = GraphQlSetup.schemaContent(schema)
101+
.subscriptionFetcher("greetings", env ->
102+
Flux.create(emitter -> {
103+
emitter.next("a");
104+
emitter.error(new RuntimeException("Test Exception"));
105+
}))
106+
.subscriptionExceptionResolvers(resolver)
107+
.toGraphQl();
108+
109+
ContextView view = ReactorContextManager.extractThreadLocalValues(accessor, Context.empty());
110+
ExecutionInput input = ExecutionInput.newExecutionInput(query).build();
111+
ReactorContextManager.setReactorContext(view, input.getGraphQLContext());
112+
113+
Flux<ResponseHelper> flux = Mono.delay(Duration.ofMillis(10))
114+
.flatMap((aLong) -> Mono.fromFuture(graphQL.executeAsync(input)).map(ResponseHelper::forSubscription))
115+
.block(TIMEOUT);
116+
117+
StepVerifier.create(flux)
118+
.consumeNextWith((helper) -> assertThat(helper.toEntity("greetings", String.class)).isEqualTo("a"))
119+
.consumeErrorWith((ex) -> {
120+
SubscriptionPublisherException theEx = (SubscriptionPublisherException) ex;
121+
List<GraphQLError> errors = theEx.getErrors();
122+
assertThat(errors).hasSize(1);
123+
assertThat(errors.get(0).getMessage()).isEqualTo("Error: Test Exception, name=007");
124+
assertThat(errors.get(0).getErrorType()).isEqualTo(ErrorType.BAD_REQUEST);
125+
})
126+
.verify(TIMEOUT);
127+
}
128+
finally {
129+
nameThreadLocal.remove();
130+
}
131+
}
132+
133+
}

spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandlerTests.java

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.util.concurrent.atomic.AtomicBoolean;
2626
import java.util.function.BiConsumer;
2727

28-
import graphql.GraphqlErrorBuilder;
2928
import org.assertj.core.api.InstanceOfAssertFactories;
3029
import org.junit.jupiter.api.Test;
3130
import reactor.core.publisher.Flux;
@@ -39,7 +38,6 @@
3938
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
4039
import org.springframework.graphql.GraphQlSetup;
4140
import org.springframework.graphql.execution.ErrorType;
42-
import org.springframework.graphql.execution.SubscriptionExceptionResolver;
4341
import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor;
4442
import org.springframework.graphql.server.WebGraphQlHandler;
4543
import org.springframework.graphql.server.WebGraphQlInterceptor;
@@ -313,23 +311,20 @@ void clientCompletion() {
313311

314312
@Test
315313
void subscriptionErrorPayloadIsArray() {
316-
final String GREETING_QUERY = "{" +
314+
String query = "{" +
317315
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
318316
"\"type\":\"subscribe\"," +
319-
"\"payload\":{\"query\": \"" +
320-
" subscription TestTypenameSubscription {" +
321-
" greeting" +
322-
" }\"}" +
317+
"\"payload\":{\"query\": \"subscription { greetings }\"}" +
323318
"}";
324319

325-
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
320+
String schema = "type Subscription { greetings: String! } type Query { greeting: String! }";
326321

327322
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
328323
toWebSocketMessage("{\"type\":\"connection_init\"}"),
329-
toWebSocketMessage(GREETING_QUERY)));
324+
toWebSocketMessage(query)));
330325

331326
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
332-
.subscriptionFetcher("greeting", env -> Flux.just("a", null, "b"))
327+
.subscriptionFetcher("greetings", env -> Flux.just("a", null, "b"))
333328
.toWebGraphQlHandler();
334329

335330
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
@@ -342,7 +337,7 @@ void subscriptionErrorPayloadIsArray() {
342337
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
343338
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
344339
assertThat(actual.<Map<String, Object>>getPayload())
345-
.containsEntry("data", Collections.singletonMap("greeting", "a"));
340+
.containsEntry("data", Collections.singletonMap("greetings", "a"));
346341
})
347342
.consumeNextWith((message) -> {
348343
GraphQlWebSocketMessage actual = decode(message);
@@ -358,63 +353,6 @@ void subscriptionErrorPayloadIsArray() {
358353
.verify(TIMEOUT);
359354
}
360355

361-
@Test
362-
void subscriptionPublisherExceptionResolved() {
363-
final String GREETING_QUERY = "{" +
364-
"\"id\":\"" + SUBSCRIPTION_ID + "\"," +
365-
"\"type\":\"subscribe\"," +
366-
"\"payload\":{\"query\": \"" +
367-
" subscription TestTypenameSubscription {" +
368-
" greeting" +
369-
" }\"}" +
370-
"}";
371-
372-
String schema = "type Subscription { greeting: String! } type Query { greetingUnused: String! }";
373-
374-
TestWebSocketSession session = new TestWebSocketSession(Flux.just(
375-
toWebSocketMessage("{\"type\":\"connection_init\"}"),
376-
toWebSocketMessage(GREETING_QUERY)));
377-
378-
WebGraphQlHandler webHandler = GraphQlSetup.schemaContent(schema)
379-
.subscriptionFetcher("greeting", env ->
380-
Flux.create(emitter -> {
381-
emitter.next("a");
382-
emitter.error(new RuntimeException("Test Exception"));
383-
emitter.next("b");
384-
}))
385-
.subscriptionExceptionResolvers(SubscriptionExceptionResolver.forSingleError(exception ->
386-
GraphqlErrorBuilder.newError()
387-
.message("Error: " + exception.getMessage())
388-
.errorType(ErrorType.BAD_REQUEST)
389-
.build()))
390-
.toWebGraphQlHandler();
391-
392-
new GraphQlWebSocketHandler(webHandler, ServerCodecConfigurer.create(), TIMEOUT)
393-
.handle(session).block(TIMEOUT);
394-
395-
StepVerifier.create(session.getOutput())
396-
.consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.CONNECTION_ACK))
397-
.consumeNextWith((message) -> {
398-
GraphQlWebSocketMessage actual = decode(message);
399-
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
400-
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.NEXT);
401-
assertThat(actual.<Map<String, Object>>getPayload())
402-
.containsEntry("data", Collections.singletonMap("greeting", "a"));
403-
})
404-
.consumeNextWith((message) -> {
405-
GraphQlWebSocketMessage actual = decode(message);
406-
assertThat(actual.getId()).isEqualTo(SUBSCRIPTION_ID);
407-
assertThat(actual.resolvedType()).isEqualTo(GraphQlWebSocketMessageType.ERROR);
408-
List<Map<String, Object>> errors = actual.getPayload();
409-
assertThat(errors).hasSize(1);
410-
assertThat(errors.get(0)).containsEntry("message", "Error: Test Exception");
411-
assertThat(errors.get(0)).containsEntry("extensions",
412-
Collections.singletonMap("classification", ErrorType.BAD_REQUEST.name()));
413-
})
414-
.expectComplete()
415-
.verify(TIMEOUT);
416-
}
417-
418356
private TestWebSocketSession handle(Flux<WebSocketMessage> input, WebGraphQlInterceptor... interceptors) {
419357
GraphQlWebSocketHandler handler = new GraphQlWebSocketHandler(
420358
initHandler(interceptors),

0 commit comments

Comments
 (0)