|
1 | 1 | from contextlib import contextmanager
|
2 |
| -from typing import Any, Generator |
| 2 | +from typing import Any, Dict, Generator, Tuple |
3 | 3 | from unittest.mock import patch
|
4 | 4 |
|
| 5 | +from httpcore import AsyncByteStream, SyncByteStream |
5 | 6 | from httpcore._async.connection import AsyncHTTPConnection
|
6 | 7 | from httpcore._sync.connection import SyncHTTPConnection
|
7 |
| -from httpx import URL |
| 8 | +from httpcore._types import URL, Headers |
8 | 9 |
|
| 10 | +from pytest_httpx_blockage.contextvar import is_blockage_enabled |
9 | 11 | from pytest_httpx_blockage.exceptions import RequestBlockageException
|
10 | 12 |
|
| 13 | +base_request_sync = SyncHTTPConnection.request |
| 14 | +base_request_async = AsyncHTTPConnection.arequest |
| 15 | + |
11 | 16 |
|
12 | 17 | def side_effect(
|
| 18 | + self: SyncHTTPConnection, |
13 | 19 | method: bytes,
|
14 | 20 | url: URL,
|
15 | 21 | *args: Any,
|
16 | 22 | **kwargs: Any,
|
17 |
| -) -> None: |
18 |
| - raise RequestBlockageException(f'Unmocked "{method.decode()}" request to host="{url}"') |
| 23 | +) -> Tuple[int, Headers, SyncByteStream, Dict[Any, Any]]: |
| 24 | + if is_blockage_enabled.get(): |
| 25 | + raise RequestBlockageException(f'Unmocked "{method.decode()}" request to host="{url}"') |
| 26 | + else: |
| 27 | + return base_request_sync( |
| 28 | + self, |
| 29 | + method, |
| 30 | + url, |
| 31 | + *args, |
| 32 | + **kwargs, |
| 33 | + ) |
19 | 34 |
|
20 | 35 |
|
21 | 36 | async def async_side_effect(
|
| 37 | + self: AsyncHTTPConnection, |
22 | 38 | method: bytes,
|
23 | 39 | url: URL,
|
24 | 40 | *args: Any,
|
25 | 41 | **kwargs: Any,
|
26 |
| -) -> None: |
27 |
| - side_effect(method=method, url=url) |
| 42 | +) -> Tuple[int, Headers, AsyncByteStream, Dict[Any, Any]]: |
| 43 | + if is_blockage_enabled.get(): |
| 44 | + raise RequestBlockageException(f'Unmocked "{method.decode()}" request to host="{url}"') |
| 45 | + else: |
| 46 | + return await base_request_async( |
| 47 | + self, |
| 48 | + method, |
| 49 | + url, |
| 50 | + *args, |
| 51 | + **kwargs, |
| 52 | + ) |
28 | 53 |
|
29 | 54 |
|
30 | 55 | @contextmanager
|
31 | 56 | def blockage() -> Generator[None, None, None]:
|
32 |
| - patch_sync = patch.object(SyncHTTPConnection, 'request') |
33 |
| - patch_async = patch.object(AsyncHTTPConnection, 'arequest') |
| 57 | + patch_sync = patch.object(SyncHTTPConnection, 'request', autospec=True) |
| 58 | + patch_async = patch.object(AsyncHTTPConnection, 'arequest', autospec=True) |
| 59 | + |
34 | 60 | with patch_sync as mocked_sync, patch_async as mocked_async:
|
35 | 61 | mocked_sync.side_effect = side_effect
|
36 | 62 | mocked_async.side_effect = async_side_effect
|
|
0 commit comments