diff --git a/Lib/asyncio/taskgroups.py b/Lib/asyncio/taskgroups.py index 0fdea3697ece3d..4d20598388b151 100644 --- a/Lib/asyncio/taskgroups.py +++ b/Lib/asyncio/taskgroups.py @@ -23,9 +23,11 @@ class TaskGroup: Any exceptions other than `asyncio.CancelledError` raised within a task will cancel all remaining tasks and wait for them to exit. + You can prevent this behavior by passing `defer_errors=True` to + the constructor. The exceptions are then combined and raised as an `ExceptionGroup`. """ - def __init__(self): + def __init__(self, defer_errors=False): self._entered = False self._exiting = False self._aborting = False @@ -36,6 +38,7 @@ def __init__(self): self._errors = [] self._base_error = None self._on_completed_fut = None + self._defer_errors = defer_errors def __repr__(self): info = [''] @@ -198,7 +201,8 @@ def _on_task_done(self, task): return self._errors.append(exc) - if self._is_base_error(exc) and self._base_error is None: + is_base_error = self._is_base_error(exc) + if is_base_error and self._base_error is None: self._base_error = exc if self._parent_task.done(): @@ -231,6 +235,7 @@ def _on_task_done(self, task): # pass # await something_else # this line has to be called # # after TaskGroup is finished. - self._abort() - self._parent_cancel_requested = True - self._parent_task.cancel() + if not self._defer_errors or is_base_error: + self._abort() + self._parent_cancel_requested = True + self._parent_task.cancel() diff --git a/Lib/test/test_asyncio/test_taskgroups.py b/Lib/test/test_asyncio/test_taskgroups.py index 6a0231f2859a62..55357a803c7d3f 100644 --- a/Lib/test/test_asyncio/test_taskgroups.py +++ b/Lib/test/test_asyncio/test_taskgroups.py @@ -779,6 +779,440 @@ async def main(): await asyncio.create_task(main()) + async def test_children_complete_on_child_error(self): + async def zero_division(): + 1 / 0 + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + t1 = g.create_task(foo1()) + t2 = g.create_task(foo2()) + g.create_task(zero_division()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertEqual(t1.result(), 42) + self.assertEqual(t2.result(), 11) + + async def test_inner_complete_on_child_error(self): + async def zero_division(): + 1 / 0 + + async def foo1(): + await asyncio.sleep(0.1) + return 42 + + async def foo2(): + await asyncio.sleep(0.2) + return 11 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + t1 = g.create_task(foo1()) + g.create_task(zero_division()) + r1 = await foo2() + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertEqual(t1.result(), 42) + self.assertEqual(r1, 11) + + async def test_children_exceptions_propagate(self): + async def zero_division(): + 1 / 0 + + async def value_error(): + await asyncio.sleep(0.2) + raise ValueError + + async def foo1(): + await asyncio.sleep(0.4) + return 42 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(zero_division()) + g.create_task(value_error()) + t1 = g.create_task(foo1()) + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError, ValueError}) + + self.assertEqual(t1.result(), 42) + + async def test_children_cancel_on_inner_failure(self): + async def zero_division(): + 1 / 0 + + async def foo1(): + await asyncio.sleep(0.2) + return 42 + + with self.assertRaises(ExceptionGroup) as cm: + async with taskgroups.TaskGroup(defer_errors=True) as g: + t1 = g.create_task(foo1()) + await zero_division() + + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + self.assertTrue(t1.cancelled()) + + async def test_cancellation_01(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + for _ in range(5): + g.create_task(foo()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError) as cm: + await r + + self.assertEqual(NUM, 5) + + async def test_taskgroup_35(self): + + NUM = 0 + + async def foo(): + nonlocal NUM + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + NUM += 1 + raise + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(defer_errors=True) as g: + for _ in range(5): + g.create_task(foo()) + + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 15) + + async def test_taskgroup_36(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + for _ in range(5): + g.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_37(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True): + async with taskgroups.TaskGroup(defer_errors=True) as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_37a(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(): + async with taskgroups.TaskGroup(defer_errors=True) as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_38(self): + + async def foo(): + try: + await asyncio.sleep(10) + finally: + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g1: + g1.create_task(asyncio.sleep(10)) + + async with taskgroups.TaskGroup(defer_errors=True) as g2: + for _ in range(5): + g2.create_task(foo()) + + await asyncio.sleep(10) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + + self.assertEqual(get_error_types(cm.exception), {ExceptionGroup}) + self.assertEqual(get_error_types(cm.exception.exceptions[0]), {ZeroDivisionError}) + + async def test_taskgroup_39(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_40(self): + + async def crash_soon(): + await asyncio.sleep(0.3) + 1 / 0 + + async def nested_runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g1: + g1.create_task(crash_soon()) + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + await asyncio.sleep(0.5) + raise + + async def runner(): + t = asyncio.create_task(nested_runner()) + await t + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(ExceptionGroup) as cm: + await r + self.assertEqual(get_error_types(cm.exception), {ZeroDivisionError}) + + async def test_taskgroup_41(self): + + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(defer_errors=True): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + raise + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 10) + + async def test_taskgroup_42(self): + + NUM = 0 + + async def runner(): + nonlocal NUM + async with taskgroups.TaskGroup(defer_errors=True): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + NUM += 10 + # This isn't a good idea, but we have to support + # this weird case. + raise MyExc + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.1) + + self.assertFalse(r.done()) + r.cancel() + + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t),{MyExc}) + else: + self.fail('ExceptionGroup was not raised') + + self.assertEqual(NUM, 10) + + async def test_taskgroup_43(self): + + async def crash_soon(): + await asyncio.sleep(0.1) + 1 / 0 + + async def nested(): + try: + await asyncio.sleep(10) + finally: + raise MyExc + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(crash_soon()) + await nested() + + r = asyncio.create_task(runner()) + try: + await r + except ExceptionGroup as t: + self.assertEqual(get_error_types(t), {MyExc, ZeroDivisionError}) + else: + self.fail('TasgGroupError was not raised') + + async def test_taskgroup_44(self): + + async def foo1(): + await asyncio.sleep(1) + return 42 + + async def foo2(): + await asyncio.sleep(2) + return 11 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(foo1()) + g.create_task(foo2()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.05) + r.cancel() + + with self.assertRaises(asyncio.CancelledError): + await r + + async def test_taskgroup_45(self): + + NUM = 0 + + async def foo1(): + nonlocal NUM + await asyncio.sleep(0.2) + NUM += 1 + + async def foo2(): + nonlocal NUM + await asyncio.sleep(0.3) + NUM += 2 + + async def runner(): + async with taskgroups.TaskGroup(defer_errors=True) as g: + g.create_task(foo1()) + g.create_task(foo2()) + + r = asyncio.create_task(runner()) + await asyncio.sleep(0.05) + r.cancel() + + with self.assertRaises(asyncio.CancelledError): + await r + + self.assertEqual(NUM, 0) + + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst b/Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst new file mode 100644 index 00000000000000..e179ea9f680129 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-02-07-18-08-32.gh-issue-101581.lxw8WY.rst @@ -0,0 +1,2 @@ +Adds ``defer_errors`` flag to :class:`asyncio.TaskGroup` to optionally +prevent child tasks from being cancelled.