From acbe50ff5358e2b773ff7d1832b6c7b997925b59 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 11 Apr 2024 15:32:43 +0200 Subject: [PATCH] Harden driver against unexpected RESET responses The server has been observed to reply with `FAILURE` and `IGNORED` to `RESET` requests. The former is according to spec and the driver should drop the connection (which it didn't), the latter isn't. The right combination of those two unexpected responses at the right time could get the driver stuck in an infinite loop. This change makes the driver drop the connection in either case to gracefully handle the situation. --- neo4j/io/__init__.py | 2 +- neo4j/io/_bolt3.py | 13 ++-- neo4j/io/_bolt4.py | 13 ++-- neo4j/io/_common.py | 20 ++++++ tests/unit/io/test__common.py | 124 +++++++++++++++++++++++++++++++++- 5 files changed, 156 insertions(+), 16 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 7b8c63ee..86f1f98a 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -544,7 +544,7 @@ def fetch_all(self): messages fetched """ detail_count = summary_count = 0 - while self.responses: + while not self._closed and self.responses: response = self.responses[0] while not response.complete: detail_delta, summary_delta = self.fetch_message() diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 21f71b18..dc630fd8 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -46,6 +46,7 @@ from neo4j.io._common import ( CommitResponse, InitResponse, + ResetResponse, Response, ) @@ -294,15 +295,13 @@ def rollback(self, **handlers): self._append(b"\x13", (), Response(self, "rollback", **handlers)) def reset(self): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", response=ResetResponse(self, "reset")) self.send_all() self.fetch_all() diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 2feeea14..1a4e8b48 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -47,6 +47,7 @@ from neo4j.io._common import ( CommitResponse, InitResponse, + ResetResponse, Response, ) from neo4j.io._bolt3 import ( @@ -246,15 +247,13 @@ def rollback(self, **handlers): self._append(b"\x13", (), Response(self, "rollback", **handlers)) def reset(self): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. - """ - - def fail(metadata): - raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) + """Reset the connection. + Add a RESET message to the outgoing queue, send it and consume all + remaining messages. + """ log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) + self._append(b"\x0F", response=ResetResponse(self, "reset")) self.send_all() self.fetch_all() diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index f11ddd8e..971547f7 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -267,6 +267,26 @@ def on_failure(self, metadata): ) +class ResetResponse(Response): + def _unexpected_message(self, response): + log.warning("[#%04X] RESET received %s (unexpected response) " + "=> dropping connection", + self.connection.local_port, response) + self.connection.close() + + def on_records(self, records): + self._unexpected_message("RECORD") + + def on_success(self, metadata): + pass + + def on_failure(self, metadata): + self._unexpected_message("FAILURE") + + def on_ignored(self, metadata=None): + self._unexpected_message("IGNORED") + + class CommitResponse(Response): pass diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py index 3b61c710..72f9ee92 100644 --- a/tests/unit/io/test__common.py +++ b/tests/unit/io/test__common.py @@ -1,6 +1,31 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + import pytest -from neo4j.io._common import Outbox +from neo4j.io._common import ( + Outbox, + ResetResponse, +) + +from ..work import fake_connection @pytest.mark.parametrize(("chunk_size", "data", "result"), ( @@ -30,3 +55,100 @@ def test_outbox_chunking(chunk_size, data, result): assert bytes(outbox.view()) == result outbox.clear() assert bytes(outbox.view()) == b"" + + +def get_handler_arg(response): + if response == "RECORD": + return [] + elif response == "IGNORED": + return {} + elif response == "FAILURE": + return {} + elif response == "SUCCESS": + return {} + else: + raise ValueError(f"Unexpected response: {response}") + + +def call_handler(handler, response, arg=None): + if arg is None: + arg = get_handler_arg(response) + + if response == "RECORD": + return handler.on_records(arg) + elif response == "IGNORED": + return handler.on_ignored(arg) + elif response == "FAILURE": + return handler.on_failure(arg) + elif response == "SUCCESS": + return handler.on_success(arg) + else: + raise ValueError(f"Unexpected response: {response}") + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +def test_reset_response_closes_connection_on_unexpected_responses( + response, unexpected, fake_connection +): + handler = ResetResponse(fake_connection, "reset") + fake_connection.close.assert_not_called() + + call_handler(handler, response) + + if unexpected: + fake_connection.close.assert_called_once() + else: + fake_connection.close.assert_not_called() + + +@pytest.mark.parametrize( + ("response", "unexpected"), + ( + ("RECORD", True), + ("IGNORED", True), + ("FAILURE", True), + ("SUCCESS", False), + ) +) +def test_reset_response_logs_warning_on_unexpected_responses( + response, unexpected, fake_connection, caplog +): + handler = ResetResponse(fake_connection, "reset") + + with caplog.at_level(logging.WARNING): + call_handler(handler, response) + + log_message_found = any("RESET" in msg and "unexpected response" in msg + for msg in caplog.messages) + if unexpected: + assert log_message_found + else: + assert not log_message_found + + +@pytest.mark.parametrize("response", + ("RECORD", "IGNORED", "FAILURE", "SUCCESS")) +def test_reset_response_never_calls_handlers( + response, fake_connection, mocker +): + handlers = { + key: mocker.MagicMock(name=key) + for key in + ("on_records", "on_ignored", "on_failure", "on_success", "on_summary") + } + + handler = ResetResponse(fake_connection, "reset", **handlers) + + arg = get_handler_arg(response) + call_handler(handler, response, arg) + + for handler in handlers.values(): + handler.assert_not_called()