Skip to content

Commit 59fbccf

Browse files
authored
Fix interceptors in testing environment (#364)
Fixes #363
1 parent 24fea4c commit 59fbccf

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

temporalio/testing/_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,5 +562,5 @@ def _client_with_interceptors(
562562
config = client.config()
563563
config_interceptors = list(config["interceptors"])
564564
config_interceptors.extend(interceptors)
565-
config["interceptors"] = interceptors
565+
config["interceptors"] = config_interceptors
566566
return temporalio.client.Client(**config)

tests/testing/test_workflow.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@
33
import uuid
44
from datetime import datetime, timedelta, timezone
55
from time import monotonic
6-
from typing import Optional, Union
6+
from typing import Any, List, Optional, Union
77

88
import pytest
99

1010
from temporalio import activity, workflow
11-
from temporalio.client import Client, WorkflowFailureError
11+
from temporalio.client import (
12+
Client,
13+
Interceptor,
14+
OutboundInterceptor,
15+
StartWorkflowInput,
16+
WorkflowFailureError,
17+
WorkflowHandle,
18+
)
1219
from temporalio.common import RetryPolicy
1320
from temporalio.exceptions import (
1421
ActivityError,
@@ -176,7 +183,36 @@ def some_signal(self) -> None:
176183
assert "foo" == "bar"
177184

178185

186+
class SimpleClientInterceptor(Interceptor):
187+
def __init__(self) -> None:
188+
self.events: List[str] = []
189+
190+
def intercept_client(self, next: OutboundInterceptor) -> OutboundInterceptor:
191+
return SimpleClientOutboundInterceptor(self, super().intercept_client(next))
192+
193+
194+
class SimpleClientOutboundInterceptor(OutboundInterceptor):
195+
def __init__(
196+
self, root: SimpleClientInterceptor, next: OutboundInterceptor
197+
) -> None:
198+
super().__init__(next)
199+
self.root = root
200+
201+
async def start_workflow(
202+
self, input: StartWorkflowInput
203+
) -> WorkflowHandle[Any, Any]:
204+
self.root.events.append(f"start: {input.workflow}")
205+
return await super().start_workflow(input)
206+
207+
179208
async def test_workflow_env_assert(client: Client):
209+
# Set the interceptor on the client. This used to fail for being
210+
# accidentally overridden.
211+
client_config = client.config()
212+
interceptor = SimpleClientInterceptor()
213+
client_config["interceptors"] = [interceptor]
214+
client = Client(**client_config)
215+
180216
def assert_proper_error(err: Optional[BaseException]) -> None:
181217
assert isinstance(err, ApplicationError)
182218
# In unsandboxed workflows, this message has extra diff info appended
@@ -195,6 +231,7 @@ def assert_proper_error(err: Optional[BaseException]) -> None:
195231
task_queue=worker.task_queue,
196232
)
197233
assert_proper_error(err.value.cause)
234+
assert interceptor.events
198235

199236
# Start a new one and check signal
200237
handle = await env.client.start_workflow(

0 commit comments

Comments
 (0)