Skip to content

[4.3] Fix endless loop in Result.peek with fetch_size=1 #590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions neo4j/work/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SessionExpired,
)
from neo4j.work.summary import ResultSummary
from neo4j.exceptions import ResultConsumedError


class _ConnectionErrorHandler:
Expand Down Expand Up @@ -223,20 +224,37 @@ def __iter__(self):
self._closed = True

def _attach(self):
"""Sets the Result object in an attached state by fetching messages from the connection to the buffer.
"""Sets the Result object in an attached state by fetching messages from
the connection to the buffer.
"""
if self._closed is False:
while self._attached is False:
self._connection.fetch_message()

def _buffer_all(self):
"""Sets the Result object in an detached state by fetching all records from the connection to the buffer.
def _buffer(self, n=None):
"""Try to fill `self_record_buffer` with n records.

Might end up with more records in the buffer if the fetch size makes it
overshoot.
Might ent up with fewer records in the buffer if there are not enough
records available.
"""
record_buffer = deque()
for record in self:
record_buffer.append(record)
if n is not None and len(record_buffer) >= n:
break
self._closed = False
self._record_buffer = record_buffer
if n is None:
self._record_buffer = record_buffer
else:
self._record_buffer.extend(record_buffer)

def _buffer_all(self):
"""Sets the Result object in an detached state by fetching all records
from the connection to the buffer.
"""
self._buffer()

def _obtain_summary(self):
"""Obtain the summary of this result, buffering any remaining records.
Expand Down Expand Up @@ -309,6 +327,13 @@ def single(self):
:returns: the next :class:`neo4j.Record` or :const:`None` if none remain
:warns: if more than one record is available
"""
# TODO in 5.0 replace with this code that raises an error if there's not
# exactly one record in the left result stream.
# self._buffer(2).
# if len(self._record_buffer) != 1:
# raise SomeError("Expected exactly 1 record, found %i"
# % len(self._record_buffer))
# return self._record_buffer.popleft()
records = list(self) # TODO: exhausts the result with self.consume if there are more records.
size = len(records)
if size == 0:
Expand All @@ -323,16 +348,9 @@ def peek(self):

:returns: the next :class:`.Record` or :const:`None` if none remain
"""
self._buffer(1)
if self._record_buffer:
return self._record_buffer[0]
if not self._attached:
return None
while self._attached:
self._connection.fetch_message()
if self._record_buffer:
return self._record_buffer[0]

return None

def graph(self):
"""Return a :class:`neo4j.graph.Graph` instance containing all the graph objects
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/work/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ def test_result_peek(records, fetch_size):
connection = ConnectionStub(records=Records(["x"], records))
result = Result(connection, HydratorStub(), fetch_size, noop, noop)
result._run("CYPHER", {}, None, "r", None)
record = result.peek()
if not records:
assert record is None
else:
assert isinstance(record, Record)
assert record.get("x") == records[0][0]
for i in range(len(records) + 1):
record = result.peek()
if i == len(records):
assert record is None
else:
assert isinstance(record, Record)
assert record.get("x") == records[i][0]
next(iter(result)) # consume the record


@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))
Expand Down