Skip to content

Commit 69c4d46

Browse files
authored
[4.3] Fix endless loop in Result.peek with fetch_size=1 (#590)
Backport of fix included in #587
1 parent 192793b commit 69c4d46

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

neo4j/work/result.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SessionExpired,
2929
)
3030
from neo4j.work.summary import ResultSummary
31+
from neo4j.exceptions import ResultConsumedError
3132

3233

3334
class _ConnectionErrorHandler:
@@ -223,20 +224,37 @@ def __iter__(self):
223224
self._closed = True
224225

225226
def _attach(self):
226-
"""Sets the Result object in an attached state by fetching messages from the connection to the buffer.
227+
"""Sets the Result object in an attached state by fetching messages from
228+
the connection to the buffer.
227229
"""
228230
if self._closed is False:
229231
while self._attached is False:
230232
self._connection.fetch_message()
231233

232-
def _buffer_all(self):
233-
"""Sets the Result object in an detached state by fetching all records from the connection to the buffer.
234+
def _buffer(self, n=None):
235+
"""Try to fill `self_record_buffer` with n records.
236+
237+
Might end up with more records in the buffer if the fetch size makes it
238+
overshoot.
239+
Might ent up with fewer records in the buffer if there are not enough
240+
records available.
234241
"""
235242
record_buffer = deque()
236243
for record in self:
237244
record_buffer.append(record)
245+
if n is not None and len(record_buffer) >= n:
246+
break
238247
self._closed = False
239-
self._record_buffer = record_buffer
248+
if n is None:
249+
self._record_buffer = record_buffer
250+
else:
251+
self._record_buffer.extend(record_buffer)
252+
253+
def _buffer_all(self):
254+
"""Sets the Result object in an detached state by fetching all records
255+
from the connection to the buffer.
256+
"""
257+
self._buffer()
240258

241259
def _obtain_summary(self):
242260
"""Obtain the summary of this result, buffering any remaining records.
@@ -309,6 +327,13 @@ def single(self):
309327
:returns: the next :class:`neo4j.Record` or :const:`None` if none remain
310328
:warns: if more than one record is available
311329
"""
330+
# TODO in 5.0 replace with this code that raises an error if there's not
331+
# exactly one record in the left result stream.
332+
# self._buffer(2).
333+
# if len(self._record_buffer) != 1:
334+
# raise SomeError("Expected exactly 1 record, found %i"
335+
# % len(self._record_buffer))
336+
# return self._record_buffer.popleft()
312337
records = list(self) # TODO: exhausts the result with self.consume if there are more records.
313338
size = len(records)
314339
if size == 0:
@@ -323,16 +348,9 @@ def peek(self):
323348
324349
:returns: the next :class:`.Record` or :const:`None` if none remain
325350
"""
351+
self._buffer(1)
326352
if self._record_buffer:
327353
return self._record_buffer[0]
328-
if not self._attached:
329-
return None
330-
while self._attached:
331-
self._connection.fetch_message()
332-
if self._record_buffer:
333-
return self._record_buffer[0]
334-
335-
return None
336354

337355
def graph(self):
338356
"""Return a :class:`neo4j.graph.Graph` instance containing all the graph objects

tests/unit/work/test_result.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,14 @@ def test_result_peek(records, fetch_size):
272272
connection = ConnectionStub(records=Records(["x"], records))
273273
result = Result(connection, HydratorStub(), fetch_size, noop, noop)
274274
result._run("CYPHER", {}, None, "r", None)
275-
record = result.peek()
276-
if not records:
277-
assert record is None
278-
else:
279-
assert isinstance(record, Record)
280-
assert record.get("x") == records[0][0]
275+
for i in range(len(records) + 1):
276+
record = result.peek()
277+
if i == len(records):
278+
assert record is None
279+
else:
280+
assert isinstance(record, Record)
281+
assert record.get("x") == records[i][0]
282+
next(iter(result)) # consume the record
281283

282284

283285
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))

0 commit comments

Comments
 (0)