Skip to content

Commit 9d77e8d

Browse files
authored
Fix pulling results in parallel (#561)
Consuming two results in the same TX could cause the driver to send too many PULL request to the server which led to FAILURE
1 parent 76b399d commit 9d77e8d

File tree

5 files changed

+148
-76
lines changed

5 files changed

+148
-76
lines changed

neo4j/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def send_all(self):
498498

499499
@abc.abstractmethod
500500
def fetch_message(self):
501-
""" Receive at least one message from the server, if available.
501+
""" Receive at most one message from the server, if available.
502502
503503
:return: 2-tuple of number of detail messages and number of summary
504504
messages fetched

neo4j/io/_bolt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def fail(metadata):
219219
self._is_reset = True
220220

221221
def fetch_message(self):
222-
""" Receive at least one message from the server, if available.
222+
""" Receive at most one message from the server, if available.
223223
224224
:return: 2-tuple of number of detail messages and number of summary
225225
messages fetched

neo4j/io/_bolt4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def fail(metadata):
231231
self._is_reset = True
232232

233233
def fetch_message(self):
234-
""" Receive at least one message from the server, if available.
234+
""" Receive at most one message from the server, if available.
235235
236236
:return: 2-tuple of number of detail messages and number of summary
237237
messages fetched

neo4j/work/result.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,12 @@ def __init__(self, connection, hydrant, fetch_size, on_closed,
8989
# states
9090
self._discarding = False # discard the remainder of records
9191
self._attached = False # attached to a connection
92-
self._streaming = False # there is still more records to buffer upp on the wire
93-
self._has_more = False # there is more records available to pull from the server
94-
self._closed = False # the result have been properly iterated or consumed fully
92+
# there are still more response messages we wait for
93+
self._streaming = False
94+
# there ar more records available to pull from the server
95+
self._has_more = False
96+
# the result has been fully iterated or consumed
97+
self._closed = False
9598

9699
def _tx_ready_run(self, query, parameters, **kwparameters):
97100
# BEGIN+RUN does not carry any extra on the RUN message.
@@ -112,11 +115,6 @@ def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters):
112115
"server": self._connection.server_info,
113116
}
114117

115-
run_metadata = {
116-
"metadata": query_metadata,
117-
"timeout": query_timeout,
118-
}
119-
120118
def on_attached(metadata):
121119
self._metadata.update(metadata)
122120
self._qid = metadata.get("qid", -1) # For auto-commit there is no qid and Bolt 3 do not support qid
@@ -144,9 +142,7 @@ def on_failed_attach(metadata):
144142
self._attach()
145143

146144
def _pull(self):
147-
148145
def on_records(records):
149-
self._streaming = True
150146
if not self._discarding:
151147
self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records))
152148

@@ -159,14 +155,11 @@ def on_failure(metadata):
159155
self._on_closed()
160156

161157
def on_success(summary_metadata):
158+
self._streaming = False
162159
has_more = summary_metadata.get("has_more")
160+
self._has_more = bool(has_more)
163161
if has_more:
164-
self._has_more = True
165-
self._streaming = False
166162
return
167-
else:
168-
self._has_more = False
169-
170163
self._metadata.update(summary_metadata)
171164
self._bookmark = summary_metadata.get("bookmark")
172165

@@ -178,11 +171,9 @@ def on_success(summary_metadata):
178171
on_failure=on_failure,
179172
on_summary=on_summary,
180173
)
174+
self._streaming = True
181175

182176
def _discard(self):
183-
def on_records(records):
184-
pass
185-
186177
def on_summary():
187178
self._attached = False
188179
self._on_closed()
@@ -193,47 +184,41 @@ def on_failure(metadata):
193184
self._on_closed()
194185

195186
def on_success(summary_metadata):
187+
self._streaming = False
196188
has_more = summary_metadata.get("has_more")
189+
self._has_more = bool(has_more)
197190
if has_more:
198-
self._has_more = True
199-
self._streaming = False
200-
else:
201-
self._has_more = False
202-
self._discarding = False
203-
191+
return
192+
self._discarding = False
204193
self._metadata.update(summary_metadata)
205194
self._bookmark = summary_metadata.get("bookmark")
206195

207196
# This was the last page received, discard the rest
208197
self._connection.discard(
209198
n=-1,
210199
qid=self._qid,
211-
on_records=on_records,
212200
on_success=on_success,
213201
on_failure=on_failure,
214202
on_summary=on_summary,
215203
)
204+
self._streaming = True
216205

217206
def __iter__(self):
218207
"""Iterator returning Records.
219208
:returns: Record, it is an immutable ordered collection of key-value pairs.
220209
:rtype: :class:`neo4j.Record`
221210
"""
222211
while self._record_buffer or self._attached:
223-
while self._record_buffer:
212+
if self._record_buffer:
224213
yield self._record_buffer.popleft()
225-
226-
while self._attached is True: # _attached is set to False for _pull on_summary and _discard on_summary
227-
self._connection.fetch_message() # Receive at least one message from the server, if available.
228-
if self._attached:
229-
if self._record_buffer:
230-
yield self._record_buffer.popleft()
231-
elif self._discarding and self._streaming is False:
232-
self._discard()
233-
self._connection.send_all()
234-
elif self._has_more and self._streaming is False:
235-
self._pull()
236-
self._connection.send_all()
214+
elif self._streaming:
215+
self._connection.fetch_message()
216+
elif self._discarding:
217+
self._discard()
218+
self._connection.send_all()
219+
elif self._has_more:
220+
self._pull()
221+
self._connection.send_all()
237222

238223
self._closed = True
239224

tests/unit/work/test_result.py

Lines changed: 122 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,20 @@ def __eq__(self, other):
7878
def __repr__(self):
7979
return "Message(%s)" % self.message
8080

81-
def __init__(self, records=None, run_meta=None, summary_meta=None):
82-
self._records = records
81+
def __init__(self, records=None, run_meta=None, summary_meta=None,
82+
force_qid=False):
83+
self._multi_result = isinstance(records, (list, tuple))
84+
if self._multi_result:
85+
self._records = records
86+
self._use_qid = True
87+
else:
88+
self._records = records,
89+
self._use_qid = force_qid
8390
self.fetch_idx = 0
84-
self.record_idx = 0
85-
self.to_pull = None
91+
self._qid = -1
92+
self.record_idxs = [0] * len(self._records)
93+
self.to_pull = [None] * len(self._records)
94+
self._exhausted = [False] * len(self._records)
8695
self.queued = []
8796
self.sent = []
8897
self.run_meta = run_meta
@@ -99,36 +108,54 @@ def fetch_message(self):
99108
msg = self.sent[self.fetch_idx]
100109
if msg == "RUN":
101110
self.fetch_idx += 1
102-
msg.on_success({"fields": self._records.fields,
103-
**(self.run_meta or {})})
111+
self._qid += 1
112+
meta = {"fields": self._records[self._qid].fields,
113+
**(self.run_meta or {})}
114+
if self._use_qid:
115+
meta.update(qid=self._qid)
116+
msg.on_success(meta)
104117
elif msg == "DISCARD":
105118
self.fetch_idx += 1
106-
self.record_idx = len(self._records)
119+
qid = msg.kwargs.get("qid", -1)
120+
if qid < 0:
121+
qid = self._qid
122+
self.record_idxs[qid] = len(self._records[qid])
107123
msg.on_success(self.summary_meta or {})
108124
msg.on_summary()
109125
elif msg == "PULL":
110-
if self.to_pull is None:
126+
qid = msg.kwargs.get("qid", -1)
127+
if qid < 0:
128+
qid = self._qid
129+
if self._exhausted[qid]:
130+
pytest.fail("PULLing exhausted result")
131+
if self.to_pull[qid] is None:
111132
n = msg.kwargs.get("n", -1)
112133
if n < 0:
113-
n = len(self._records)
114-
self.to_pull = min(n, len(self._records) - self.record_idx)
134+
n = len(self._records[qid])
135+
self.to_pull[qid] = \
136+
min(n, len(self._records[qid]) - self.record_idxs[qid])
115137
# if to == len(self._records):
116138
# self.fetch_idx += 1
117-
if self.to_pull > 0:
118-
record = self._records[self.record_idx]
119-
self.record_idx += 1
120-
self.to_pull -= 1
139+
if self.to_pull[qid] > 0:
140+
record = self._records[qid][self.record_idxs[qid]]
141+
self.record_idxs[qid] += 1
142+
self.to_pull[qid] -= 1
121143
msg.on_records([record])
122-
elif self.to_pull == 0:
123-
self.to_pull = None
144+
elif self.to_pull[qid] == 0:
145+
self.to_pull[qid] = None
124146
self.fetch_idx += 1
125-
if self.record_idx < len(self._records):
147+
if self.record_idxs[qid] < len(self._records[qid]):
126148
msg.on_success({"has_more": True})
127149
else:
128150
msg.on_success({"bookmark": "foo",
129151
**(self.summary_meta or {})})
152+
self._exhausted[qid] = True
130153
msg.on_summary()
131154

155+
def fetch_all(self):
156+
while self.fetch_idx < len(self.sent):
157+
self.fetch_message()
158+
132159
def run(self, *args, **kwargs):
133160
self.queued.append(ConnectionStub.Message("RUN", *args, **kwargs))
134161

@@ -153,30 +180,90 @@ def noop(*_, **__):
153180
pass
154181

155182

156-
def test_result_iteration():
157-
records = [[1], [2], [3], [4], [5]]
158-
connection = ConnectionStub(records=Records(["x"], records))
159-
result = Result(connection, HydratorStub(), 2, noop, noop)
160-
result._run("CYPHER", {}, None, "r", None)
161-
received = []
162-
for record in result:
163-
assert isinstance(record, Record)
164-
received.append([record.data().get("x", None)])
165-
assert received == records
183+
def _fetch_and_compare_all_records(result, key, expected_records, method,
184+
limit=None):
185+
received_records = []
186+
if method == "for loop":
187+
for record in result:
188+
assert isinstance(record, Record)
189+
received_records.append([record.data().get(key, None)])
190+
if limit is not None and len(received_records) == limit:
191+
break
192+
elif method == "next":
193+
iter_ = iter(result)
194+
n = len(expected_records) if limit is None else limit
195+
for _ in range(n):
196+
received_records.append([next(iter_).get(key, None)])
197+
if limit is None:
198+
with pytest.raises(StopIteration):
199+
received_records.append([next(iter_).get(key, None)])
200+
elif method == "new iter":
201+
n = len(expected_records) if limit is None else limit
202+
for _ in range(n):
203+
received_records.append([next(iter(result)).get(key, None)])
204+
if limit is None:
205+
with pytest.raises(StopIteration):
206+
received_records.append([next(iter(result)).get(key, None)])
207+
else:
208+
raise ValueError()
209+
assert received_records == expected_records
166210

167211

168-
def test_result_next():
212+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
213+
def test_result_iteration(method):
169214
records = [[1], [2], [3], [4], [5]]
170215
connection = ConnectionStub(records=Records(["x"], records))
171216
result = Result(connection, HydratorStub(), 2, noop, noop)
172217
result._run("CYPHER", {}, None, "r", None)
173-
iter_ = iter(result)
174-
received = []
175-
for _ in range(len(records)):
176-
received.append([next(iter_).get("x", None)])
177-
with pytest.raises(StopIteration):
178-
received.append([next(iter_).get("x", None)])
179-
assert received == records
218+
_fetch_and_compare_all_records(result, "x", records, method)
219+
220+
221+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
222+
@pytest.mark.parametrize("invert_fetch", (True, False))
223+
def test_parallel_result_iteration(method, invert_fetch):
224+
records1 = [[i] for i in range(1, 6)]
225+
records2 = [[i] for i in range(6, 11)]
226+
connection = ConnectionStub(
227+
records=(Records(["x"], records1), Records(["x"], records2))
228+
)
229+
result1 = Result(connection, HydratorStub(), 2, noop, noop)
230+
result1._run("CYPHER1", {}, None, "r", None)
231+
result2 = Result(connection, HydratorStub(), 2, noop, noop)
232+
result2._run("CYPHER2", {}, None, "r", None)
233+
if invert_fetch:
234+
_fetch_and_compare_all_records(result2, "x", records2, method)
235+
_fetch_and_compare_all_records(result1, "x", records1, method)
236+
else:
237+
_fetch_and_compare_all_records(result1, "x", records1, method)
238+
_fetch_and_compare_all_records(result2, "x", records2, method)
239+
240+
241+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
242+
@pytest.mark.parametrize("invert_fetch", (True, False))
243+
def test_interwoven_result_iteration(method, invert_fetch):
244+
records1 = [[i] for i in range(1, 10)]
245+
records2 = [[i] for i in range(11, 20)]
246+
connection = ConnectionStub(
247+
records=(Records(["x"], records1), Records(["y"], records2))
248+
)
249+
result1 = Result(connection, HydratorStub(), 2, noop, noop)
250+
result1._run("CYPHER1", {}, None, "r", None)
251+
result2 = Result(connection, HydratorStub(), 2, noop, noop)
252+
result2._run("CYPHER2", {}, None, "r", None)
253+
start = 0
254+
for n in (1, 2, 3, 1, None):
255+
end = n if n is None else start + n
256+
if invert_fetch:
257+
_fetch_and_compare_all_records(result2, "y", records2[start:end],
258+
method, n)
259+
_fetch_and_compare_all_records(result1, "x", records1[start:end],
260+
method, n)
261+
else:
262+
_fetch_and_compare_all_records(result1, "x", records1[start:end],
263+
method, n)
264+
_fetch_and_compare_all_records(result2, "y", records2[start:end],
265+
method, n)
266+
start = end
180267

181268

182269
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))

0 commit comments

Comments
 (0)