Skip to content

Commit 5da1cd6

Browse files
committed
Fix pulling results in parallel
Consuming two results in the same TX could cause the driver sending too many PULL request to the server which led to FAILURE
1 parent 76b399d commit 5da1cd6

File tree

5 files changed

+149
-74
lines changed

5 files changed

+149
-74
lines changed

neo4j/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def send_all(self):
498498

499499
@abc.abstractmethod
500500
def fetch_message(self):
501-
""" Receive at least one message from the server, if available.
501+
""" Receive at most one message from the server, if available.
502502
503503
:return: 2-tuple of number of detail messages and number of summary
504504
messages fetched

neo4j/io/_bolt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def fail(metadata):
219219
self._is_reset = True
220220

221221
def fetch_message(self):
222-
""" Receive at least one message from the server, if available.
222+
""" Receive at most one message from the server, if available.
223223
224224
:return: 2-tuple of number of detail messages and number of summary
225225
messages fetched

neo4j/io/_bolt4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def fail(metadata):
231231
self._is_reset = True
232232

233233
def fetch_message(self):
234-
""" Receive at least one message from the server, if available.
234+
""" Receive at most one message from the server, if available.
235235
236236
:return: 2-tuple of number of detail messages and number of summary
237237
messages fetched

neo4j/work/result.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,12 @@ def __init__(self, connection, hydrant, fetch_size, on_closed,
8989
# states
9090
self._discarding = False # discard the remainder of records
9191
self._attached = False # attached to a connection
92-
self._streaming = False # there is still more records to buffer upp on the wire
93-
self._has_more = False # there is more records available to pull from the server
94-
self._closed = False # the result have been properly iterated or consumed fully
92+
# there are still more response messages we wait for
93+
self._streaming = False
94+
# there ar more records available to pull from the server
95+
self._has_more = False
96+
# the result has been fully iterated or consumed
97+
self._closed = False
9598

9699
def _tx_ready_run(self, query, parameters, **kwparameters):
97100
# BEGIN+RUN does not carry any extra on the RUN message.
@@ -112,11 +115,6 @@ def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters):
112115
"server": self._connection.server_info,
113116
}
114117

115-
run_metadata = {
116-
"metadata": query_metadata,
117-
"timeout": query_timeout,
118-
}
119-
120118
def on_attached(metadata):
121119
self._metadata.update(metadata)
122120
self._qid = metadata.get("qid", -1) # For auto-commit there is no qid and Bolt 3 do not support qid
@@ -144,9 +142,7 @@ def on_failed_attach(metadata):
144142
self._attach()
145143

146144
def _pull(self):
147-
148145
def on_records(records):
149-
self._streaming = True
150146
if not self._discarding:
151147
self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records))
152148

@@ -159,13 +155,11 @@ def on_failure(metadata):
159155
self._on_closed()
160156

161157
def on_success(summary_metadata):
158+
self._streaming = False
162159
has_more = summary_metadata.get("has_more")
160+
self._has_more = bool(has_more)
163161
if has_more:
164-
self._has_more = True
165-
self._streaming = False
166162
return
167-
else:
168-
self._has_more = False
169163

170164
self._metadata.update(summary_metadata)
171165
self._bookmark = summary_metadata.get("bookmark")
@@ -178,11 +172,9 @@ def on_success(summary_metadata):
178172
on_failure=on_failure,
179173
on_summary=on_summary,
180174
)
175+
self._streaming = True
181176

182177
def _discard(self):
183-
def on_records(records):
184-
pass
185-
186178
def on_summary():
187179
self._attached = False
188180
self._on_closed()
@@ -193,13 +185,13 @@ def on_failure(metadata):
193185
self._on_closed()
194186

195187
def on_success(summary_metadata):
188+
self._streaming = False
196189
has_more = summary_metadata.get("has_more")
190+
self._has_more = bool(has_more)
197191
if has_more:
198-
self._has_more = True
199-
self._streaming = False
200-
else:
201-
self._has_more = False
202-
self._discarding = False
192+
return
193+
self._discarding = False
194+
self._discarding = False
203195

204196
self._metadata.update(summary_metadata)
205197
self._bookmark = summary_metadata.get("bookmark")
@@ -208,32 +200,28 @@ def on_success(summary_metadata):
208200
self._connection.discard(
209201
n=-1,
210202
qid=self._qid,
211-
on_records=on_records,
212203
on_success=on_success,
213204
on_failure=on_failure,
214205
on_summary=on_summary,
215206
)
207+
self._streaming = True
216208

217209
def __iter__(self):
218210
"""Iterator returning Records.
219211
:returns: Record, it is an immutable ordered collection of key-value pairs.
220212
:rtype: :class:`neo4j.Record`
221213
"""
222214
while self._record_buffer or self._attached:
223-
while self._record_buffer:
215+
if self._record_buffer:
224216
yield self._record_buffer.popleft()
225-
226-
while self._attached is True: # _attached is set to False for _pull on_summary and _discard on_summary
227-
self._connection.fetch_message() # Receive at least one message from the server, if available.
228-
if self._attached:
229-
if self._record_buffer:
230-
yield self._record_buffer.popleft()
231-
elif self._discarding and self._streaming is False:
232-
self._discard()
233-
self._connection.send_all()
234-
elif self._has_more and self._streaming is False:
235-
self._pull()
236-
self._connection.send_all()
217+
elif self._streaming:
218+
self._connection.fetch_message()
219+
elif self._discarding:
220+
self._discard()
221+
self._connection.send_all()
222+
elif self._has_more:
223+
self._pull()
224+
self._connection.send_all()
237225

238226
self._closed = True
239227

tests/unit/work/test_result.py

Lines changed: 122 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,20 @@ def __eq__(self, other):
7878
def __repr__(self):
7979
return "Message(%s)" % self.message
8080

81-
def __init__(self, records=None, run_meta=None, summary_meta=None):
82-
self._records = records
81+
def __init__(self, records=None, run_meta=None, summary_meta=None,
82+
force_qid=False):
83+
self._multi_result = isinstance(records, (list, tuple))
84+
if self._multi_result:
85+
self._records = records
86+
self._use_qid = True
87+
else:
88+
self._records = records,
89+
self._use_qid = force_qid
8390
self.fetch_idx = 0
84-
self.record_idx = 0
85-
self.to_pull = None
91+
self._qid = -1
92+
self.record_idxs = [0] * len(self._records)
93+
self.to_pull = [None] * len(self._records)
94+
self._exhausted = [False] * len(self._records)
8695
self.queued = []
8796
self.sent = []
8897
self.run_meta = run_meta
@@ -99,36 +108,54 @@ def fetch_message(self):
99108
msg = self.sent[self.fetch_idx]
100109
if msg == "RUN":
101110
self.fetch_idx += 1
102-
msg.on_success({"fields": self._records.fields,
103-
**(self.run_meta or {})})
111+
self._qid += 1
112+
meta = {"fields": self._records[self._qid].fields,
113+
**(self.run_meta or {})}
114+
if self._use_qid:
115+
meta.update(qid=self._qid)
116+
msg.on_success(meta)
104117
elif msg == "DISCARD":
105118
self.fetch_idx += 1
106-
self.record_idx = len(self._records)
119+
qid = msg.kwargs.get("qid", -1)
120+
if qid < 0:
121+
qid = self._qid
122+
self.record_idxs[qid] = len(self._records[qid])
107123
msg.on_success(self.summary_meta or {})
108124
msg.on_summary()
109125
elif msg == "PULL":
110-
if self.to_pull is None:
126+
qid = msg.kwargs.get("qid", -1)
127+
if qid < 0:
128+
qid = self._qid
129+
if self._exhausted[qid]:
130+
pytest.fail("PULLing exhausted result")
131+
if self.to_pull[qid] is None:
111132
n = msg.kwargs.get("n", -1)
112133
if n < 0:
113-
n = len(self._records)
114-
self.to_pull = min(n, len(self._records) - self.record_idx)
134+
n = len(self._records[qid])
135+
self.to_pull[qid] = \
136+
min(n, len(self._records[qid]) - self.record_idxs[qid])
115137
# if to == len(self._records):
116138
# self.fetch_idx += 1
117-
if self.to_pull > 0:
118-
record = self._records[self.record_idx]
119-
self.record_idx += 1
120-
self.to_pull -= 1
139+
if self.to_pull[qid] > 0:
140+
record = self._records[qid][self.record_idxs[qid]]
141+
self.record_idxs[qid] += 1
142+
self.to_pull[qid] -= 1
121143
msg.on_records([record])
122-
elif self.to_pull == 0:
123-
self.to_pull = None
144+
elif self.to_pull[qid] == 0:
145+
self.to_pull[qid] = None
124146
self.fetch_idx += 1
125-
if self.record_idx < len(self._records):
147+
if self.record_idxs[qid] < len(self._records[qid]):
126148
msg.on_success({"has_more": True})
127149
else:
128150
msg.on_success({"bookmark": "foo",
129151
**(self.summary_meta or {})})
152+
self._exhausted[qid] = True
130153
msg.on_summary()
131154

155+
def fetch_all(self):
156+
while self.fetch_idx < len(self.sent):
157+
self.fetch_message()
158+
132159
def run(self, *args, **kwargs):
133160
self.queued.append(ConnectionStub.Message("RUN", *args, **kwargs))
134161

@@ -153,30 +180,90 @@ def noop(*_, **__):
153180
pass
154181

155182

156-
def test_result_iteration():
157-
records = [[1], [2], [3], [4], [5]]
158-
connection = ConnectionStub(records=Records(["x"], records))
159-
result = Result(connection, HydratorStub(), 2, noop, noop)
160-
result._run("CYPHER", {}, None, "r", None)
161-
received = []
162-
for record in result:
163-
assert isinstance(record, Record)
164-
received.append([record.data().get("x", None)])
165-
assert received == records
183+
def _fetch_and_compare_all_records(result, key, expected_records, method,
184+
limit=None):
185+
received_records = []
186+
if method == "for loop":
187+
for record in result:
188+
assert isinstance(record, Record)
189+
received_records.append([record.data().get(key, None)])
190+
if limit is not None and len(received_records) == limit:
191+
break
192+
elif method == "next":
193+
iter_ = iter(result)
194+
n = len(expected_records) if limit is None else limit
195+
for _ in range(n):
196+
received_records.append([next(iter_).get(key, None)])
197+
if limit is None:
198+
with pytest.raises(StopIteration):
199+
received_records.append([next(iter_).get(key, None)])
200+
elif method == "new iter":
201+
n = len(expected_records) if limit is None else limit
202+
for _ in range(n):
203+
received_records.append([next(iter(result)).get(key, None)])
204+
if limit is None:
205+
with pytest.raises(StopIteration):
206+
received_records.append([next(iter(result)).get(key, None)])
207+
else:
208+
raise ValueError()
209+
assert received_records == expected_records
166210

167211

168-
def test_result_next():
212+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
213+
def test_result_iteration(method):
169214
records = [[1], [2], [3], [4], [5]]
170215
connection = ConnectionStub(records=Records(["x"], records))
171216
result = Result(connection, HydratorStub(), 2, noop, noop)
172217
result._run("CYPHER", {}, None, "r", None)
173-
iter_ = iter(result)
174-
received = []
175-
for _ in range(len(records)):
176-
received.append([next(iter_).get("x", None)])
177-
with pytest.raises(StopIteration):
178-
received.append([next(iter_).get("x", None)])
179-
assert received == records
218+
_fetch_and_compare_all_records(result, "x", records, method)
219+
220+
221+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
222+
@pytest.mark.parametrize("invert_fetch", (True, False))
223+
def test_parallel_result_iteration(method, invert_fetch):
224+
records1 = [[i] for i in range(1, 6)]
225+
records2 = [[i] for i in range(6, 11)]
226+
connection = ConnectionStub(
227+
records=(Records(["x"], records1), Records(["x"], records2))
228+
)
229+
result1 = Result(connection, HydratorStub(), 2, noop, noop)
230+
result1._run("CYPHER1", {}, None, "r", None)
231+
result2 = Result(connection, HydratorStub(), 2, noop, noop)
232+
result2._run("CYPHER2", {}, None, "r", None)
233+
if invert_fetch:
234+
_fetch_and_compare_all_records(result2, "x", records2, method)
235+
_fetch_and_compare_all_records(result1, "x", records1, method)
236+
else:
237+
_fetch_and_compare_all_records(result1, "x", records1, method)
238+
_fetch_and_compare_all_records(result2, "x", records2, method)
239+
240+
241+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
242+
@pytest.mark.parametrize("invert_fetch", (True, False))
243+
def test_interwoven_result_iteration(method, invert_fetch):
244+
records1 = [[i] for i in range(1, 10)]
245+
records2 = [[i] for i in range(11, 20)]
246+
connection = ConnectionStub(
247+
records=(Records(["x"], records1), Records(["y"], records2))
248+
)
249+
result1 = Result(connection, HydratorStub(), 2, noop, noop)
250+
result1._run("CYPHER1", {}, None, "r", None)
251+
result2 = Result(connection, HydratorStub(), 2, noop, noop)
252+
result2._run("CYPHER2", {}, None, "r", None)
253+
start = 0
254+
for n in (1, 2, 3, 1, None):
255+
end = n if n is None else start + n
256+
if invert_fetch:
257+
_fetch_and_compare_all_records(result2, "y", records2[start:end],
258+
method, n)
259+
_fetch_and_compare_all_records(result1, "x", records1[start:end],
260+
method, n)
261+
else:
262+
_fetch_and_compare_all_records(result1, "x", records1[start:end],
263+
method, n)
264+
_fetch_and_compare_all_records(result2, "y", records2[start:end],
265+
method, n)
266+
start = end
180267

181268

182269
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))

0 commit comments

Comments
 (0)