diff --git a/sentry_sdk/integrations/asyncio.py b/sentry_sdk/integrations/asyncio.py index 03e320adc7..7f9b5b0c6d 100644 --- a/sentry_sdk/integrations/asyncio.py +++ b/sentry_sdk/integrations/asyncio.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from typing import Any + from collections.abc import Coroutine from sentry_sdk._types import ExcInfo @@ -37,8 +38,8 @@ def patch_asyncio(): loop = asyncio.get_running_loop() orig_task_factory = loop.get_task_factory() - def _sentry_task_factory(loop, coro): - # type: (Any, Any) -> Any + def _sentry_task_factory(loop, coro, **kwargs): + # type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any] async def _coro_creating_hub_and_span(): # type: () -> Any @@ -56,7 +57,7 @@ async def _coro_creating_hub_and_span(): # Trying to use user set task factory (if there is one) if orig_task_factory: - return orig_task_factory(loop, _coro_creating_hub_and_span()) + return orig_task_factory(loop, _coro_creating_hub_and_span(), **kwargs) # The default task factory in `asyncio` does not have its own function # but is just a couple of lines in `asyncio.base_events.create_task()` @@ -65,13 +66,13 @@ async def _coro_creating_hub_and_span(): # WARNING: # If the default behavior of the task creation in asyncio changes, # this will break! - task = Task(_coro_creating_hub_and_span(), loop=loop) + task = Task(_coro_creating_hub_and_span(), loop=loop, **kwargs) if task._source_traceback: # type: ignore del task._source_traceback[-1] # type: ignore return task - loop.set_task_factory(_sentry_task_factory) + loop.set_task_factory(_sentry_task_factory) # type: ignore except RuntimeError: # When there is no running loop, we have nothing to patch. pass diff --git a/tests/integrations/asyncio/test_asyncio_py3.py b/tests/integrations/asyncio/test_asyncio_py3.py index 98106ed01f..c563f37b7d 100644 --- a/tests/integrations/asyncio/test_asyncio_py3.py +++ b/tests/integrations/asyncio/test_asyncio_py3.py @@ -1,11 +1,22 @@ import asyncio +import inspect import sys import pytest import sentry_sdk from sentry_sdk.consts import OP -from sentry_sdk.integrations.asyncio import AsyncioIntegration +from sentry_sdk.integrations.asyncio import AsyncioIntegration, patch_asyncio + +try: + from unittest.mock import MagicMock, patch +except ImportError: + from mock import MagicMock, patch + +try: + from contextvars import Context, ContextVar +except ImportError: + pass # All tests will be skipped with incompatible versions minimum_python_37 = pytest.mark.skipif( @@ -13,6 +24,12 @@ ) +minimum_python_311 = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Asyncio task context parameter was introduced in Python 3.11", +) + + async def foo(): await asyncio.sleep(0.01) @@ -33,6 +50,17 @@ def event_loop(request): loop.close() +def get_sentry_task_factory(mock_get_running_loop): + """ + Patches (mocked) asyncio and gets the sentry_task_factory. + """ + mock_loop = mock_get_running_loop.return_value + patch_asyncio() + patched_factory = mock_loop.set_task_factory.call_args[0][0] + + return patched_factory + + @minimum_python_37 @pytest.mark.asyncio async def test_create_task( @@ -170,3 +198,173 @@ async def add(a, b): result = await asyncio.create_task(add(1, 2)) assert result == 3, result + + +@minimum_python_311 +@pytest.mark.asyncio +async def test_task_with_context(sentry_init): + """ + Integration test to ensure working context parameter in Python 3.11+ + """ + sentry_init( + integrations=[ + AsyncioIntegration(), + ], + ) + + var = ContextVar("var") + var.set("original value") + + async def change_value(): + var.set("changed value") + + async def retrieve_value(): + return var.get() + + # Create a context and run both tasks within the context + ctx = Context() + async with asyncio.TaskGroup() as tg: + tg.create_task(change_value(), context=ctx) + retrieve_task = tg.create_task(retrieve_value(), context=ctx) + + assert retrieve_task.result() == "changed value" + + +@minimum_python_37 +@patch("asyncio.get_running_loop") +def test_patch_asyncio(mock_get_running_loop): + """ + Test that the patch_asyncio function will patch the task factory. + """ + mock_loop = mock_get_running_loop.return_value + + patch_asyncio() + + assert mock_loop.set_task_factory.called + + set_task_factory_args, _ = mock_loop.set_task_factory.call_args + assert len(set_task_factory_args) == 1 + + sentry_task_factory, *_ = set_task_factory_args + assert callable(sentry_task_factory) + + +@minimum_python_37 +@pytest.mark.forked +@patch("asyncio.get_running_loop") +@patch("sentry_sdk.integrations.asyncio.Task") +def test_sentry_task_factory_no_factory(MockTask, mock_get_running_loop): # noqa: N803 + mock_loop = mock_get_running_loop.return_value + mock_coro = MagicMock() + + # Set the original task factory to None + mock_loop.get_task_factory.return_value = None + + # Retieve sentry task factory (since it is an inner function within patch_asyncio) + sentry_task_factory = get_sentry_task_factory(mock_get_running_loop) + + # The call we are testing + ret_val = sentry_task_factory(mock_loop, mock_coro) + + assert MockTask.called + assert ret_val == MockTask.return_value + + task_args, task_kwargs = MockTask.call_args + assert len(task_args) == 1 + + coro_param, *_ = task_args + assert inspect.iscoroutine(coro_param) + + assert "loop" in task_kwargs + assert task_kwargs["loop"] == mock_loop + + +@minimum_python_37 +@pytest.mark.forked +@patch("asyncio.get_running_loop") +def test_sentry_task_factory_with_factory(mock_get_running_loop): + mock_loop = mock_get_running_loop.return_value + mock_coro = MagicMock() + + # The original task factory will be mocked out here, let's retrieve the value for later + orig_task_factory = mock_loop.get_task_factory.return_value + + # Retieve sentry task factory (since it is an inner function within patch_asyncio) + sentry_task_factory = get_sentry_task_factory(mock_get_running_loop) + + # The call we are testing + ret_val = sentry_task_factory(mock_loop, mock_coro) + + assert orig_task_factory.called + assert ret_val == orig_task_factory.return_value + + task_factory_args, _ = orig_task_factory.call_args + assert len(task_factory_args) == 2 + + loop_arg, coro_arg = task_factory_args + assert loop_arg == mock_loop + assert inspect.iscoroutine(coro_arg) + + +@minimum_python_311 +@patch("asyncio.get_running_loop") +@patch("sentry_sdk.integrations.asyncio.Task") +def test_sentry_task_factory_context_no_factory( + MockTask, mock_get_running_loop # noqa: N803 +): + mock_loop = mock_get_running_loop.return_value + mock_coro = MagicMock() + mock_context = MagicMock() + + # Set the original task factory to None + mock_loop.get_task_factory.return_value = None + + # Retieve sentry task factory (since it is an inner function within patch_asyncio) + sentry_task_factory = get_sentry_task_factory(mock_get_running_loop) + + # The call we are testing + ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context) + + assert MockTask.called + assert ret_val == MockTask.return_value + + task_args, task_kwargs = MockTask.call_args + assert len(task_args) == 1 + + coro_param, *_ = task_args + assert inspect.iscoroutine(coro_param) + + assert "loop" in task_kwargs + assert task_kwargs["loop"] == mock_loop + assert "context" in task_kwargs + assert task_kwargs["context"] == mock_context + + +@minimum_python_311 +@patch("asyncio.get_running_loop") +def test_sentry_task_factory_context_with_factory(mock_get_running_loop): + mock_loop = mock_get_running_loop.return_value + mock_coro = MagicMock() + mock_context = MagicMock() + + # The original task factory will be mocked out here, let's retrieve the value for later + orig_task_factory = mock_loop.get_task_factory.return_value + + # Retieve sentry task factory (since it is an inner function within patch_asyncio) + sentry_task_factory = get_sentry_task_factory(mock_get_running_loop) + + # The call we are testing + ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context) + + assert orig_task_factory.called + assert ret_val == orig_task_factory.return_value + + task_factory_args, task_factory_kwargs = orig_task_factory.call_args + assert len(task_factory_args) == 2 + + loop_arg, coro_arg = task_factory_args + assert loop_arg == mock_loop + assert inspect.iscoroutine(coro_arg) + + assert "context" in task_factory_kwargs + assert task_factory_kwargs["context"] == mock_context