diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 7b8c63ee..86f1f98a 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -544,7 +544,7 @@ def fetch_all(self): messages fetched """ detail_count = summary_count = 0 - while self.responses: + while not self._closed and self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 21f71b18..dc630fd8 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -46,6 +46,7 @@ from neo4j.io._common import ( CommitResponse, InitResponse, + ResetResponse, Response, ) @@ -294,15 +295,13 @@ def rollback(self, **handlers): self._append(b"\x13", (), Response(self, "rollback", **handlers)) def reset(self): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", response=ResetResponse(self, "reset")) self.send_all() self.fetch_all() diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 2feeea14..1a4e8b48 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -47,6 +47,7 @@ from neo4j.io._common import ( CommitResponse, InitResponse, + ResetResponse, Response, ) from neo4j.io._bolt3 import ( @@ -246,15 +247,13 @@ def rollback(self, **handlers): self._append(b"\x13", (), Response(self, "rollback", **handlers)) def reset(self): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", response=ResetResponse(self, "reset")) self.send_all() self.fetch_all() diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index f11ddd8e..971547f7 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -267,6 +267,26 @@ def on_failure(self, metadata): ) +class ResetResponse(Response): + def _unexpected_message(self, response): + log.warning("[#%04X] RESET received %s (unexpected response) " + "=> dropping connection", + self.connection.local_port, response) + self.connection.close() + + def on_records(self, records): + self._unexpected_message("RECORD") + + def on_success(self, metadata): + pass + + def on_failure(self, metadata): + self._unexpected_message("FAILURE") + + def on_ignored(self, metadata=None): + self._unexpected_message("IGNORED") + + class CommitResponse(Response): pass diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py index 3b61c710..72f9ee92 100644 --- a/tests/unit/io/test__common.py +++ b/tests/unit/io/test__common.py @@ -1,6 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + import pytest -from neo4j.io._common import Outbox +from neo4j.io._common import ( + Outbox, + ResetResponse, +) + +from ..work import fake_connection @pytest.mark.parametrize(("chunk_size", "data", "result"), ( @@ -30,3 +55,100 @@ def test_outbox_chunking(chunk_size, data, result): assert bytes(outbox.view()) == result outbox.clear() assert bytes(outbox.view()) == b"" + + +def get_handler_arg(response): + if response == "RECORD": + return [] + elif response == "IGNORED": + return {} + elif response == "FAILURE": + return {} + elif response == "SUCCESS": + return {} + else: + raise ValueError(f"Unexpected response: {response}") + + +def call_handler(handler, response, arg=None): + if arg is None: + arg = get_handler_arg(response) + + if response == "RECORD": + return handler.on_records(arg) + elif response == "IGNORED": + return handler.on_ignored(arg) + elif response == "FAILURE": + return handler.on_failure(arg) + elif response == "SUCCESS": + return handler.on_success(arg) + else: + raise ValueError(f"Unexpected response: {response}") + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +def test_reset_response_closes_connection_on_unexpected_responses( + response, unexpected, fake_connection +): + handler = ResetResponse(fake_connection, "reset") + fake_connection.close.assert_not_called() + + call_handler(handler, response) + + if unexpected: + fake_connection.close.assert_called_once() + else: + fake_connection.close.assert_not_called() + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +def test_reset_response_logs_warning_on_unexpected_responses( + response, unexpected, fake_connection, caplog +): + handler = ResetResponse(fake_connection, "reset") + + with caplog.at_level(logging.WARNING): + call_handler(handler, response) + + log_message_found = any("RESET" in msg and "unexpected response" in msg + for msg in caplog.messages) + if unexpected: + assert log_message_found + else: + assert not log_message_found + + +@pytest.mark.parametrize("response", + ("RECORD", "IGNORED", "FAILURE", "SUCCESS")) +def test_reset_response_never_calls_handlers( + response, fake_connection, mocker +): + handlers = { + key: mocker.MagicMock(name=key) + for key in + ("on_records", "on_ignored", "on_failure", "on_success", "on_summary") + } + + handler = ResetResponse(fake_connection, "reset", **handlers) + + arg = get_handler_arg(response) + call_handler(handler, response, arg) + + for handler in handlers.values(): + handler.assert_not_called()