Skip to content

Commit 1ba8b80

Browse files
committed
Fix driver stuck on RecursionError on COMMIT SUCCESS
The driver could get stuck when the result summary was so deeply nested that it caused an `RecursionError`. It would also get stuck on certain protocol violations in the same place (e.g., unknown PackStream tag). However, the latter is much more unlikely to occur. This patch will make the driver ditch the connection when it's not able to properly read a message for unexpected reasons and let the error bubble up.
1 parent 2694fdc commit 1ba8b80

File tree

7 files changed

+382
-50
lines changed

7 files changed

+382
-50
lines changed

neo4j/io/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,15 @@ def _set_defunct_write(self, error=None, silent=False):
586586

587587
def _set_defunct(self, message, error=None, silent=False):
588588
direct_driver = isinstance(self.pool, BoltPool)
589+
connection_failed = isinstance(
590+
error,
591+
(
592+
ServiceUnavailable,
593+
SessionExpired,
594+
OSError,
595+
SocketDeadlineExceeded,
596+
),
597+
)
589598

590599
if error:
591600
log.debug("[#%04X] %r", self.local_port, error)
@@ -595,6 +604,12 @@ def _set_defunct(self, message, error=None, silent=False):
595604
# connection from the client side, and remove the address
596605
# from the connection pool.
597606
self._defunct = True
607+
if not connection_failed:
608+
# Something else but the connection failed
609+
# => we're not sure which state we're in
610+
# => ditch the connection and raise the error for user-awareness
611+
self.close()
612+
raise error
598613
if not self._closing:
599614
# If we fail while closing the connection, there is no need to
600615
# remove the connection from the pool, nor to try to close the

neo4j/io/_common.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ def __init__(self, s, on_error):
4646
self._messages = self._yield_messages(s)
4747

4848
def _yield_messages(self, sock):
49-
try:
50-
buffer = UnpackableBuffer()
51-
unpacker = Unpacker(buffer)
52-
chunk_size = 0
53-
while True:
54-
49+
buffer = UnpackableBuffer()
50+
unpacker = Unpacker(buffer)
51+
chunk_size = 0
52+
while True:
53+
try:
5554
while chunk_size == 0:
5655
# Determine the chunk size and skip noop
5756
buffer.receive(sock, 2)
@@ -61,18 +60,27 @@ def _yield_messages(self, sock):
6160

6261
buffer.receive(sock, chunk_size + 2)
6362
chunk_size = buffer.pop_u16()
63+
except (OSError, socket.timeout, SocketDeadlineExceeded) as error:
64+
self.on_error(error)
6465

65-
if chunk_size == 0:
66-
# chunk_size was the end marker for the message
66+
if chunk_size == 0:
67+
# chunk_size was the end marker for the message
68+
try:
6769
size, tag = unpacker.unpack_structure_header()
6870
fields = [unpacker.unpack() for _ in range(size)]
6971
yield tag, fields
72+
except Exception as error:
73+
log.debug(
74+
"[#%04X] _: Failed to unpack response: %r",
75+
self._local_port,
76+
error,
77+
)
78+
self.on_error(error)
79+
raise
80+
finally:
7081
# Reset for new message
7182
unpacker.reset()
7283

73-
except (OSError, socket.timeout, SocketDeadlineExceeded) as error:
74-
self.on_error(error)
75-
7684
def pop(self):
7785
return next(self._messages)
7886

tests/unit/io/_common/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [http://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.

tests/unit/io/_common/test_inbox.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.import asyncio
15+
16+
17+
import pytest
18+
19+
from neo4j._exceptions import SocketDeadlineExceeded
20+
from neo4j.io._common import Inbox
21+
from neo4j.packstream import Unpacker
22+
23+
24+
def _on_error_side_effect(e):
25+
raise e
26+
27+
28+
class InboxMockHolder:
29+
def __init__(self, mocker):
30+
self.socket_mock = mocker.Mock()
31+
self.socket_mock.getsockname.return_value = ("host", 1234)
32+
self.on_error = mocker.MagicMock()
33+
self.on_error.side_effect = _on_error_side_effect
34+
self.inbox = Inbox(self.socket_mock, self.on_error)
35+
self.unpacker = None
36+
self._unpacker_failure = None
37+
mocker.patch(
38+
"neo4j.io._common.Unpacker",
39+
new=self._make_unpacker,
40+
)
41+
self.inbox._unpacker = self.unpacker
42+
# plenty of nonsense messages to read
43+
self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000)
44+
self._mocker = mocker
45+
46+
def _make_unpacker(self, buffer):
47+
if self.unpacker is not None:
48+
pytest.fail("Unexpected 2nd instantiation of Unpacker")
49+
self.unpacker = self._mocker.Mock(wraps=Unpacker(buffer))
50+
if self._unpacker_failure is not None:
51+
self.mock_unpack_failure(self._unpacker_failure)
52+
return self.unpacker
53+
54+
def mock_set_data(self, data):
55+
def side_effect(buffer, n):
56+
nonlocal data
57+
58+
if not data:
59+
pytest.fail("Read more data than mocked")
60+
61+
n = min(len(data), len(buffer), n)
62+
buffer[:n] = data[:n]
63+
data = data[n:]
64+
return n
65+
66+
self.socket_mock.recv_into.side_effect = side_effect
67+
68+
def assert_no_error(self):
69+
self.on_error.assert_not_called()
70+
assert next(self.inbox, None) is not None
71+
72+
def mock_receive_failure(self, exception):
73+
self.socket_mock.recv_into.side_effect = exception
74+
75+
def mock_unpack_failure(self, exception):
76+
self._unpacker_failure = exception
77+
if self.unpacker is not None:
78+
self.unpacker.unpack_structure_header.side_effect = exception
79+
80+
81+
@pytest.mark.parametrize(
82+
"error",
83+
(
84+
SocketDeadlineExceeded("test"),
85+
OSError("test"),
86+
),
87+
)
88+
def test_inbox_receive_failure_error_handler(mocker, error):
89+
mocks = InboxMockHolder(mocker)
90+
mocks.mock_receive_failure(error)
91+
inbox = mocks.inbox
92+
93+
with pytest.raises(type(error)) as exc:
94+
next(inbox)
95+
96+
assert exc.value is error
97+
mocks.on_error.assert_called_once_with(error)
98+
99+
100+
@pytest.mark.parametrize(
101+
"error",
102+
(
103+
SocketDeadlineExceeded("test"),
104+
OSError("test"),
105+
RecursionError("2deep4u"),
106+
RuntimeError("something funny happened"),
107+
),
108+
)
109+
def test_inbox_unpack_failure(mocker, error):
110+
mocks = InboxMockHolder(mocker)
111+
mocks.mock_unpack_failure(error)
112+
inbox = mocks.inbox
113+
114+
with pytest.raises(type(error)) as exc:
115+
next(inbox)
116+
117+
assert exc.value is error
118+
mocks.on_error.assert_called_once_with(error)

tests/unit/io/_common/test_oubox.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
19+
from neo4j.io._common import Outbox
20+
21+
22+
@pytest.mark.parametrize(("chunk_size", "data", "result"), (
23+
(
24+
2,
25+
(bytes(range(10, 15)),),
26+
bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
27+
),
28+
(
29+
2,
30+
(bytes(range(10, 14)),),
31+
bytes((0, 2, 10, 11, 0, 2, 12, 13))
32+
),
33+
(
34+
2,
35+
(bytes((5, 6, 7)), bytes((8, 9))),
36+
bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
37+
),
38+
))
39+
def test_outbox_chunking(chunk_size, data, result):
40+
outbox = Outbox(max_chunk_size=chunk_size)
41+
assert bytes(outbox.view()) == b""
42+
for d in data:
43+
outbox.write(d)
44+
assert bytes(outbox.view()) == result
45+
# make sure this works multiple times
46+
assert bytes(outbox.view()) == result
47+
outbox.clear()
48+
assert bytes(outbox.view()) == b""

tests/unit/io/test__common.py renamed to tests/unit/io/_common/test_response_handler.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# Copyright (c) "Neo4j"
2-
# Neo4j Sweden AB [http://neo4j.com]
3-
#
4-
# This file is part of Neo4j.
2+
# Neo4j Sweden AB [https://neo4j.com]
53
#
64
# Licensed under the Apache License, Version 2.0 (the "License");
75
# you may not use this file except in compliance with the License.
86
# You may obtain a copy of the License at
97
#
10-
# http://www.apache.org/licenses/LICENSE-2.0
8+
# https://www.apache.org/licenses/LICENSE-2.0
119
#
1210
# Unless required by applicable law or agreed to in writing, software
1311
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -20,40 +18,7 @@
2018

2119
import pytest
2220

23-
from neo4j.io._common import (
24-
Outbox,
25-
ResetResponse,
26-
)
27-
28-
29-
@pytest.mark.parametrize(("chunk_size", "data", "result"), (
30-
(
31-
2,
32-
(bytes(range(10, 15)),),
33-
bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
34-
),
35-
(
36-
2,
37-
(bytes(range(10, 14)),),
38-
bytes((0, 2, 10, 11, 0, 2, 12, 13))
39-
),
40-
(
41-
2,
42-
(bytes((5, 6, 7)), bytes((8, 9))),
43-
bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
44-
),
45-
))
46-
def test_outbox_chunking(chunk_size, data, result):
47-
outbox = Outbox(max_chunk_size=chunk_size)
48-
assert bytes(outbox.view()) == b""
49-
for d in data:
50-
outbox.write(d)
51-
assert bytes(outbox.view()) == result
52-
# make sure this works multiple times
53-
assert bytes(outbox.view()) == result
54-
outbox.clear()
55-
assert bytes(outbox.view()) == b""
56-
21+
from neo4j.io._common import ResetResponse
5722

5823
def get_handler_arg(response):
5924
if response == "RECORD":

0 commit comments

Comments
 (0)