diff --git a/neo4j/api.py b/neo4j/api.py index e4af3933..8ed2ce26 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -219,6 +219,11 @@ def version_info(self): :return: Server Version or None :rtype: tuple + + .. deprecated:: 4.3 + `version_info` will be removed in version 5.0. Use + :meth:`~ServerInfo.agent`, :meth:`~ServerInfo.protocol_version`, + or call the `dbms.components` procedure instead. """ if not self.agent: return None diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 8ebd3e6c..4b6a2b93 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -141,6 +141,10 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self.unresolved_address = unresolved_address self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) + # so far `connection.recv_timeout_seconds` is the only available + # configuration hint that exists. Therefore, all hints can be stored at + # connection level. This might change in the future. + self.configuration_hints = {} self.outbox = Outbox() self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) self.packer = Packer(self.outbox) diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 9444c7e8..58689584 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -376,3 +376,30 @@ def fail(md): self.send_all() self.fetch_all() return [metadata.get("rt")] + + def hello(self): + def on_success(metadata): + self.configuration_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + recv_timeout = self.configuration_hints.get( + "connection.recv_timeout_seconds" + ) + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info("[#%04X] Server supplied an invalid value for " + "connection.recv_timeout_seconds (%r). Make sure the " + "server and network is set up correctly.", + self.local_port, recv_timeout) + + headers = self.get_base_headers() + headers.update(self.auth_dict) + logged_headers = dict(headers) + if "credentials" in logged_headers: + logged_headers["credentials"] = "*******" + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, on_success=on_success)) + self.send_all() + self.fetch_all() + check_supported_server_product(self.server_info.agent) diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index 38e2dfaa..fc543499 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -19,6 +19,7 @@ # limitations under the License. +import socket from struct import pack as struct_pack from neo4j.exceptions import ( @@ -67,7 +68,7 @@ def _yield_messages(self, sock): # Reset for new message unpacker.reset() - except OSError as error: + except (OSError, socket.timeout) as error: self.on_error(error) def pop(self): diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 6eb9e2f3..6aeb71a1 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -35,6 +35,7 @@ "Optimization:PullPipelining": true, "Temporary:ResultKeys": true, "Temporary:FullSummary": true, - "Temporary:CypherPathAndRelationship": true + "Temporary:CypherPathAndRelationship": true, + "ConfHint:connection.recv_timeout_seconds": true } } diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/io/test_class_bolt3.py index 79bc0cf4..f7d63e85 100644 --- a/tests/unit/io/test_class_bolt3.py +++ b/tests/unit/io/test_class_bolt3.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock import pytest @@ -94,3 +95,18 @@ def test_simple_pull(fake_socket): tag, fields = socket.pop_message() assert tag == b"\x3F" assert len(fields) == 0 + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/3.5.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt3(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/io/test_class_bolt4x0.py index 333fc158..3879acb0 100644 --- a/tests/unit/io/test_class_bolt4x0.py +++ b/tests/unit/io/test_class_bolt4x0.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock import pytest @@ -181,3 +182,18 @@ def test_n_and_qid_extras_in_pull(fake_socket): assert tag == b"\x3F" assert len(fields) == 1 assert fields[0] == {"n": 666, "qid": 777} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.0.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x0(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/io/test_class_bolt4x1.py index aee69e68..663d3cbe 100644 --- a/tests/unit/io/test_class_bolt4x1.py +++ b/tests/unit/io/test_class_bolt4x1.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock import pytest @@ -194,3 +195,18 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert tag == 0x01 assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.1.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/io/test_class_bolt4x2.py index 0c0b1a9a..470adf5c 100644 --- a/tests/unit/io/test_class_bolt4x2.py +++ b/tests/unit/io/test_class_bolt4x2.py @@ -19,6 +19,8 @@ # limitations under the License. +from unittest.mock import MagicMock + import pytest from neo4j.io._bolt4 import Bolt4x2 @@ -194,3 +196,18 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert tag == 0x01 assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("recv_timeout", (1, -1)) +def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.2.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x2(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.hello() + sockets.client.settimeout.assert_not_called() diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/io/test_class_bolt4x3.py index b82a4f0b..ec23f1d5 100644 --- a/tests/unit/io/test_class_bolt4x3.py +++ b/tests/unit/io/test_class_bolt4x3.py @@ -18,6 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from unittest.mock import MagicMock import pytest @@ -186,7 +188,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): def test_hello_passes_routing_metadata(fake_socket_pair): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) - sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0"}) connection = Bolt4x3(address, sockets.client, PoolConfig.max_connection_lifetime, routing_context={"foo": "bar"}) @@ -195,3 +197,37 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert tag == 0x01 assert len(fields) == 1 assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("recv_timeout", "valid"), ( + (1, True), + (42, True), + (-1, False), + (0, False), + (2.5, False), + (None, False), + ("1", False), +)) +def test_hint_recv_timeout_seconds(fake_socket_pair, recv_timeout, valid, + caplog): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.client.settimeout = MagicMock() + sockets.server.send_message(0x70, { + "server": "Neo4j/4.2.0", + "hints": {"connection.recv_timeout_seconds": recv_timeout}, + }) + connection = Bolt4x3(address, sockets.client, + PoolConfig.max_connection_lifetime) + with caplog.at_level(logging.INFO): + connection.hello() + invalid_value_logged = any(repr(recv_timeout) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + if valid: + sockets.client.settimeout.assert_called_once_with(recv_timeout) + assert not invalid_value_logged + else: + sockets.client.settimeout.assert_not_called() + assert invalid_value_logged