From 1ba8b8048ea7c9771bc0fb549e310605e58804b0 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 30 May 2025 10:59:53 +0200 Subject: [PATCH] Fix driver stuck on RecursionError on COMMIT SUCCESS The driver could get stuck when the result summary was so deeply nested that it caused an `RecursionError`. It would also get stuck on certain protocol violations in the same place (e.g., unknown PackStream tag). However, the latter is much more unlikely to occur. This patch will make the driver ditch the connection when it's not able to properly read a message for unexpected reasons and let the error bubble up. --- neo4j/io/__init__.py | 15 ++ neo4j/io/_common.py | 30 ++-- tests/unit/io/_common/__init__.py | 16 ++ tests/unit/io/_common/test_inbox.py | 118 +++++++++++++ tests/unit/io/_common/test_oubox.py | 48 +++++ .../test_response_handler.py} | 41 +---- tests/unit/io/test_class_bolt.py | 164 +++++++++++++++++- 7 files changed, 382 insertions(+), 50 deletions(-) create mode 100644 tests/unit/io/_common/__init__.py create mode 100644 tests/unit/io/_common/test_inbox.py create mode 100644 tests/unit/io/_common/test_oubox.py rename tests/unit/io/{test__common.py => _common/test_response_handler.py} (76%) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 4182b406..f9690eed 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -586,6 +586,15 @@ def _set_defunct_write(self, error=None, silent=False): def _set_defunct(self, message, error=None, silent=False): direct_driver = isinstance(self.pool, BoltPool) + connection_failed = isinstance( + error, + ( + ServiceUnavailable, + SessionExpired, + OSError, + SocketDeadlineExceeded, + ), + ) if error: log.debug("[#%04X] %r", self.local_port, error) @@ -595,6 +604,12 @@ def _set_defunct(self, message, error=None, silent=False): # connection from the client side, and remove the address # from the connection pool. self._defunct = True + if not connection_failed: + # Something else but the connection failed + # => we're not sure which state we're in + # => ditch the connection and raise the error for user-awareness + self.close() + raise error if not self._closing: # If we fail while closing the connection, there is no need to # remove the connection from the pool, nor to try to close the diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index a8bff75b..8be0abe7 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -46,12 +46,11 @@ def __init__(self, s, on_error): self._messages = self._yield_messages(s) def _yield_messages(self, sock): - try: - buffer = UnpackableBuffer() - unpacker = Unpacker(buffer) - chunk_size = 0 - while True: - + buffer = UnpackableBuffer() + unpacker = Unpacker(buffer) + chunk_size = 0 + while True: + try: while chunk_size == 0: # Determine the chunk size and skip noop buffer.receive(sock, 2) @@ -61,18 +60,27 @@ def _yield_messages(self, sock): buffer.receive(sock, chunk_size + 2) chunk_size = buffer.pop_u16() + except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + self.on_error(error) - if chunk_size == 0: - # chunk_size was the end marker for the message + if chunk_size == 0: + # chunk_size was the end marker for the message + try: size, tag = unpacker.unpack_structure_header() fields = [unpacker.unpack() for _ in range(size)] yield tag, fields + except Exception as error: + log.debug( + "[#%04X] _: Failed to unpack response: %r", + self._local_port, + error, + ) + self.on_error(error) + raise + finally: # Reset for new message unpacker.reset() - except (OSError, socket.timeout, SocketDeadlineExceeded) as error: - self.on_error(error) - def pop(self): return next(self._messages) diff --git a/tests/unit/io/_common/__init__.py b/tests/unit/io/_common/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/io/_common/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/io/_common/test_inbox.py b/tests/unit/io/_common/test_inbox.py new file mode 100644 index 00000000..dbd15dc1 --- /dev/null +++ b/tests/unit/io/_common/test_inbox.py @@ -0,0 +1,118 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import asyncio + + +import pytest + +from neo4j._exceptions import SocketDeadlineExceeded +from neo4j.io._common import Inbox +from neo4j.packstream import Unpacker + + +def _on_error_side_effect(e): + raise e + + +class InboxMockHolder: + def __init__(self, mocker): + self.socket_mock = mocker.Mock() + self.socket_mock.getsockname.return_value = ("host", 1234) + self.on_error = mocker.MagicMock() + self.on_error.side_effect = _on_error_side_effect + self.inbox = Inbox(self.socket_mock, self.on_error) + self.unpacker = None + self._unpacker_failure = None + mocker.patch( + "neo4j.io._common.Unpacker", + new=self._make_unpacker, + ) + self.inbox._unpacker = self.unpacker + # plenty of nonsense messages to read + self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000) + self._mocker = mocker + + def _make_unpacker(self, buffer): + if self.unpacker is not None: + pytest.fail("Unexpected 2nd instantiation of Unpacker") + self.unpacker = self._mocker.Mock(wraps=Unpacker(buffer)) + if self._unpacker_failure is not None: + self.mock_unpack_failure(self._unpacker_failure) + return self.unpacker + + def mock_set_data(self, data): + def side_effect(buffer, n): + nonlocal data + + if not data: + pytest.fail("Read more data than mocked") + + n = min(len(data), len(buffer), n) + buffer[:n] = data[:n] + data = data[n:] + return n + + self.socket_mock.recv_into.side_effect = side_effect + + def assert_no_error(self): + self.on_error.assert_not_called() + assert next(self.inbox, None) is not None + + def mock_receive_failure(self, exception): + self.socket_mock.recv_into.side_effect = exception + + def mock_unpack_failure(self, exception): + self._unpacker_failure = exception + if self.unpacker is not None: + self.unpacker.unpack_structure_header.side_effect = exception + + +@pytest.mark.parametrize( + "error", + ( + SocketDeadlineExceeded("test"), + OSError("test"), + ), +) +def test_inbox_receive_failure_error_handler(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_receive_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + next(inbox) + + assert exc.value is error + mocks.on_error.assert_called_once_with(error) + + +@pytest.mark.parametrize( + "error", + ( + SocketDeadlineExceeded("test"), + OSError("test"), + RecursionError("2deep4u"), + RuntimeError("something funny happened"), + ), +) +def test_inbox_unpack_failure(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_unpack_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + next(inbox) + + assert exc.value is error + mocks.on_error.assert_called_once_with(error) diff --git a/tests/unit/io/_common/test_oubox.py b/tests/unit/io/_common/test_oubox.py new file mode 100644 index 00000000..220d0ce6 --- /dev/null +++ b/tests/unit/io/_common/test_oubox.py @@ -0,0 +1,48 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j.io._common import Outbox + + +@pytest.mark.parametrize(("chunk_size", "data", "result"), ( + ( + 2, + (bytes(range(10, 15)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) + ), + ( + 2, + (bytes(range(10, 14)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13)) + ), + ( + 2, + (bytes((5, 6, 7)), bytes((8, 9))), + bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + ), +)) +def test_outbox_chunking(chunk_size, data, result): + outbox = Outbox(max_chunk_size=chunk_size) + assert bytes(outbox.view()) == b"" + for d in data: + outbox.write(d) + assert bytes(outbox.view()) == result + # make sure this works multiple times + assert bytes(outbox.view()) == result + outbox.clear() + assert bytes(outbox.view()) == b"" diff --git a/tests/unit/io/test__common.py b/tests/unit/io/_common/test_response_handler.py similarity index 76% rename from tests/unit/io/test__common.py rename to tests/unit/io/_common/test_response_handler.py index b70357c1..755c8f90 100644 --- a/tests/unit/io/test__common.py +++ b/tests/unit/io/_common/test_response_handler.py @@ -1,13 +1,11 @@ # Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. +# Neo4j Sweden AB [https://neo4j.com] # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,40 +18,7 @@ import pytest -from neo4j.io._common import ( - Outbox, - ResetResponse, -) - - -@pytest.mark.parametrize(("chunk_size", "data", "result"), ( - ( - 2, - (bytes(range(10, 15)),), - bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) - ), - ( - 2, - (bytes(range(10, 14)),), - bytes((0, 2, 10, 11, 0, 2, 12, 13)) - ), - ( - 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) - ), -)) -def test_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" - +from neo4j.io._common import ResetResponse def get_handler_arg(response): if response == "RECORD": diff --git a/tests/unit/io/test_class_bolt.py b/tests/unit/io/test_class_bolt.py index 2001546c..ee794ca6 100644 --- a/tests/unit/io/test_class_bolt.py +++ b/tests/unit/io/test_class_bolt.py @@ -19,8 +19,31 @@ # limitations under the License. +import random +from socket import socket + import pytest -from neo4j.io import Bolt + +import neo4j +from neo4j._exceptions import SocketDeadlineExceeded +from neo4j.io import ( + Bolt, + BoltPool, + IOPool, + Neo4jPool, +) +from neo4j.io._bolt4 import Bolt4x4 +from neo4j.io._common import ( + CommitResponse, + ResetResponse, + Response, +) +from neo4j.exceptions import ( + IncompleteCommit, + ServiceUnavailable, + SessionExpired, +) + # python -m pytest tests/unit/io/test_class_bolt.py -s -v @@ -61,3 +84,142 @@ def test_magic_preamble(): preamble = 0x6060B017 preamble_bytes = preamble.to_bytes(4, byteorder="big") assert Bolt.MAGIC_PREAMBLE == preamble_bytes + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + RuntimeError("test error"), + RecursionError("How deep is your ~~love~~ recursion?"), + ), +) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +def test_error_handler_bubbling( + mocker, fake_socket, mode, error, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + with pytest.raises(type(error)) as exc: + handler(error) + assert exc.value is error + + connection.socket.close.assert_called_once() + + assert connection.closed() + assert connection.defunct() + + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + OSError("computer says no! *cough*"), + SocketDeadlineExceeded("too late, too little"), + ServiceUnavailable("borked connection"), + SessionExpired("nobody at home"), + ), +) +@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +def test_error_handler_rewritten( + mocker, fake_socket, mode, error, routing, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + mocks.mock_driver_routing(routing) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + if queued_commit is not None: + expected_error = IncompleteCommit + elif routing: + expected_error = SessionExpired + else: + expected_error = ServiceUnavailable + + print(expected_error) + with pytest.raises(expected_error) as exc: + handler(error) + assert exc.value.__cause__ is error + connection.socket.close.assert_called_once() + + assert connection.closed() + assert connection.defunct() + + +def make_pool_mock_cls(mocker, routing): + class PoolMock(mocker.MagicMock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __getattribute__(self, item): + if item == "__class__": + if routing: + return Neo4jPool + return BoltPool + return super().__getattribute__(item) + + def __setattr__(self, key, value): + print(key) + if key == "_PoolMock__routing": + nonlocal routing + routing = value + return + super().__setattr__(key, value) + + return PoolMock + + +class ErrorHandlerTestMockHolder: + def __init__(self, mocker): + self.address = neo4j.Address(("127.0.0.1", 7687)) + self.socket_mock = mocker.MagicMock(spec=socket) + self.socket_mock.getpeername.return_value = self.address + self.connection = Bolt4x4(self.address, self.socket_mock, 108000) + self.pool = make_pool_mock_cls(mocker, False)() + self.connection.pool = self.pool + + def mock_driver_routing(self, routing): + print("++++++ SETTING +++++++") + self.pool._PoolMock__routing = routing + + def queue_random_non_commit_response(self): + resp_cls = random.choice((ResetResponse, Response)) + resp = resp_cls(self.connection, "MESSAGE") + self.connection.responses.append(resp) + + def queue_commit_message(self): + resp = CommitResponse(self.connection, "MESSAGE") + self.connection.responses.append(resp) + + def queue_commit_message_at(self, position): + self.connection.responses.clear() + for _ in range(position - 1): + self.queue_random_non_commit_response() + self.queue_commit_message() + self.queue_random_non_commit_response() + + def get_error_handler(self, mode): + if mode == "r": + return self.connection._set_defunct_read + elif mode == "w": + return self.connection._set_defunct_write + else: + raise ValueError(f"Invalid handler mode {mode!r}") + + +def test_configures_inbox_error_handler(mocker): + inbox_cls_mock = mocker.patch( + "neo4j.io.Inbox", autospec=True + ) + mocks = ErrorHandlerTestMockHolder(mocker) + inbox_cls_mock.assert_called_once() + call_args = inbox_cls_mock.call_args + assert call_args.kwargs["on_error"] == mocks.connection._set_defunct_read