diff --git a/pyproject.toml b/pyproject.toml index 180b4764..4d881171 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -277,6 +277,9 @@ select = [ # allow async functions without await to enable type checking, pretending to be async, matching type signatures "RUF029", ] +"tests/**" = [ + "B011", # allow `assert False` in tests, they won't be run with -O anyway +] "bin/**" = [ "T20", # print statements are ok in our helper scripts ] diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 6ca539d1..f5cdedb2 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -35,6 +35,7 @@ from ..._exceptions import ( BoltError, BoltHandshakeError, + SocketDeadlineExceededError, ) from ..._io import BoltProtocolVersion from ..._meta import USER_AGENT @@ -887,6 +888,15 @@ async def _set_defunct_write(self, error=None, silent=False): async def _set_defunct(self, message, error=None, silent=False): direct_driver = getattr(self.pool, "is_direct_pool", False) user_cancelled = isinstance(error, asyncio.CancelledError) + connection_failed = isinstance( + error, + ( + ServiceUnavailable, + SessionExpired, + OSError, + SocketDeadlineExceededError, + ), + ) if not (user_cancelled or self._closing): log_call = log.error @@ -913,6 +923,12 @@ async def _set_defunct(self, message, error=None, silent=False): if user_cancelled: self.kill() raise error # cancellation error should not be re-written + if not connection_failed: + # Something else but the connection failed + # => we're not sure which state we're in + # => ditch the connection and raise the error for user-awareness + await self.close() + raise error if not self._closing: # If we fail while closing the connection, there is no need to # remove the connection from the pool, nor to try to close the diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 491cb60b..5e412178 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -81,6 +81,15 @@ async def pop(self, hydration_hooks): self._unpacker.unpack(hydration_hooks) for _ in range(size) ] return tag, fields + except Exception as error: + log.debug( + "[#%04X] _: Failed to unpack response: %r", + self._local_port, + error, + ) + self._broken = True + await AsyncUtil.callback(self.on_error, error) + raise finally: # Reset for new message self._unpacker.reset() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index c836e924..64d92c3f 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -35,6 +35,7 @@ from ..._exceptions import ( BoltError, BoltHandshakeError, + SocketDeadlineExceededError, ) from ..._io import BoltProtocolVersion from ..._meta import USER_AGENT @@ -887,6 +888,15 @@ def _set_defunct_write(self, error=None, silent=False): def _set_defunct(self, message, error=None, silent=False): direct_driver = getattr(self.pool, "is_direct_pool", False) user_cancelled = isinstance(error, asyncio.CancelledError) + connection_failed = isinstance( + error, + ( + ServiceUnavailable, + SessionExpired, + OSError, + SocketDeadlineExceededError, + ), + ) if not (user_cancelled or self._closing): log_call = log.error @@ -913,6 +923,12 @@ def _set_defunct(self, message, error=None, silent=False): if user_cancelled: self.kill() raise error # cancellation error should not be re-written + if not connection_failed: + # Something else but the connection failed + # => we're not sure which state we're in + # => ditch the connection and raise the error for user-awareness + self.close() + raise error if not self._closing: # If we fail while closing the connection, there is no need to # remove the connection from the pool, nor to try to close the diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index d863091a..b97008dd 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -81,6 +81,15 @@ def pop(self, hydration_hooks): self._unpacker.unpack(hydration_hooks) for _ in range(size) ] return tag, fields + except Exception as error: + log.debug( + "[#%04X] _: Failed to unpack response: %r", + self._local_port, + error, + ) + self._broken = True + Util.callback(self.on_error, error) + raise finally: # Reset for new message self._unpacker.reset() diff --git a/tests/unit/async_/io/_common/__init__.py b/tests/unit/async_/io/_common/__init__.py new file mode 100644 index 00000000..3f968099 --- /dev/null +++ b/tests/unit/async_/io/_common/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/async_/io/_common/test_inbox.py b/tests/unit/async_/io/_common/test_inbox.py new file mode 100644 index 00000000..1699fa5e --- /dev/null +++ b/tests/unit/async_/io/_common/test_inbox.py @@ -0,0 +1,140 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import pytest + +from neo4j._async.io._common import AsyncInbox +from neo4j._codec.packstream.v1 import Unpacker +from neo4j._exceptions import SocketDeadlineExceededError + +from ....._async_compat import mark_async_test + + +class InboxMockHolder: + def __init__(self, mocker): + self.socket_mock = mocker.Mock() + self.socket_mock.getsockname.return_value = ("host", 1234) + self.on_error = mocker.AsyncMock() + self.inbox = AsyncInbox(self.socket_mock, self.on_error, Unpacker) + self.unpacker = mocker.Mock(wraps=self.inbox._unpacker) + self.inbox._unpacker = self.unpacker + # plenty of nonsense messages to read + self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000) + + def mock_set_data(self, data): + async def side_effect(buffer, n): + nonlocal data + + if not data: + pytest.fail("Read more data than mocked") + + n = min(len(data), len(buffer), n) + buffer[:n] = data[:n] + data = data[n:] + return n + + self.socket_mock.recv_into.side_effect = side_effect + + def assert_no_error(self): + self.on_error.assert_not_called() + assert not self.inbox._broken + + def mock_receive_failure(self, exception): + self.socket_mock.recv_into.side_effect = exception + + def mock_unpack_failure(self, exception): + self.unpacker.unpack_structure_header.side_effect = exception + + +@pytest.mark.parametrize( + ("data", "result"), + ( + ( + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14, 0, 0)), + bytes(range(10, 15)), + ), + ( + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 0)), + bytes(range(10, 14)), + ), + ( + bytes((0, 1, 5, 0, 0)), + bytes((5,)), + ), + ), +) +@mark_async_test +async def test_inbox_dechunking(data, result, mocker): + # Given + mocks = InboxMockHolder(mocker) + mocks.mock_set_data(data) + inbox = mocks.inbox + buffer = inbox._buffer + + # When + await inbox._buffer_one_chunk() + + # Then + mocks.assert_no_error() + assert buffer.used == len(result) + assert buffer.data[: len(result)] == result + + +@pytest.mark.parametrize( + "error", + ( + asyncio.CancelledError("test"), + SocketDeadlineExceededError("test"), + OSError("test"), + ), +) +@mark_async_test +async def test_inbox_receive_failure_error_handler(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_receive_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + await inbox.pop({}) + + assert exc.value is error + mocks.on_error.assert_awaited_once_with(error) + assert inbox._broken + + +@pytest.mark.parametrize( + "error", + ( + SocketDeadlineExceededError("test"), + OSError("test"), + RecursionError("2deep4u"), + RuntimeError("something funny happened"), + ), +) +@mark_async_test +async def test_inbox_unpack_failure(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_unpack_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + await inbox.pop({}) + + assert exc.value is error + mocks.on_error.assert_awaited_once_with(error) + assert inbox._broken diff --git a/tests/unit/async_/io/_common/test_oubox.py b/tests/unit/async_/io/_common/test_oubox.py new file mode 100644 index 00000000..57762224 --- /dev/null +++ b/tests/unit/async_/io/_common/test_oubox.py @@ -0,0 +1,147 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import pytest + +from neo4j._async.io._common import AsyncOutbox +from neo4j._codec.packstream.v1 import PackableBuffer +from neo4j._exceptions import SocketDeadlineExceededError + +from ....._async_compat import mark_async_test + + +class OutboxMockHolder: + def __init__(self, mocker, *, max_chunk_size=16384): + self.buffer = PackableBuffer() + self.socket_mock = mocker.AsyncMock() + self.packer_mock = mocker.Mock() + self.packer_mock.return_value = self.packer_mock + self.packer_mock.new_packable_buffer.return_value = self.buffer + self.on_error = mocker.AsyncMock() + self.outbox = AsyncOutbox( + self.socket_mock, + self.on_error, + self.packer_mock, + max_chunk_size=max_chunk_size, + ) + + def mock_write_message(self, data): + def side_effect(*_args, **_kwargs): + self.buffer.write(data) + + self.packer_mock.pack_struct.side_effect = side_effect + + def assert_no_error(self): + self.on_error.assert_not_called() + + def mock_pack_failure(self, exception): + def side_effect(*_args, **_kwargs): + self.buffer.write(b"some data") + raise exception + + self.packer_mock.pack_struct.side_effect = side_effect + + def mock_send_failure(self, exception): + self.socket_mock.sendall.side_effect = exception + + +@pytest.mark.parametrize( + ("chunk_size", "data", "result"), + ( + ( + 2, + bytes(range(10, 15)), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)), + ), + ( + 2, + bytes(range(10, 14)), + bytes((0, 2, 10, 11, 0, 2, 12, 13)), + ), + ( + 2, + bytes((5,)), + bytes((0, 1, 5)), + ), + ), +) +@mark_async_test +async def test_async_outbox_chunking(chunk_size, data, result, mocker): + # Given + mocks = OutboxMockHolder(mocker, max_chunk_size=chunk_size) + mocks.mock_write_message(data) + outbox = mocks.outbox + socket_mock = mocks.socket_mock + + # When + outbox.append_message(None, None, None) + + # Then + mocks.assert_no_error() + socket_mock.sendall.assert_not_called() + assert await outbox.flush() + socket_mock.sendall.assert_awaited_once_with(result + b"\x00\x00") + + assert not await outbox.flush() + socket_mock.sendall.assert_awaited_once() + + +@pytest.mark.parametrize( + "error", + ( + asyncio.CancelledError("test"), + SocketDeadlineExceededError("test"), + OSError("test"), + ), +) +@mark_async_test +async def test_outbox_send_failure_error_handler(mocker, error): + mocks = OutboxMockHolder(mocker, max_chunk_size=12345) + mocks.mock_send_failure(error) + outbox = mocks.outbox + + outbox.append_message(None, None, None) + assert not await outbox.flush() + + mocks.on_error.assert_awaited_once_with(error) + + +@pytest.mark.parametrize( + "error", + ( + asyncio.CancelledError("test"), + SocketDeadlineExceededError("test"), + OSError("test"), + RecursionError("2deep4u"), + RuntimeError("something funny happened"), + ), +) +@mark_async_test +async def test_outbox_pack_failure(mocker, error): + mocks = OutboxMockHolder(mocker, max_chunk_size=12345) + mocks.mock_pack_failure(error) + outbox = mocks.outbox + socket_mock = mocks.socket_mock + + with pytest.raises(type(error)) as exc: + outbox.append_message(None, None, None) + assert not await outbox.flush() + + assert exc.value is error + mocks.on_error.assert_not_called() + socket_mock.sendall.assert_not_called() diff --git a/tests/unit/async_/io/test__common.py b/tests/unit/async_/io/_common/test_response_handler.py similarity index 71% rename from tests/unit/async_/io/test__common.py rename to tests/unit/async_/io/_common/test_response_handler.py index b4556edf..1f458b52 100644 --- a/tests/unit/async_/io/test__common.py +++ b/tests/unit/async_/io/_common/test_response_handler.py @@ -18,53 +18,9 @@ import pytest -from neo4j._async.io._common import ( - AsyncOutbox, - ResetResponse, -) -from neo4j._codec.packstream.v1 import PackableBuffer - -from ...._async_compat import mark_async_test - +from neo4j._async.io._common import ResetResponse -@pytest.mark.parametrize( - ("chunk_size", "data", "result"), - ( - ( - 2, - bytes(range(10, 15)), - bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)), - ), - ( - 2, - bytes(range(10, 14)), - bytes((0, 2, 10, 11, 0, 2, 12, 13)), - ), - ( - 2, - bytes((5,)), - bytes((0, 1, 5)), - ), - ), -) -@mark_async_test -async def test_async_outbox_chunking(chunk_size, data, result, mocker): - buffer = PackableBuffer() - socket_mock = mocker.AsyncMock() - packer_mock = mocker.Mock() - packer_mock.return_value = packer_mock - packer_mock.new_packable_buffer.return_value = buffer - packer_mock.pack_struct.side_effect = lambda *args, **kwargs: buffer.write( - data - ) - outbox = AsyncOutbox(socket_mock, pytest.fail, packer_mock, chunk_size) - outbox.append_message(None, None, None) - socket_mock.sendall.assert_not_called() - assert await outbox.flush() - socket_mock.sendall.assert_awaited_once_with(result + b"\x00\x00") - - assert not await outbox.flush() - socket_mock.sendall.assert_awaited_once() +from ....._async_compat import mark_async_test def get_handler_arg(response): diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 010916eb..45f2af1e 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -15,13 +15,26 @@ import asyncio +import random import pytest import neo4j.auth_management from neo4j._async.io import AsyncBolt +from neo4j._async.io._bolt5 import AsyncBolt5x8 from neo4j._async.io._bolt_socket import AsyncBoltSocket -from neo4j.exceptions import UnsupportedServerProduct +from neo4j._async.io._common import ( + CommitResponse, + ResetResponse, + Response, +) +from neo4j._exceptions import SocketDeadlineExceededError +from neo4j.exceptions import ( + IncompleteCommit, + ServiceUnavailable, + SessionExpired, + UnsupportedServerProduct, +) from ...._async_compat import ( AsyncTestDecorators, @@ -206,7 +219,7 @@ async def test_failing_version_negotiation(mocker, bolt_version, none_auth): @AsyncTestDecorators.mark_async_only_test -async def test_cancel_manager_in_open(mocker): +async def test_cancel_auth_manager_in_open(mocker): address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -228,7 +241,7 @@ async def test_cancel_manager_in_open(mocker): @AsyncTestDecorators.mark_async_only_test -async def test_fail_manager_in_open(mocker): +async def test_fail_auth_manager_in_open(mocker): address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -248,3 +261,135 @@ async def test_fail_manager_in_open(mocker): assert exc.value is auth_manager.get_auth.side_effect socket_mock.close.assert_called_once_with() + + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + RuntimeError("test error"), + RecursionError("How deep is your ~~love~~ recursion?"), + asyncio.CancelledError("STOP! Cancel time!"), + ), +) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +@mark_async_test +async def test_error_handler_bubbling( + mocker, fake_socket, mode, error, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + with pytest.raises(type(error)) as exc: + await handler(error) + assert exc.value is error + + if isinstance(error, asyncio.CancelledError): + connection.socket.kill.assert_called_once() + connection.socket.close.assert_not_called() + else: + connection.socket.close.assert_awaited_once() + + assert connection.closed() + assert connection.defunct() + + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + OSError("computer says no! *cough*"), + SocketDeadlineExceededError("too late, too little"), + ServiceUnavailable("borked connection"), + SessionExpired("nobody at home"), + ), +) +@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +@mark_async_test +async def test_error_handler_rewritten( + mocker, fake_socket, mode, error, routing, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + mocks.mock_driver_routing(routing) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + if queued_commit is not None: + expected_error = IncompleteCommit + elif routing: + expected_error = SessionExpired + else: + expected_error = ServiceUnavailable + + with pytest.raises(expected_error) as exc: + await handler(error) + assert exc.value.__cause__ is error + connection.socket.close.assert_awaited_once() + + assert connection.closed() + assert connection.defunct() + + +class ErrorHandlerTestMockHolder: + def __init__(self, mocker): + self.address = neo4j.Address(("127.0.0.1", 7687)) + self.socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) + self.socket_mock.getpeername.return_value = self.address + self.connection = AsyncBolt5x8(self.address, self.socket_mock, 108000) + self.pool = mocker.AsyncMock() + self.connection.pool = self.pool + + def mock_driver_routing(self, routing): + self.pool.is_direct_pool = not routing + + def queue_random_non_commit_response(self): + resp_cls = random.choice((ResetResponse, Response)) + resp = resp_cls(self.connection, "MESSAGE", {}) + self.connection.responses.append(resp) + + def queue_commit_message(self): + resp = CommitResponse(self.connection, "MESSAGE", {}) + self.connection.responses.append(resp) + + def queue_commit_message_at(self, position): + self.connection.responses.clear() + for _ in range(position - 1): + self.queue_random_non_commit_response() + self.queue_commit_message() + self.queue_random_non_commit_response() + + def get_error_handler(self, mode): + if mode == "r": + return self.connection._set_defunct_read + elif mode == "w": + return self.connection._set_defunct_write + else: + raise ValueError(f"Invalid handler mode {mode!r}") + + +def test_configures_inbox_error_handler(mocker): + inbox_cls_mock = mocker.patch( + "neo4j._async.io._bolt.AsyncInbox", autospec=True + ) + mocks = ErrorHandlerTestMockHolder(mocker) + inbox_cls_mock.assert_called_once() + call_args = inbox_cls_mock.call_args + assert call_args.kwargs["on_error"] == mocks.connection._set_defunct_read + + +def test_configures_outbox_error_handler(mocker): + inbox_cls_mock = mocker.patch( + "neo4j._async.io._bolt.AsyncOutbox", autospec=True + ) + mocks = ErrorHandlerTestMockHolder(mocker) + inbox_cls_mock.assert_called_once() + call_args = inbox_cls_mock.call_args + assert call_args.kwargs["on_error"] == mocks.connection._set_defunct_write diff --git a/tests/unit/sync/io/_common/__init__.py b/tests/unit/sync/io/_common/__init__.py new file mode 100644 index 00000000..3f968099 --- /dev/null +++ b/tests/unit/sync/io/_common/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/sync/io/_common/test_inbox.py b/tests/unit/sync/io/_common/test_inbox.py new file mode 100644 index 00000000..5b82cb65 --- /dev/null +++ b/tests/unit/sync/io/_common/test_inbox.py @@ -0,0 +1,140 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import pytest + +from neo4j._codec.packstream.v1 import Unpacker +from neo4j._exceptions import SocketDeadlineExceededError +from neo4j._sync.io._common import Inbox + +from ....._async_compat import mark_sync_test + + +class InboxMockHolder: + def __init__(self, mocker): + self.socket_mock = mocker.Mock() + self.socket_mock.getsockname.return_value = ("host", 1234) + self.on_error = mocker.MagicMock() + self.inbox = Inbox(self.socket_mock, self.on_error, Unpacker) + self.unpacker = mocker.Mock(wraps=self.inbox._unpacker) + self.inbox._unpacker = self.unpacker + # plenty of nonsense messages to read + self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000) + + def mock_set_data(self, data): + def side_effect(buffer, n): + nonlocal data + + if not data: + pytest.fail("Read more data than mocked") + + n = min(len(data), len(buffer), n) + buffer[:n] = data[:n] + data = data[n:] + return n + + self.socket_mock.recv_into.side_effect = side_effect + + def assert_no_error(self): + self.on_error.assert_not_called() + assert not self.inbox._broken + + def mock_receive_failure(self, exception): + self.socket_mock.recv_into.side_effect = exception + + def mock_unpack_failure(self, exception): + self.unpacker.unpack_structure_header.side_effect = exception + + +@pytest.mark.parametrize( + ("data", "result"), + ( + ( + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14, 0, 0)), + bytes(range(10, 15)), + ), + ( + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 0)), + bytes(range(10, 14)), + ), + ( + bytes((0, 1, 5, 0, 0)), + bytes((5,)), + ), + ), +) +@mark_sync_test +def test_inbox_dechunking(data, result, mocker): + # Given + mocks = InboxMockHolder(mocker) + mocks.mock_set_data(data) + inbox = mocks.inbox + buffer = inbox._buffer + + # When + inbox._buffer_one_chunk() + + # Then + mocks.assert_no_error() + assert buffer.used == len(result) + assert buffer.data[: len(result)] == result + + +@pytest.mark.parametrize( + "error", + ( + asyncio.CancelledError("test"), + SocketDeadlineExceededError("test"), + OSError("test"), + ), +) +@mark_sync_test +def test_inbox_receive_failure_error_handler(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_receive_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + inbox.pop({}) + + assert exc.value is error + mocks.on_error.assert_called_once_with(error) + assert inbox._broken + + +@pytest.mark.parametrize( + "error", + ( + SocketDeadlineExceededError("test"), + OSError("test"), + RecursionError("2deep4u"), + RuntimeError("something funny happened"), + ), +) +@mark_sync_test +def test_inbox_unpack_failure(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_unpack_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + inbox.pop({}) + + assert exc.value is error + mocks.on_error.assert_called_once_with(error) + assert inbox._broken diff --git a/tests/unit/sync/io/_common/test_oubox.py b/tests/unit/sync/io/_common/test_oubox.py new file mode 100644 index 00000000..24a46b6e --- /dev/null +++ b/tests/unit/sync/io/_common/test_oubox.py @@ -0,0 +1,147 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + +import pytest + +from neo4j._codec.packstream.v1 import PackableBuffer +from neo4j._exceptions import SocketDeadlineExceededError +from neo4j._sync.io._common import Outbox + +from ....._async_compat import mark_sync_test + + +class OutboxMockHolder: + def __init__(self, mocker, *, max_chunk_size=16384): + self.buffer = PackableBuffer() + self.socket_mock = mocker.MagicMock() + self.packer_mock = mocker.Mock() + self.packer_mock.return_value = self.packer_mock + self.packer_mock.new_packable_buffer.return_value = self.buffer + self.on_error = mocker.MagicMock() + self.outbox = Outbox( + self.socket_mock, + self.on_error, + self.packer_mock, + max_chunk_size=max_chunk_size, + ) + + def mock_write_message(self, data): + def side_effect(*_args, **_kwargs): + self.buffer.write(data) + + self.packer_mock.pack_struct.side_effect = side_effect + + def assert_no_error(self): + self.on_error.assert_not_called() + + def mock_pack_failure(self, exception): + def side_effect(*_args, **_kwargs): + self.buffer.write(b"some data") + raise exception + + self.packer_mock.pack_struct.side_effect = side_effect + + def mock_send_failure(self, exception): + self.socket_mock.sendall.side_effect = exception + + +@pytest.mark.parametrize( + ("chunk_size", "data", "result"), + ( + ( + 2, + bytes(range(10, 15)), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)), + ), + ( + 2, + bytes(range(10, 14)), + bytes((0, 2, 10, 11, 0, 2, 12, 13)), + ), + ( + 2, + bytes((5,)), + bytes((0, 1, 5)), + ), + ), +) +@mark_sync_test +def test_async_outbox_chunking(chunk_size, data, result, mocker): + # Given + mocks = OutboxMockHolder(mocker, max_chunk_size=chunk_size) + mocks.mock_write_message(data) + outbox = mocks.outbox + socket_mock = mocks.socket_mock + + # When + outbox.append_message(None, None, None) + + # Then + mocks.assert_no_error() + socket_mock.sendall.assert_not_called() + assert outbox.flush() + socket_mock.sendall.assert_called_once_with(result + b"\x00\x00") + + assert not outbox.flush() + socket_mock.sendall.assert_called_once() + + +@pytest.mark.parametrize( + "error", + ( + asyncio.CancelledError("test"), + SocketDeadlineExceededError("test"), + OSError("test"), + ), +) +@mark_sync_test +def test_outbox_send_failure_error_handler(mocker, error): + mocks = OutboxMockHolder(mocker, max_chunk_size=12345) + mocks.mock_send_failure(error) + outbox = mocks.outbox + + outbox.append_message(None, None, None) + assert not outbox.flush() + + mocks.on_error.assert_called_once_with(error) + + +@pytest.mark.parametrize( + "error", + ( + asyncio.CancelledError("test"), + SocketDeadlineExceededError("test"), + OSError("test"), + RecursionError("2deep4u"), + RuntimeError("something funny happened"), + ), +) +@mark_sync_test +def test_outbox_pack_failure(mocker, error): + mocks = OutboxMockHolder(mocker, max_chunk_size=12345) + mocks.mock_pack_failure(error) + outbox = mocks.outbox + socket_mock = mocks.socket_mock + + with pytest.raises(type(error)) as exc: + outbox.append_message(None, None, None) + assert not outbox.flush() + + assert exc.value is error + mocks.on_error.assert_not_called() + socket_mock.sendall.assert_not_called() diff --git a/tests/unit/sync/io/_common/test_response_handler.py b/tests/unit/sync/io/_common/test_response_handler.py new file mode 100644 index 00000000..31df1105 --- /dev/null +++ b/tests/unit/sync/io/_common/test_response_handler.py @@ -0,0 +1,127 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import pytest + +from neo4j._sync.io._common import ResetResponse + +from ....._async_compat import mark_sync_test + + +def get_handler_arg(response): + if response == "RECORD": + return [] + elif response in {"IGNORED", "FAILURE", "SUCCESS"}: + return {} + else: + raise ValueError(f"Unexpected response: {response}") + + +def call_handler(handler, response, arg=None): + if arg is None: + arg = get_handler_arg(response) + + if response == "RECORD": + return handler.on_records(arg) + elif response == "IGNORED": + return handler.on_ignored(arg) + elif response == "FAILURE": + return handler.on_failure(arg) + elif response == "SUCCESS": + return handler.on_success(arg) + else: + raise ValueError(f"Unexpected response: {response}") + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ), +) +@mark_sync_test +def test_reset_response_closes_connection_on_unexpected_responses( + response, unexpected, fake_connection +): + handler = ResetResponse(fake_connection, "reset", {}) + fake_connection.close.assert_not_called() + + call_handler(handler, response) + + if unexpected: + fake_connection.close.assert_called_once() + else: + fake_connection.close.assert_not_called() + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ), +) +@mark_sync_test +def test_reset_response_logs_warning_on_unexpected_responses( + response, unexpected, fake_connection, caplog +): + handler = ResetResponse(fake_connection, "reset", {}) + + with caplog.at_level(logging.WARNING): + call_handler(handler, response) + + log_message_found = any( + "RESET" in msg and "unexpected response" in msg + for msg in caplog.messages + ) + if unexpected: + assert log_message_found + else: + assert not log_message_found + + +@pytest.mark.parametrize( + "response", ("RECORD", "IGNORED", "FAILURE", "SUCCESS") +) +@mark_sync_test +def test_reset_response_never_calls_handlers( + response, fake_connection, mocker +): + handlers = { + key: mocker.MagicMock(name=key) + for key in ( + "on_records", + "on_ignored", + "on_failure", + "on_success", + "on_summary", + ) + } + + handler = ResetResponse(fake_connection, "reset", {}, **handlers) + + arg = get_handler_arg(response) + call_handler(handler, response, arg) + + for handler in handlers.values(): + handler.assert_not_called() diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index ef23bbf6..6f5f1a66 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -15,13 +15,26 @@ import asyncio +import random import pytest import neo4j.auth_management +from neo4j._exceptions import SocketDeadlineExceededError from neo4j._sync.io import Bolt +from neo4j._sync.io._bolt5 import Bolt5x8 from neo4j._sync.io._bolt_socket import BoltSocket -from neo4j.exceptions import UnsupportedServerProduct +from neo4j._sync.io._common import ( + CommitResponse, + ResetResponse, + Response, +) +from neo4j.exceptions import ( + IncompleteCommit, + ServiceUnavailable, + SessionExpired, + UnsupportedServerProduct, +) from ...._async_compat import ( mark_sync_test, @@ -206,7 +219,7 @@ def test_failing_version_negotiation(mocker, bolt_version, none_auth): @TestDecorators.mark_async_only_test -def test_cancel_manager_in_open(mocker): +def test_cancel_auth_manager_in_open(mocker): address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -228,7 +241,7 @@ def test_cancel_manager_in_open(mocker): @TestDecorators.mark_async_only_test -def test_fail_manager_in_open(mocker): +def test_fail_auth_manager_in_open(mocker): address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -248,3 +261,135 @@ def test_fail_manager_in_open(mocker): assert exc.value is auth_manager.get_auth.side_effect socket_mock.close.assert_called_once_with() + + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + RuntimeError("test error"), + RecursionError("How deep is your ~~love~~ recursion?"), + asyncio.CancelledError("STOP! Cancel time!"), + ), +) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +@mark_sync_test +def test_error_handler_bubbling( + mocker, fake_socket, mode, error, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + with pytest.raises(type(error)) as exc: + handler(error) + assert exc.value is error + + if isinstance(error, asyncio.CancelledError): + connection.socket.kill.assert_called_once() + connection.socket.close.assert_not_called() + else: + connection.socket.close.assert_called_once() + + assert connection.closed() + assert connection.defunct() + + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + OSError("computer says no! *cough*"), + SocketDeadlineExceededError("too late, too little"), + ServiceUnavailable("borked connection"), + SessionExpired("nobody at home"), + ), +) +@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +@mark_sync_test +def test_error_handler_rewritten( + mocker, fake_socket, mode, error, routing, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + mocks.mock_driver_routing(routing) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + if queued_commit is not None: + expected_error = IncompleteCommit + elif routing: + expected_error = SessionExpired + else: + expected_error = ServiceUnavailable + + with pytest.raises(expected_error) as exc: + handler(error) + assert exc.value.__cause__ is error + connection.socket.close.assert_called_once() + + assert connection.closed() + assert connection.defunct() + + +class ErrorHandlerTestMockHolder: + def __init__(self, mocker): + self.address = neo4j.Address(("127.0.0.1", 7687)) + self.socket_mock = mocker.MagicMock(spec=BoltSocket) + self.socket_mock.getpeername.return_value = self.address + self.connection = Bolt5x8(self.address, self.socket_mock, 108000) + self.pool = mocker.MagicMock() + self.connection.pool = self.pool + + def mock_driver_routing(self, routing): + self.pool.is_direct_pool = not routing + + def queue_random_non_commit_response(self): + resp_cls = random.choice((ResetResponse, Response)) + resp = resp_cls(self.connection, "MESSAGE", {}) + self.connection.responses.append(resp) + + def queue_commit_message(self): + resp = CommitResponse(self.connection, "MESSAGE", {}) + self.connection.responses.append(resp) + + def queue_commit_message_at(self, position): + self.connection.responses.clear() + for _ in range(position - 1): + self.queue_random_non_commit_response() + self.queue_commit_message() + self.queue_random_non_commit_response() + + def get_error_handler(self, mode): + if mode == "r": + return self.connection._set_defunct_read + elif mode == "w": + return self.connection._set_defunct_write + else: + raise ValueError(f"Invalid handler mode {mode!r}") + + +def test_configures_inbox_error_handler(mocker): + inbox_cls_mock = mocker.patch( + "neo4j._sync.io._bolt.Inbox", autospec=True + ) + mocks = ErrorHandlerTestMockHolder(mocker) + inbox_cls_mock.assert_called_once() + call_args = inbox_cls_mock.call_args + assert call_args.kwargs["on_error"] == mocks.connection._set_defunct_read + + +def test_configures_outbox_error_handler(mocker): + inbox_cls_mock = mocker.patch( + "neo4j._sync.io._bolt.Outbox", autospec=True + ) + mocks = ErrorHandlerTestMockHolder(mocker) + inbox_cls_mock.assert_called_once() + call_args = inbox_cls_mock.call_args + assert call_args.kwargs["on_error"] == mocks.connection._set_defunct_write