diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 6465a65d..c449a5d8 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -28,6 +28,7 @@ SessionExpired, ) from neo4j.work.summary import ResultSummary +from neo4j.exceptions import ResultConsumedError class _ConnectionErrorHandler: @@ -223,20 +224,37 @@ def __iter__(self): self._closed = True def _attach(self): - """Sets the Result object in an attached state by fetching messages from the connection to the buffer. + """Sets the Result object in an attached state by fetching messages from + the connection to the buffer. """ if self._closed is False: while self._attached is False: self._connection.fetch_message() - def _buffer_all(self): - """Sets the Result object in an detached state by fetching all records from the connection to the buffer. + def _buffer(self, n=None): + """Try to fill `self_record_buffer` with n records. + + Might end up with more records in the buffer if the fetch size makes it + overshoot. + Might ent up with fewer records in the buffer if there are not enough + records available. """ record_buffer = deque() for record in self: record_buffer.append(record) + if n is not None and len(record_buffer) >= n: + break self._closed = False - self._record_buffer = record_buffer + if n is None: + self._record_buffer = record_buffer + else: + self._record_buffer.extend(record_buffer) + + def _buffer_all(self): + """Sets the Result object in an detached state by fetching all records + from the connection to the buffer. + """ + self._buffer() def _obtain_summary(self): """Obtain the summary of this result, buffering any remaining records. @@ -309,6 +327,13 @@ def single(self): :returns: the next :class:`neo4j.Record` or :const:`None` if none remain :warns: if more than one record is available """ + # TODO in 5.0 replace with this code that raises an error if there's not + # exactly one record in the left result stream. + # self._buffer(2). + # if len(self._record_buffer) != 1: + # raise SomeError("Expected exactly 1 record, found %i" + # % len(self._record_buffer)) + # return self._record_buffer.popleft() records = list(self) # TODO: exhausts the result with self.consume if there are more records. size = len(records) if size == 0: @@ -323,16 +348,9 @@ def peek(self): :returns: the next :class:`.Record` or :const:`None` if none remain """ + self._buffer(1) if self._record_buffer: return self._record_buffer[0] - if not self._attached: - return None - while self._attached: - self._connection.fetch_message() - if self._record_buffer: - return self._record_buffer[0] - - return None def graph(self): """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index d7f49ece..936d3be8 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -272,12 +272,14 @@ def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) result._run("CYPHER", {}, None, "r", None) - record = result.peek() - if not records: - assert record is None - else: - assert isinstance(record, Record) - assert record.get("x") == records[0][0] + for i in range(len(records) + 1): + record = result.peek() + if i == len(records): + assert record is None + else: + assert isinstance(record, Record) + assert record.get("x") == records[i][0] + next(iter(result)) # consume the record @pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))