Skip to content

Commit 356c829

Browse files
authored
Harden driver against unexpected RESET responses (#1042)
The server has been observed to reply with `FAILURE` and `IGNORED` to `RESET` requests. The former is according to spec and the driver should drop the connection (which it didn't), the latter isn't. The right combination of those two unexpected responses at the right time could get the driver stuck in an infinite loop. This change makes the driver drop the connection in either case to gracefully handle the situation.
1 parent dcb337c commit 356c829

File tree

5 files changed

+156
-16
lines changed

5 files changed

+156
-16
lines changed

neo4j/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def fetch_all(self):
544544
messages fetched
545545
"""
546546
detail_count = summary_count = 0
547-
while self.responses:
547+
while not self._closed and self.responses:
548548
response = self.responses[0]
549549
while not response.complete:
550550
detail_delta, summary_delta = self.fetch_message()

neo4j/io/_bolt3.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from neo4j.io._common import (
4747
CommitResponse,
4848
InitResponse,
49+
ResetResponse,
4950
Response,
5051
tx_timeout_as_ms,
5152
)
@@ -288,15 +289,13 @@ def rollback(self, **handlers):
288289
self._append(b"\x13", (), Response(self, "rollback", **handlers))
289290

290291
def reset(self):
291-
""" Add a RESET message to the outgoing queue, send
292-
it and consume all remaining messages.
293-
"""
294-
295-
def fail(metadata):
296-
raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address)
292+
"""Reset the connection.
297293
294+
Add a RESET message to the outgoing queue, send it and consume all
295+
remaining messages.
296+
"""
298297
log.debug("[#%04X] C: RESET", self.local_port)
299-
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
298+
self._append(b"\x0F", response=ResetResponse(self, "reset"))
300299
self.send_all()
301300
self.fetch_all()
302301

neo4j/io/_bolt4.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from neo4j.io._common import (
4747
CommitResponse,
4848
InitResponse,
49+
ResetResponse,
4950
Response,
5051
tx_timeout_as_ms,
5152
)
@@ -240,15 +241,13 @@ def rollback(self, **handlers):
240241
self._append(b"\x13", (), Response(self, "rollback", **handlers))
241242

242243
def reset(self):
243-
""" Add a RESET message to the outgoing queue, send
244-
it and consume all remaining messages.
245-
"""
246-
247-
def fail(metadata):
248-
raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address)
244+
"""Reset the connection.
249245
246+
Add a RESET message to the outgoing queue, send it and consume all
247+
remaining messages.
248+
"""
250249
log.debug("[#%04X] C: RESET", self.local_port)
251-
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
250+
self._append(b"\x0F", response=ResetResponse(self, "reset"))
252251
self.send_all()
253252
self.fetch_all()
254253

neo4j/io/_common.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,26 @@ def on_failure(self, metadata):
267267
)
268268

269269

270+
class ResetResponse(Response):
271+
def _unexpected_message(self, response):
272+
log.warning("[#%04X] RESET received %s (unexpected response) "
273+
"=> dropping connection",
274+
self.connection.local_port, response)
275+
self.connection.close()
276+
277+
def on_records(self, records):
278+
self._unexpected_message("RECORD")
279+
280+
def on_success(self, metadata):
281+
pass
282+
283+
def on_failure(self, metadata):
284+
self._unexpected_message("FAILURE")
285+
286+
def on_ignored(self, metadata=None):
287+
self._unexpected_message("IGNORED")
288+
289+
270290
class CommitResponse(Response):
271291

272292
pass

tests/unit/io/test__common.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [http://neo4j.com]
3+
#
4+
# This file is part of Neo4j.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
import logging
20+
121
import pytest
222

3-
from neo4j.io._common import Outbox
23+
from neo4j.io._common import (
24+
Outbox,
25+
ResetResponse,
26+
)
27+
28+
from ..work import fake_connection
429

530

631
@pytest.mark.parametrize(("chunk_size", "data", "result"), (
@@ -30,3 +55,100 @@ def test_outbox_chunking(chunk_size, data, result):
3055
assert bytes(outbox.view()) == result
3156
outbox.clear()
3257
assert bytes(outbox.view()) == b""
58+
59+
60+
def get_handler_arg(response):
61+
if response == "RECORD":
62+
return []
63+
elif response == "IGNORED":
64+
return {}
65+
elif response == "FAILURE":
66+
return {}
67+
elif response == "SUCCESS":
68+
return {}
69+
else:
70+
raise ValueError(f"Unexpected response: {response}")
71+
72+
73+
def call_handler(handler, response, arg=None):
74+
if arg is None:
75+
arg = get_handler_arg(response)
76+
77+
if response == "RECORD":
78+
return handler.on_records(arg)
79+
elif response == "IGNORED":
80+
return handler.on_ignored(arg)
81+
elif response == "FAILURE":
82+
return handler.on_failure(arg)
83+
elif response == "SUCCESS":
84+
return handler.on_success(arg)
85+
else:
86+
raise ValueError(f"Unexpected response: {response}")
87+
88+
89+
@pytest.mark.parametrize(
90+
("response", "unexpected"),
91+
(
92+
("RECORD", True),
93+
("IGNORED", True),
94+
("FAILURE", True),
95+
("SUCCESS", False),
96+
)
97+
)
98+
def test_reset_response_closes_connection_on_unexpected_responses(
99+
response, unexpected, fake_connection
100+
):
101+
handler = ResetResponse(fake_connection, "reset")
102+
fake_connection.close.assert_not_called()
103+
104+
call_handler(handler, response)
105+
106+
if unexpected:
107+
fake_connection.close.assert_called_once()
108+
else:
109+
fake_connection.close.assert_not_called()
110+
111+
112+
@pytest.mark.parametrize(
113+
("response", "unexpected"),
114+
(
115+
("RECORD", True),
116+
("IGNORED", True),
117+
("FAILURE", True),
118+
("SUCCESS", False),
119+
)
120+
)
121+
def test_reset_response_logs_warning_on_unexpected_responses(
122+
response, unexpected, fake_connection, caplog
123+
):
124+
handler = ResetResponse(fake_connection, "reset")
125+
126+
with caplog.at_level(logging.WARNING):
127+
call_handler(handler, response)
128+
129+
log_message_found = any("RESET" in msg and "unexpected response" in msg
130+
for msg in caplog.messages)
131+
if unexpected:
132+
assert log_message_found
133+
else:
134+
assert not log_message_found
135+
136+
137+
@pytest.mark.parametrize("response",
138+
("RECORD", "IGNORED", "FAILURE", "SUCCESS"))
139+
def test_reset_response_never_calls_handlers(
140+
response, fake_connection, mocker
141+
):
142+
handlers = {
143+
key: mocker.MagicMock(name=key)
144+
for key in
145+
("on_records", "on_ignored", "on_failure", "on_success", "on_summary")
146+
}
147+
148+
handler = ResetResponse(fake_connection, "reset", **handlers)
149+
150+
arg = get_handler_arg(response)
151+
call_handler(handler, response, arg)
152+
153+
for handler in handlers.values():
154+
handler.assert_not_called()

0 commit comments

Comments
 (0)