From bf7d7ee1f06813ac4ef9e77d4153ccaaa39f82e7 Mon Sep 17 00:00:00 2001 From: wuyuanyi Date: Sat, 1 May 2021 01:30:12 -0400 Subject: [PATCH 1/5] Fix MapAsyncIterator when it is cancelled due to client exiting. The tasks cancel exception is not caught and the internal __aiter__ task will not be stopped. --- src/graphql/subscription/map_async_iterator.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index 35fd42a8..cd36e703 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -1,4 +1,4 @@ -from asyncio import Event, ensure_future, Future, wait +from asyncio import Event, ensure_future, Future, wait, CancelledError from concurrent.futures import FIRST_COMPLETED from inspect import isasyncgen, isawaitable from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union @@ -43,9 +43,18 @@ async def __anext__(self) -> Any: aclose = ensure_future(self._close_event.wait()) anext = ensure_future(self.iterator.__anext__()) - pending: Set[Future] = ( - await wait([aclose, anext], return_when=FIRST_COMPLETED) - )[1] + # Suppress the StopAsyncIteration exception warning when the iterator is cancelled. + anext.add_done_callback(lambda *args: anext.exception()) + try: + pending: Set[Future] = ( + await wait([aclose, anext], return_when=FIRST_COMPLETED) + )[1] + except CancelledError: + # The iterator is cancelled + aclose.cancel() + anext.cancel() + raise StopAsyncIteration + for task in pending: task.cancel() From 2a8cd00028851fcd0934221ac62dde23883184e8 Mon Sep 17 00:00:00 2001 From: wuyuanyi Date: Sun, 2 May 2021 00:40:09 -0400 Subject: [PATCH 2/5] code style, coverage (skip), close iterator when cancelled --- src/graphql/subscription/map_async_iterator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index cd36e703..0393b06c 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -43,17 +43,19 @@ async def __anext__(self) -> Any: aclose = ensure_future(self._close_event.wait()) anext = ensure_future(self.iterator.__anext__()) - # Suppress the StopAsyncIteration exception warning when the iterator is cancelled. + # Suppress the StopAsyncIteration exception warning when the + # iterator is cancelled. anext.add_done_callback(lambda *args: anext.exception()) try: pending: Set[Future] = ( await wait([aclose, anext], return_when=FIRST_COMPLETED) )[1] - except CancelledError: + except CancelledError as e: # pragma: no cover # The iterator is cancelled aclose.cancel() anext.cancel() - raise StopAsyncIteration + self.is_closed = True + raise StopAsyncIteration from e for task in pending: task.cancel() From c1ce0820a60e51c4fe5bc830df74de86f7c33c2b Mon Sep 17 00:00:00 2001 From: wuyuanyi Date: Sun, 2 May 2021 00:47:47 -0400 Subject: [PATCH 3/5] Add test when MapAsyncIterator is cancelled. --- tests/subscription/test_map_async_iterator.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/subscription/test_map_async_iterator.py b/tests/subscription/test_map_async_iterator.py index 4df371af..2a76ed7b 100644 --- a/tests/subscription/test_map_async_iterator.py +++ b/tests/subscription/test_map_async_iterator.py @@ -1,5 +1,5 @@ import sys -from asyncio import Event, ensure_future, sleep +from asyncio import Event, ensure_future, CancelledError, create_task, sleep, Queue from pytest import mark, raises # type: ignore @@ -457,3 +457,41 @@ async def aclose(self): await anext(doubles) assert not doubles.is_closed assert not iterator.is_closed + + + @mark.asyncio + async def cancel_async_iterator_while_waiting(): + class Iterator: + def __init__(self): + self.queue: Queue[int] = Queue() + self.cancelled = False + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.queue.get() + except BaseException as ex: + self.cancelled = True + + iterator = Iterator() + doubles = MapAsyncIterator(iterator, lambda x: x + x) + + async def iterator_task(): + try: + async for double in doubles: + pass + # If cancellation is handled using StopAsyncIteration, it will reach + # here. + except CancelledError: + # Otherwise it should reach here + pass + + task = create_task(iterator_task()) + await sleep(0.1) + await doubles.aclose() + task.cancel() + await sleep(0.1) + assert iterator.cancelled + assert doubles.is_closed From 7c08ab14f6fe513ffda7b3b5f69239e007d35dac Mon Sep 17 00:00:00 2001 From: wuyuanyi Date: Sun, 2 May 2021 01:07:59 -0400 Subject: [PATCH 4/5] Enable coverage. suppress coverage errors. --- src/graphql/subscription/map_async_iterator.py | 2 +- tests/subscription/test_map_async_iterator.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index 0393b06c..39d6d321 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -50,7 +50,7 @@ async def __anext__(self) -> Any: pending: Set[Future] = ( await wait([aclose, anext], return_when=FIRST_COMPLETED) )[1] - except CancelledError as e: # pragma: no cover + except CancelledError as e: # The iterator is cancelled aclose.cancel() anext.cancel() diff --git a/tests/subscription/test_map_async_iterator.py b/tests/subscription/test_map_async_iterator.py index 2a76ed7b..e8960882 100644 --- a/tests/subscription/test_map_async_iterator.py +++ b/tests/subscription/test_map_async_iterator.py @@ -464,6 +464,7 @@ async def cancel_async_iterator_while_waiting(): class Iterator: def __init__(self): self.queue: Queue[int] = Queue() + self.queue.put_nowait(1) # suppress coverage warning self.cancelled = False def __aiter__(self): @@ -484,7 +485,7 @@ async def iterator_task(): pass # If cancellation is handled using StopAsyncIteration, it will reach # here. - except CancelledError: + except CancelledError: # pragma: no cover # Otherwise it should reach here pass From 169ae7bced0f515603e97f1def925f3d062e5009 Mon Sep 17 00:00:00 2001 From: wuyuanyi Date: Sun, 2 May 2021 01:17:29 -0400 Subject: [PATCH 5/5] Python 3.6 compatible; code style. --- tests/subscription/test_map_async_iterator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/subscription/test_map_async_iterator.py b/tests/subscription/test_map_async_iterator.py index e8960882..c9e3fca0 100644 --- a/tests/subscription/test_map_async_iterator.py +++ b/tests/subscription/test_map_async_iterator.py @@ -1,5 +1,5 @@ import sys -from asyncio import Event, ensure_future, CancelledError, create_task, sleep, Queue +from asyncio import Event, ensure_future, CancelledError, sleep, Queue from pytest import mark, raises # type: ignore @@ -458,7 +458,6 @@ async def aclose(self): assert not doubles.is_closed assert not iterator.is_closed - @mark.asyncio async def cancel_async_iterator_while_waiting(): class Iterator: @@ -473,7 +472,7 @@ def __aiter__(self): async def __anext__(self): try: return await self.queue.get() - except BaseException as ex: + except BaseException: self.cancelled = True iterator = Iterator() @@ -489,7 +488,7 @@ async def iterator_task(): # Otherwise it should reach here pass - task = create_task(iterator_task()) + task = ensure_future(iterator_task()) await sleep(0.1) await doubles.aclose() task.cancel()