diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 4182b406..f9690eed 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -586,6 +586,15 @@ def _set_defunct_write(self, error=None, silent=False): def _set_defunct(self, message, error=None, silent=False): direct_driver = isinstance(self.pool, BoltPool) + connection_failed = isinstance( + error, + ( + ServiceUnavailable, + SessionExpired, + OSError, + SocketDeadlineExceeded, + ), + ) if error: log.debug("[#%04X] %r", self.local_port, error) @@ -595,6 +604,12 @@ def _set_defunct(self, message, error=None, silent=False): # connection from the client side, and remove the address # from the connection pool. self._defunct = True + if not connection_failed: + # Something else but the connection failed + # => we're not sure which state we're in + # => ditch the connection and raise the error for user-awareness + self.close() + raise error if not self._closing: # If we fail while closing the connection, there is no need to # remove the connection from the pool, nor to try to close the diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index a8bff75b..8be0abe7 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -46,12 +46,11 @@ def __init__(self, s, on_error): self._messages = self._yield_messages(s) def _yield_messages(self, sock): - try: - buffer = UnpackableBuffer() - unpacker = Unpacker(buffer) - chunk_size = 0 - while True: - + buffer = UnpackableBuffer() + unpacker = Unpacker(buffer) + chunk_size = 0 + while True: + try: while chunk_size == 0: # Determine the chunk size and skip noop buffer.receive(sock, 2) @@ -61,18 +60,27 @@ def _yield_messages(self, sock): buffer.receive(sock, chunk_size + 2) chunk_size = buffer.pop_u16() + except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + self.on_error(error) - if chunk_size == 0: - # chunk_size was the end marker for the message + if chunk_size == 0: + # chunk_size was the end marker for the message + try: size, tag = unpacker.unpack_structure_header() fields = [unpacker.unpack() for _ in range(size)] yield tag, fields + except Exception as error: + log.debug( + "[#%04X] _: Failed to unpack response: %r", + self._local_port, + error, + ) + self.on_error(error) + raise + finally: # Reset for new message unpacker.reset() - except (OSError, socket.timeout, SocketDeadlineExceeded) as error: - self.on_error(error) - def pop(self): return next(self._messages) diff --git a/tests/unit/io/_common/__init__.py b/tests/unit/io/_common/__init__.py new file mode 100644 index 00000000..b81a309d --- /dev/null +++ b/tests/unit/io/_common/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/io/_common/test_inbox.py b/tests/unit/io/_common/test_inbox.py new file mode 100644 index 00000000..dbd15dc1 --- /dev/null +++ b/tests/unit/io/_common/test_inbox.py @@ -0,0 +1,118 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.import asyncio + + +import pytest + +from neo4j._exceptions import SocketDeadlineExceeded +from neo4j.io._common import Inbox +from neo4j.packstream import Unpacker + + +def _on_error_side_effect(e): + raise e + + +class InboxMockHolder: + def __init__(self, mocker): + self.socket_mock = mocker.Mock() + self.socket_mock.getsockname.return_value = ("host", 1234) + self.on_error = mocker.MagicMock() + self.on_error.side_effect = _on_error_side_effect + self.inbox = Inbox(self.socket_mock, self.on_error) + self.unpacker = None + self._unpacker_failure = None + mocker.patch( + "neo4j.io._common.Unpacker", + new=self._make_unpacker, + ) + self.inbox._unpacker = self.unpacker + # plenty of nonsense messages to read + self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000) + self._mocker = mocker + + def _make_unpacker(self, buffer): + if self.unpacker is not None: + pytest.fail("Unexpected 2nd instantiation of Unpacker") + self.unpacker = self._mocker.Mock(wraps=Unpacker(buffer)) + if self._unpacker_failure is not None: + self.mock_unpack_failure(self._unpacker_failure) + return self.unpacker + + def mock_set_data(self, data): + def side_effect(buffer, n): + nonlocal data + + if not data: + pytest.fail("Read more data than mocked") + + n = min(len(data), len(buffer), n) + buffer[:n] = data[:n] + data = data[n:] + return n + + self.socket_mock.recv_into.side_effect = side_effect + + def assert_no_error(self): + self.on_error.assert_not_called() + assert next(self.inbox, None) is not None + + def mock_receive_failure(self, exception): + self.socket_mock.recv_into.side_effect = exception + + def mock_unpack_failure(self, exception): + self._unpacker_failure = exception + if self.unpacker is not None: + self.unpacker.unpack_structure_header.side_effect = exception + + +@pytest.mark.parametrize( + "error", + ( + SocketDeadlineExceeded("test"), + OSError("test"), + ), +) +def test_inbox_receive_failure_error_handler(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_receive_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + next(inbox) + + assert exc.value is error + mocks.on_error.assert_called_once_with(error) + + +@pytest.mark.parametrize( + "error", + ( + SocketDeadlineExceeded("test"), + OSError("test"), + RecursionError("2deep4u"), + RuntimeError("something funny happened"), + ), +) +def test_inbox_unpack_failure(mocker, error): + mocks = InboxMockHolder(mocker) + mocks.mock_unpack_failure(error) + inbox = mocks.inbox + + with pytest.raises(type(error)) as exc: + next(inbox) + + assert exc.value is error + mocks.on_error.assert_called_once_with(error) diff --git a/tests/unit/io/_common/test_oubox.py b/tests/unit/io/_common/test_oubox.py new file mode 100644 index 00000000..220d0ce6 --- /dev/null +++ b/tests/unit/io/_common/test_oubox.py @@ -0,0 +1,48 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j.io._common import Outbox + + +@pytest.mark.parametrize(("chunk_size", "data", "result"), ( + ( + 2, + (bytes(range(10, 15)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) + ), + ( + 2, + (bytes(range(10, 14)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13)) + ), + ( + 2, + (bytes((5, 6, 7)), bytes((8, 9))), + bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + ), +)) +def test_outbox_chunking(chunk_size, data, result): + outbox = Outbox(max_chunk_size=chunk_size) + assert bytes(outbox.view()) == b"" + for d in data: + outbox.write(d) + assert bytes(outbox.view()) == result + # make sure this works multiple times + assert bytes(outbox.view()) == result + outbox.clear() + assert bytes(outbox.view()) == b"" diff --git a/tests/unit/io/test__common.py b/tests/unit/io/_common/test_response_handler.py similarity index 76% rename from tests/unit/io/test__common.py rename to tests/unit/io/_common/test_response_handler.py index b70357c1..755c8f90 100644 --- a/tests/unit/io/test__common.py +++ b/tests/unit/io/_common/test_response_handler.py @@ -1,13 +1,11 @@ # Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. +# Neo4j Sweden AB [https://neo4j.com] # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -20,40 +18,7 @@ import pytest -from neo4j.io._common import ( - Outbox, - ResetResponse, -) - - -@pytest.mark.parametrize(("chunk_size", "data", "result"), ( - ( - 2, - (bytes(range(10, 15)),), - bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) - ), - ( - 2, - (bytes(range(10, 14)),), - bytes((0, 2, 10, 11, 0, 2, 12, 13)) - ), - ( - 2, - (bytes((5, 6, 7)), bytes((8, 9))), - bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) - ), -)) -def test_outbox_chunking(chunk_size, data, result): - outbox = Outbox(max_chunk_size=chunk_size) - assert bytes(outbox.view()) == b"" - for d in data: - outbox.write(d) - assert bytes(outbox.view()) == result - # make sure this works multiple times - assert bytes(outbox.view()) == result - outbox.clear() - assert bytes(outbox.view()) == b"" - +from neo4j.io._common import ResetResponse def get_handler_arg(response): if response == "RECORD": diff --git a/tests/unit/io/test_class_bolt.py b/tests/unit/io/test_class_bolt.py index 2001546c..ee794ca6 100644 --- a/tests/unit/io/test_class_bolt.py +++ b/tests/unit/io/test_class_bolt.py @@ -19,8 +19,31 @@ # limitations under the License. +import random +from socket import socket + import pytest -from neo4j.io import Bolt + +import neo4j +from neo4j._exceptions import SocketDeadlineExceeded +from neo4j.io import ( + Bolt, + BoltPool, + IOPool, + Neo4jPool, +) +from neo4j.io._bolt4 import Bolt4x4 +from neo4j.io._common import ( + CommitResponse, + ResetResponse, + Response, +) +from neo4j.exceptions import ( + IncompleteCommit, + ServiceUnavailable, + SessionExpired, +) + # python -m pytest tests/unit/io/test_class_bolt.py -s -v @@ -61,3 +84,142 @@ def test_magic_preamble(): preamble = 0x6060B017 preamble_bytes = preamble.to_bytes(4, byteorder="big") assert Bolt.MAGIC_PREAMBLE == preamble_bytes + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + RuntimeError("test error"), + RecursionError("How deep is your ~~love~~ recursion?"), + ), +) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +def test_error_handler_bubbling( + mocker, fake_socket, mode, error, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + with pytest.raises(type(error)) as exc: + handler(error) + assert exc.value is error + + connection.socket.close.assert_called_once() + + assert connection.closed() + assert connection.defunct() + + +@pytest.mark.parametrize("mode", ("r", "w")) +@pytest.mark.parametrize( + "error", + ( + OSError("computer says no! *cough*"), + SocketDeadlineExceeded("too late, too little"), + ServiceUnavailable("borked connection"), + SessionExpired("nobody at home"), + ), +) +@pytest.mark.parametrize("routing", (True, False)) +@pytest.mark.parametrize("queued_commit", (None, 0, 1, 10)) +def test_error_handler_rewritten( + mocker, fake_socket, mode, error, routing, queued_commit +): + mocks = ErrorHandlerTestMockHolder(mocker) + mocks.mock_driver_routing(routing) + if queued_commit is not None: + mocks.queue_commit_message_at(queued_commit) + + connection = mocks.connection + handler = mocks.get_error_handler(mode) + + if queued_commit is not None: + expected_error = IncompleteCommit + elif routing: + expected_error = SessionExpired + else: + expected_error = ServiceUnavailable + + print(expected_error) + with pytest.raises(expected_error) as exc: + handler(error) + assert exc.value.__cause__ is error + connection.socket.close.assert_called_once() + + assert connection.closed() + assert connection.defunct() + + +def make_pool_mock_cls(mocker, routing): + class PoolMock(mocker.MagicMock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __getattribute__(self, item): + if item == "__class__": + if routing: + return Neo4jPool + return BoltPool + return super().__getattribute__(item) + + def __setattr__(self, key, value): + print(key) + if key == "_PoolMock__routing": + nonlocal routing + routing = value + return + super().__setattr__(key, value) + + return PoolMock + + +class ErrorHandlerTestMockHolder: + def __init__(self, mocker): + self.address = neo4j.Address(("127.0.0.1", 7687)) + self.socket_mock = mocker.MagicMock(spec=socket) + self.socket_mock.getpeername.return_value = self.address + self.connection = Bolt4x4(self.address, self.socket_mock, 108000) + self.pool = make_pool_mock_cls(mocker, False)() + self.connection.pool = self.pool + + def mock_driver_routing(self, routing): + print("++++++ SETTING +++++++") + self.pool._PoolMock__routing = routing + + def queue_random_non_commit_response(self): + resp_cls = random.choice((ResetResponse, Response)) + resp = resp_cls(self.connection, "MESSAGE") + self.connection.responses.append(resp) + + def queue_commit_message(self): + resp = CommitResponse(self.connection, "MESSAGE") + self.connection.responses.append(resp) + + def queue_commit_message_at(self, position): + self.connection.responses.clear() + for _ in range(position - 1): + self.queue_random_non_commit_response() + self.queue_commit_message() + self.queue_random_non_commit_response() + + def get_error_handler(self, mode): + if mode == "r": + return self.connection._set_defunct_read + elif mode == "w": + return self.connection._set_defunct_write + else: + raise ValueError(f"Invalid handler mode {mode!r}") + + +def test_configures_inbox_error_handler(mocker): + inbox_cls_mock = mocker.patch( + "neo4j.io.Inbox", autospec=True + ) + mocks = ErrorHandlerTestMockHolder(mocker) + inbox_cls_mock.assert_called_once() + call_args = inbox_cls_mock.call_args + assert call_args.kwargs["on_error"] == mocks.connection._set_defunct_read