diff --git a/src/neo4j/_async_compat/network/_bolt_socket.py b/src/neo4j/_async_compat/network/_bolt_socket.py index 83a663e57..357e7410f 100644 --- a/src/neo4j/_async_compat/network/_bolt_socket.py +++ b/src/neo4j/_async_compat/network/_bolt_socket.py @@ -98,7 +98,7 @@ def __init__(self, reader, protocol, writer): self._timeout = None self._deadline = None - async def _wait_for_io(self, io_fut): + async def _wait_for_io(self, io_async_fn, *args, **kwargs): timeout = self._timeout to_raise = SocketTimeout if self._deadline is not None: @@ -109,6 +109,7 @@ async def _wait_for_io(self, io_fut): timeout = deadline_timeout to_raise = SocketDeadlineExceededError + io_fut = io_async_fn(*args, **kwargs) if timeout is not None and timeout <= 0: # give the io-operation time for one loop cycle to do its thing io_fut = asyncio.create_task(io_fut) @@ -157,20 +158,17 @@ def settimeout(self, timeout): self._timeout = timeout async def recv(self, n): - io_fut = self._reader.read(n) - return await self._wait_for_io(io_fut) + return await self._wait_for_io(self._reader.read, n) async def recv_into(self, buffer, nbytes): # FIXME: not particularly memory or time efficient - io_fut = self._reader.read(nbytes) - res = await self._wait_for_io(io_fut) + res = await self._wait_for_io(self._reader.read, nbytes) buffer[: len(res)] = res return len(res) async def sendall(self, data): self._writer.write(data) - io_fut = self._writer.drain() - return await self._wait_for_io(io_fut) + return await self._wait_for_io(self._writer.drain) async def close(self): self._writer.close() diff --git a/tests/unit/mixed/async_compat/test_network.py b/tests/unit/mixed/async_compat/test_network.py new file mode 100644 index 000000000..ee4c290d9 --- /dev/null +++ b/tests/unit/mixed/async_compat/test_network.py @@ -0,0 +1,164 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import asyncio +import socket +import typing as t + +import freezegun +import pytest + +from neo4j._async_compat.network import AsyncBoltSocket +from neo4j._exceptions import SocketDeadlineExceededError + +from ...._async_compat.mark_decorator import mark_async_test + + +if t.TYPE_CHECKING: + import typing_extensions as te + from freezegun.api import ( + FrozenDateTimeFactory, + StepTickTimeFactory, + TickingDateTimeFactory, + ) + + TFreezeTime: te.TypeAlias = ( + StepTickTimeFactory | TickingDateTimeFactory | FrozenDateTimeFactory + ) + + +@pytest.fixture +def reader_factory(mocker): + def factory(): + return mocker.create_autospec(asyncio.StreamReader) + + return factory + + +@pytest.fixture +def writer_factory(mocker): + def factory(): + return mocker.create_autospec(asyncio.StreamWriter) + + return factory + + +@pytest.fixture +def socket_factory(reader_factory, writer_factory): + def factory(): + protocol = None + return AsyncBoltSocket(reader_factory(), protocol, writer_factory()) + + return factory + + +def reader(s: AsyncBoltSocket): + return s._reader + + +def writer(s: AsyncBoltSocket): + return s._writer + + +@pytest.mark.parametrize( + ("timeout", "deadline", "pre_tick", "tick", "exception"), + ( + (None, None, 60 * 60 * 10, 60 * 60 * 10, None), + # test timeout + (5, None, 0, 4, None), + # timeout is not affected by time passed before the call + (5, None, 7, 4, None), + (5, None, 0, 6, socket.timeout), + # test deadline + (None, 5, 0, 4, None), + (None, 5, 2, 2, None), + # deadline is affected by time passed before the call + (None, 5, 2, 4, SocketDeadlineExceededError), + (None, 5, 6, 0, SocketDeadlineExceededError), + (None, 5, 0, 6, SocketDeadlineExceededError), + # test combination + (5, 5, 0, 4, None), + (5, 5, 2, 2, None), + # deadline triggered by time passed before + (5, 5, 2, 4, SocketDeadlineExceededError), + # the shorter one determines the error + (4, 5, 0, 6, socket.timeout), + (5, 4, 0, 6, SocketDeadlineExceededError), + ), +) +@pytest.mark.parametrize("method", ("recv", "recv_into", "sendall")) +@mark_async_test +async def test_async_bolt_socket_read_timeout( + socket_factory, timeout, deadline, pre_tick, tick, exception, method +): + def make_read_side_effect(freeze_time: TFreezeTime): + async def read_side_effect(n): + assert n == 1 + freeze_time.tick(tick) + for _ in range(10): + await asyncio.sleep(0) + return b"y" + + return read_side_effect + + def make_drain_side_effect(freeze_time: TFreezeTime): + async def drain_side_effect(): + freeze_time.tick(tick) + for _ in range(10): + await asyncio.sleep(0) + + return drain_side_effect + + async def call_method(s: AsyncBoltSocket): + if method == "recv": + res = await s.recv(1) + assert res == b"y" + elif method == "recv_into": + b = bytearray(1) + await s.recv_into(b, 1) + assert b == b"y" + elif method == "sendall": + await s.sendall(b"y") + else: + raise NotImplementedError(f"method: {method}") + + with freezegun.freeze_time("1970-01-01T00:00:00") as frozen_time: + socket = socket_factory() + if timeout is not None: + socket.settimeout(timeout) + if deadline is not None: + socket.set_deadline(deadline) + if pre_tick: + frozen_time.tick(pre_tick) + + if method in {"recv", "recv_into"}: + reader(socket).read.side_effect = make_read_side_effect( + frozen_time + ) + elif method == "sendall": + writer(socket).drain.side_effect = make_drain_side_effect( + frozen_time + ) + else: + raise NotImplementedError(f"method: {method}") + + if exception: + with pytest.raises(exception): + await call_method(socket) + else: + await call_method(socket)