3
3
import uuid
4
4
from datetime import datetime , timedelta , timezone
5
5
from time import monotonic
6
- from typing import Optional , Union
6
+ from typing import Any , List , Optional , Union
7
7
8
8
import pytest
9
9
10
10
from temporalio import activity , workflow
11
- from temporalio .client import Client , WorkflowFailureError
11
+ from temporalio .client import (
12
+ Client ,
13
+ Interceptor ,
14
+ OutboundInterceptor ,
15
+ StartWorkflowInput ,
16
+ WorkflowFailureError ,
17
+ WorkflowHandle ,
18
+ )
12
19
from temporalio .common import RetryPolicy
13
20
from temporalio .exceptions import (
14
21
ActivityError ,
@@ -176,7 +183,36 @@ def some_signal(self) -> None:
176
183
assert "foo" == "bar"
177
184
178
185
186
+ class SimpleClientInterceptor (Interceptor ):
187
+ def __init__ (self ) -> None :
188
+ self .events : List [str ] = []
189
+
190
+ def intercept_client (self , next : OutboundInterceptor ) -> OutboundInterceptor :
191
+ return SimpleClientOutboundInterceptor (self , super ().intercept_client (next ))
192
+
193
+
194
+ class SimpleClientOutboundInterceptor (OutboundInterceptor ):
195
+ def __init__ (
196
+ self , root : SimpleClientInterceptor , next : OutboundInterceptor
197
+ ) -> None :
198
+ super ().__init__ (next )
199
+ self .root = root
200
+
201
+ async def start_workflow (
202
+ self , input : StartWorkflowInput
203
+ ) -> WorkflowHandle [Any , Any ]:
204
+ self .root .events .append (f"start: { input .workflow } " )
205
+ return await super ().start_workflow (input )
206
+
207
+
179
208
async def test_workflow_env_assert (client : Client ):
209
+ # Set the interceptor on the client. This used to fail for being
210
+ # accidentally overridden.
211
+ client_config = client .config ()
212
+ interceptor = SimpleClientInterceptor ()
213
+ client_config ["interceptors" ] = [interceptor ]
214
+ client = Client (** client_config )
215
+
180
216
def assert_proper_error (err : Optional [BaseException ]) -> None :
181
217
assert isinstance (err , ApplicationError )
182
218
# In unsandboxed workflows, this message has extra diff info appended
@@ -195,6 +231,7 @@ def assert_proper_error(err: Optional[BaseException]) -> None:
195
231
task_queue = worker .task_queue ,
196
232
)
197
233
assert_proper_error (err .value .cause )
234
+ assert interceptor .events
198
235
199
236
# Start a new one and check signal
200
237
handle = await env .client .start_workflow (
0 commit comments