diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 3fcea6fc..c0b65da2 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -125,14 +125,18 @@ def pytest_pyfunc_call(pyfuncitem): if 'asyncio' in pyfuncitem.keywords: if getattr(pyfuncitem.obj, 'is_hypothesis_test', False): pyfuncitem.obj.hypothesis.inner_test = wrap_in_sync( - pyfuncitem.obj.hypothesis.inner_test + pyfuncitem.obj.hypothesis.inner_test, + _loop=pyfuncitem.funcargs['event_loop'] ) else: - pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj) + pyfuncitem.obj = wrap_in_sync( + pyfuncitem.obj, + _loop=pyfuncitem.funcargs['event_loop'] + ) yield -def wrap_in_sync(func): +def wrap_in_sync(func, _loop): """Return a sync wrapper around an async function executing it in the current event loop.""" @@ -140,9 +144,15 @@ def wrap_in_sync(func): def inner(**kwargs): coro = func(**kwargs) if coro is not None: - task = asyncio.ensure_future(coro) try: - asyncio.get_event_loop().run_until_complete(task) + loop = asyncio.get_event_loop() + except RuntimeError as exc: + if 'no current event loop' not in str(exc): + raise + loop = _loop + task = asyncio.ensure_future(coro, loop=loop) + try: + loop.run_until_complete(task) except BaseException: # run_until_complete doesn't get the result from exceptions # that are not subclasses of `Exception`. Consume all @@ -154,9 +164,11 @@ def inner(**kwargs): def pytest_runtest_setup(item): - if 'asyncio' in item.keywords and 'event_loop' not in item.fixturenames: + if 'asyncio' in item.keywords: # inject an event loop fixture for all async tests - item.fixturenames.append('event_loop') + if 'event_loop' in item.fixturenames: + item.fixturenames.remove('event_loop') + item.fixturenames.insert(0, 'event_loop') if item.get_closest_marker("asyncio") is not None \ and not getattr(item.obj, 'hypothesis', False) \ and getattr(item.obj, 'is_hypothesis_test', False): diff --git a/tests/test_simple.py b/tests/test_simple.py index 00c07fcb..c8dccaf8 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -128,6 +128,29 @@ async def test_asyncio_marker_without_loop(self, remove_loop): assert ret == 'ok' +class TestEventLoopStartedBeforeFixtures: + @pytest.fixture + async def loop(self): + return asyncio.get_event_loop() + + @staticmethod + def foo(): + return 1 + + @pytest.mark.asyncio + async def test_no_event_loop(self, loop): + assert await loop.run_in_executor(None, self.foo) == 1 + + @pytest.mark.asyncio + async def test_event_loop_after_fixture(self, loop, event_loop): + assert await loop.run_in_executor(None, self.foo) == 1 + + @pytest.mark.asyncio + async def test_event_loop_before_fixture(self, event_loop, loop): + assert await loop.run_in_executor(None, self.foo) == 1 + + + @pytest.mark.asyncio async def test_no_warning_on_skip(): pytest.skip("Test a skip error inside asyncio")