Skip to content

[4.4] Fix driver stuck on RecursionError on COMMIT SUCCESS #1203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,15 @@ def _set_defunct_write(self, error=None, silent=False):

def _set_defunct(self, message, error=None, silent=False):
direct_driver = isinstance(self.pool, BoltPool)
connection_failed = isinstance(
error,
(
ServiceUnavailable,
SessionExpired,
OSError,
SocketDeadlineExceeded,
),
)

if error:
log.debug("[#%04X] %r", self.local_port, error)
Expand All @@ -595,6 +604,12 @@ def _set_defunct(self, message, error=None, silent=False):
# connection from the client side, and remove the address
# from the connection pool.
self._defunct = True
if not connection_failed:
# Something else but the connection failed
# => we're not sure which state we're in
# => ditch the connection and raise the error for user-awareness
self.close()
raise error
if not self._closing:
# If we fail while closing the connection, there is no need to
# remove the connection from the pool, nor to try to close the
Expand Down
30 changes: 19 additions & 11 deletions neo4j/io/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ def __init__(self, s, on_error):
self._messages = self._yield_messages(s)

def _yield_messages(self, sock):
try:
buffer = UnpackableBuffer()
unpacker = Unpacker(buffer)
chunk_size = 0
while True:

buffer = UnpackableBuffer()
unpacker = Unpacker(buffer)
chunk_size = 0
while True:
try:
while chunk_size == 0:
# Determine the chunk size and skip noop
buffer.receive(sock, 2)
Expand All @@ -61,18 +60,27 @@ def _yield_messages(self, sock):

buffer.receive(sock, chunk_size + 2)
chunk_size = buffer.pop_u16()
except (OSError, socket.timeout, SocketDeadlineExceeded) as error:
self.on_error(error)

if chunk_size == 0:
# chunk_size was the end marker for the message
if chunk_size == 0:
# chunk_size was the end marker for the message
try:
size, tag = unpacker.unpack_structure_header()
fields = [unpacker.unpack() for _ in range(size)]
yield tag, fields
except Exception as error:
log.debug(
"[#%04X] _: Failed to unpack response: %r",
self._local_port,
error,
)
self.on_error(error)
raise
finally:
# Reset for new message
unpacker.reset()

except (OSError, socket.timeout, SocketDeadlineExceeded) as error:
self.on_error(error)

def pop(self):
return next(self._messages)

Expand Down
16 changes: 16 additions & 0 deletions tests/unit/io/_common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
118 changes: 118 additions & 0 deletions tests/unit/io/_common/test_inbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import asyncio


import pytest

from neo4j._exceptions import SocketDeadlineExceeded
from neo4j.io._common import Inbox
from neo4j.packstream import Unpacker


def _on_error_side_effect(e):
raise e


class InboxMockHolder:
def __init__(self, mocker):
self.socket_mock = mocker.Mock()
self.socket_mock.getsockname.return_value = ("host", 1234)
self.on_error = mocker.MagicMock()
self.on_error.side_effect = _on_error_side_effect
self.inbox = Inbox(self.socket_mock, self.on_error)
self.unpacker = None
self._unpacker_failure = None
mocker.patch(
"neo4j.io._common.Unpacker",
new=self._make_unpacker,
)
self.inbox._unpacker = self.unpacker
# plenty of nonsense messages to read
self.mock_set_data(b"\x00\x01\xff\x00\x00" * 1000)
self._mocker = mocker

def _make_unpacker(self, buffer):
if self.unpacker is not None:
pytest.fail("Unexpected 2nd instantiation of Unpacker")
self.unpacker = self._mocker.Mock(wraps=Unpacker(buffer))
if self._unpacker_failure is not None:
self.mock_unpack_failure(self._unpacker_failure)
return self.unpacker

def mock_set_data(self, data):
def side_effect(buffer, n):
nonlocal data

if not data:
pytest.fail("Read more data than mocked")

n = min(len(data), len(buffer), n)
buffer[:n] = data[:n]
data = data[n:]
return n

self.socket_mock.recv_into.side_effect = side_effect

def assert_no_error(self):
self.on_error.assert_not_called()
assert next(self.inbox, None) is not None

def mock_receive_failure(self, exception):
self.socket_mock.recv_into.side_effect = exception

def mock_unpack_failure(self, exception):
self._unpacker_failure = exception
if self.unpacker is not None:
self.unpacker.unpack_structure_header.side_effect = exception


@pytest.mark.parametrize(
"error",
(
SocketDeadlineExceeded("test"),
OSError("test"),
),
)
def test_inbox_receive_failure_error_handler(mocker, error):
mocks = InboxMockHolder(mocker)
mocks.mock_receive_failure(error)
inbox = mocks.inbox

with pytest.raises(type(error)) as exc:
next(inbox)

assert exc.value is error
mocks.on_error.assert_called_once_with(error)


@pytest.mark.parametrize(
"error",
(
SocketDeadlineExceeded("test"),
OSError("test"),
RecursionError("2deep4u"),
RuntimeError("something funny happened"),
),
)
def test_inbox_unpack_failure(mocker, error):
mocks = InboxMockHolder(mocker)
mocks.mock_unpack_failure(error)
inbox = mocks.inbox

with pytest.raises(type(error)) as exc:
next(inbox)

assert exc.value is error
mocks.on_error.assert_called_once_with(error)
48 changes: 48 additions & 0 deletions tests/unit/io/_common/test_oubox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from neo4j.io._common import Outbox


@pytest.mark.parametrize(("chunk_size", "data", "result"), (
(
2,
(bytes(range(10, 15)),),
bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
),
(
2,
(bytes(range(10, 14)),),
bytes((0, 2, 10, 11, 0, 2, 12, 13))
),
(
2,
(bytes((5, 6, 7)), bytes((8, 9))),
bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
),
))
def test_outbox_chunking(chunk_size, data, result):
outbox = Outbox(max_chunk_size=chunk_size)
assert bytes(outbox.view()) == b""
for d in data:
outbox.write(d)
assert bytes(outbox.view()) == result
# make sure this works multiple times
assert bytes(outbox.view()) == result
outbox.clear()
assert bytes(outbox.view()) == b""
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [http://neo4j.com]
#
# This file is part of Neo4j.
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -20,40 +18,7 @@

import pytest

from neo4j.io._common import (
Outbox,
ResetResponse,
)


@pytest.mark.parametrize(("chunk_size", "data", "result"), (
(
2,
(bytes(range(10, 15)),),
bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14))
),
(
2,
(bytes(range(10, 14)),),
bytes((0, 2, 10, 11, 0, 2, 12, 13))
),
(
2,
(bytes((5, 6, 7)), bytes((8, 9))),
bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9))
),
))
def test_outbox_chunking(chunk_size, data, result):
outbox = Outbox(max_chunk_size=chunk_size)
assert bytes(outbox.view()) == b""
for d in data:
outbox.write(d)
assert bytes(outbox.view()) == result
# make sure this works multiple times
assert bytes(outbox.view()) == result
outbox.clear()
assert bytes(outbox.view()) == b""

from neo4j.io._common import ResetResponse

def get_handler_arg(response):
if response == "RECORD":
Expand Down
Loading