Skip to content

Commit a32e403

Browse files
authored
Fix driver stuck on RecursionError on COMMIT SUCCESS (#1192)
The driver could get stuck when the result summary was so deeply nested that it caused an `RecursionError`. It would also get stuck on certain protocol violations in the same place (e.g., unknown PackStream tag). However, the latter is much more unlikely to occur. This patch will make the driver ditch the connection when it's not able to properly read a message for unexpected reasons and let the error bubble up. * Add unit tests
1 parent 5b53d72 commit a32e403

File tree

15 files changed

+1080
-52
lines changed

15 files changed

+1080
-52
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ select = [
277277
# allow async functions without await to enable type checking, pretending to be async, matching type signatures
278278
"RUF029",
279279
]
280+
"tests/**" = [
281+
"B011", # allow `assert False` in tests, they won't be run with -O anyway
282+
]
280283
"bin/**" = [
281284
"T20", # print statements are ok in our helper scripts
282285
]

src/neo4j/_async/io/_bolt.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ..._exceptions import (
3636
BoltError,
3737
BoltHandshakeError,
38+
SocketDeadlineExceededError,
3839
)
3940
from ..._io import BoltProtocolVersion
4041
from ..._meta import USER_AGENT
@@ -887,6 +888,15 @@ async def _set_defunct_write(self, error=None, silent=False):
887888
async def _set_defunct(self, message, error=None, silent=False):
888889
direct_driver = getattr(self.pool, "is_direct_pool", False)
889890
user_cancelled = isinstance(error, asyncio.CancelledError)
891+
connection_failed = isinstance(
892+
error,
893+
(
894+
ServiceUnavailable,
895+
SessionExpired,
896+
OSError,
897+
SocketDeadlineExceededError,
898+
),
899+
)
890900

891901
if not (user_cancelled or self._closing):
892902
log_call = log.error
@@ -913,6 +923,12 @@ async def _set_defunct(self, message, error=None, silent=False):
913923
if user_cancelled:
914924
self.kill()
915925
raise error # cancellation error should not be re-written
926+
if not connection_failed:
927+
# Something else but the connection failed
928+
# => we're not sure which state we're in
929+
# => ditch the connection and raise the error for user-awareness
930+
await self.close()
931+
raise error
916932
if not self._closing:
917933
# If we fail while closing the connection, there is no need to
918934
# remove the connection from the pool, nor to try to close the

src/neo4j/_async/io/_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ async def pop(self, hydration_hooks):
8181
self._unpacker.unpack(hydration_hooks) for _ in range(size)
8282
]
8383
return tag, fields
84+
except Exception as error:
85+
log.debug(
86+
"[#%04X] _: Failed to unpack response: %r",
87+
self._local_port,
88+
error,
89+
)
90+
self._broken = True
91+
await AsyncUtil.callback(self.on_error, error)
92+
raise
8493
finally:
8594
# Reset for new message
8695
self._unpacker.reset()

src/neo4j/_sync/io/_bolt.py

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j/_sync/io/_common.py

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import asyncio
18+
19+
import pytest
20+
21+
from neo4j._async.io._common import AsyncInbox
22+
from neo4j._codec.packstream.v1 import Unpacker
23+
from neo4j._exceptions import SocketDeadlineExceededError
24+
25+
from ....._async_compat import mark_async_test
26+
27+
28+
class InboxMockHolder:
29+
def __init__(self, mocker):
30+
self.socket_mock = mocker.Mock()
31+
self.socket_mock.getsockname.return_value = ("host", 1234)
32+
self.on_error = mocker.AsyncMock()
33+
self.inbox = AsyncInbox(self.socket_mock, self.on_error, Unpacker)
34+
self.unpacker = mocker.Mock(wraps=self.inbox._unpacker)
35+
self.inbox._unpacker = self.unpacker
36+
# plenty of nonsense messages to read
37+
self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000)
38+
39+
def mock_set_data(self, data):
40+
async def side_effect(buffer, n):
41+
nonlocal data
42+
43+
if not data:
44+
pytest.fail("Read more data than mocked")
45+
46+
n = min(len(data), len(buffer), n)
47+
buffer[:n] = data[:n]
48+
data = data[n:]
49+
return n
50+
51+
self.socket_mock.recv_into.side_effect = side_effect
52+
53+
def assert_no_error(self):
54+
self.on_error.assert_not_called()
55+
assert not self.inbox._broken
56+
57+
def mock_receive_failure(self, exception):
58+
self.socket_mock.recv_into.side_effect = exception
59+
60+
def mock_unpack_failure(self, exception):
61+
self.unpacker.unpack_structure_header.side_effect = exception
62+
63+
64+
@pytest.mark.parametrize(
65+
("data", "result"),
66+
(
67+
(
68+
bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14, 0, 0)),
69+
bytes(range(10, 15)),
70+
),
71+
(
72+
bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 0)),
73+
bytes(range(10, 14)),
74+
),
75+
(
76+
bytes((0, 1, 5, 0, 0)),
77+
bytes((5,)),
78+
),
79+
),
80+
)
81+
@mark_async_test
82+
async def test_inbox_dechunking(data, result, mocker):
83+
# Given
84+
mocks = InboxMockHolder(mocker)
85+
mocks.mock_set_data(data)
86+
inbox = mocks.inbox
87+
buffer = inbox._buffer
88+
89+
# When
90+
await inbox._buffer_one_chunk()
91+
92+
# Then
93+
mocks.assert_no_error()
94+
assert buffer.used == len(result)
95+
assert buffer.data[: len(result)] == result
96+
97+
98+
@pytest.mark.parametrize(
99+
"error",
100+
(
101+
asyncio.CancelledError("test"),
102+
SocketDeadlineExceededError("test"),
103+
OSError("test"),
104+
),
105+
)
106+
@mark_async_test
107+
async def test_inbox_receive_failure_error_handler(mocker, error):
108+
mocks = InboxMockHolder(mocker)
109+
mocks.mock_receive_failure(error)
110+
inbox = mocks.inbox
111+
112+
with pytest.raises(type(error)) as exc:
113+
await inbox.pop({})
114+
115+
assert exc.value is error
116+
mocks.on_error.assert_awaited_once_with(error)
117+
assert inbox._broken
118+
119+
120+
@pytest.mark.parametrize(
121+
"error",
122+
(
123+
SocketDeadlineExceededError("test"),
124+
OSError("test"),
125+
RecursionError("2deep4u"),
126+
RuntimeError("something funny happened"),
127+
),
128+
)
129+
@mark_async_test
130+
async def test_inbox_unpack_failure(mocker, error):
131+
mocks = InboxMockHolder(mocker)
132+
mocks.mock_unpack_failure(error)
133+
inbox = mocks.inbox
134+
135+
with pytest.raises(type(error)) as exc:
136+
await inbox.pop({})
137+
138+
assert exc.value is error
139+
mocks.on_error.assert_awaited_once_with(error)
140+
assert inbox._broken

0 commit comments

Comments
 (0)