1
1
from contextvars import ContextVar
2
+ from typing import Callable , Iterator
2
3
3
4
import anyio
4
5
import pytest
8
9
from starlette .requests import Request
9
10
from starlette .responses import Response
10
11
from starlette .routing import Route
12
+ from starlette .testclient import TestClient
13
+
14
+ TestClientFactory = Callable [..., TestClient ]
11
15
12
16
13
17
@pytest .mark .anyio
14
- async def test_run_until_first_complete ():
18
+ async def test_run_until_first_complete () -> None :
15
19
task1_finished = anyio .Event ()
16
20
task2_finished = anyio .Event ()
17
21
18
- async def task1 ():
22
+ async def task1 () -> None :
19
23
task1_finished .set ()
20
24
21
- async def task2 ():
25
+ async def task2 () -> None :
22
26
await task1_finished .wait ()
23
27
await anyio .sleep (0 ) # pragma: nocover
24
28
task2_finished .set () # pragma: nocover
@@ -28,7 +32,9 @@ async def task2():
28
32
assert not task2_finished .is_set ()
29
33
30
34
31
- def test_accessing_context_from_threaded_sync_endpoint (test_client_factory ) -> None :
35
+ def test_accessing_context_from_threaded_sync_endpoint (
36
+ test_client_factory : TestClientFactory ,
37
+ ) -> None :
32
38
ctxvar : ContextVar [bytes ] = ContextVar ("ctxvar" )
33
39
ctxvar .set (b"data" )
34
40
@@ -45,7 +51,7 @@ def endpoint(request: Request) -> Response:
45
51
@pytest .mark .anyio
46
52
async def test_iterate_in_threadpool () -> None :
47
53
class CustomIterable :
48
- def __iter__ (self ):
54
+ def __iter__ (self ) -> Iterator [ int ] :
49
55
yield from range (3 )
50
56
51
57
assert [v async for v in iterate_in_threadpool (CustomIterable ())] == [0 , 1 , 2 ]
0 commit comments