diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 58689584..8e60272a 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -381,16 +381,17 @@ def hello(self): def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - recv_timeout = self.configuration_hints.get( - "connection.recv_timeout_seconds" - ) - if isinstance(recv_timeout, int) and recv_timeout > 0: - self.socket.settimeout(recv_timeout) - else: - log.info("[#%04X] Server supplied an invalid value for " - "connection.recv_timeout_seconds (%r). Make sure the " - "server and network is set up correctly.", - self.local_port, recv_timeout) + if "connection.recv_timeout_seconds" in self.configuration_hints: + recv_timeout = self.configuration_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info("[#%04X] Server supplied an invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, recv_timeout) headers = self.get_base_headers() headers.update(self.auth_dict) diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/io/test_class_bolt4x3.py index ec23f1d5..fc08f5b9 100644 --- a/tests/unit/io/test_class_bolt4x3.py +++ b/tests/unit/io/test_class_bolt4x3.py @@ -199,35 +199,41 @@ def test_hello_passes_routing_metadata(fake_socket_pair): assert fields[0]["routing"] == {"foo": "bar"} -@pytest.mark.parametrize(("recv_timeout", "valid"), ( - (1, True), - (42, True), - (-1, False), - (0, False), - (2.5, False), - (None, False), - ("1", False), +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), )) -def test_hint_recv_timeout_seconds(fake_socket_pair, recv_timeout, valid, +def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, caplog): address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address) sockets.client.settimeout = MagicMock() - sockets.server.send_message(0x70, { - "server": "Neo4j/4.2.0", - "hints": {"connection.recv_timeout_seconds": recv_timeout}, - }) + sockets.server.send_message(0x70, {"server": "Neo4j/4.3.0", "hints": hints}) connection = Bolt4x3(address, sockets.client, PoolConfig.max_connection_lifetime) with caplog.at_level(logging.INFO): connection.hello() - invalid_value_logged = any(repr(recv_timeout) in msg - and "recv_timeout_seconds" in msg - and "invalid" in msg - for msg in caplog.messages) if valid: - sockets.client.settimeout.assert_called_once_with(recv_timeout) - assert not invalid_value_logged + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) else: sockets.client.settimeout.assert_not_called() - assert invalid_value_logged + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages)