Skip to content

Commit 97d09e3

Browse files
committed
Fix pulling results in parallel
Consuming two results in the same TX could cause the driver sending too many PULL request to the server which led to FAILURE
1 parent 76b399d commit 97d09e3

File tree

5 files changed

+148
-76
lines changed

5 files changed

+148
-76
lines changed

neo4j/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def send_all(self):
498498

499499
@abc.abstractmethod
500500
def fetch_message(self):
501-
""" Receive at least one message from the server, if available.
501+
""" Receive at most one message from the server, if available.
502502
503503
:return: 2-tuple of number of detail messages and number of summary
504504
messages fetched

neo4j/io/_bolt3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def fail(metadata):
219219
self._is_reset = True
220220

221221
def fetch_message(self):
222-
""" Receive at least one message from the server, if available.
222+
""" Receive at most one message from the server, if available.
223223
224224
:return: 2-tuple of number of detail messages and number of summary
225225
messages fetched

neo4j/io/_bolt4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def fail(metadata):
231231
self._is_reset = True
232232

233233
def fetch_message(self):
234-
""" Receive at least one message from the server, if available.
234+
""" Receive at most one message from the server, if available.
235235
236236
:return: 2-tuple of number of detail messages and number of summary
237237
messages fetched

neo4j/work/result.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,12 @@ def __init__(self, connection, hydrant, fetch_size, on_closed,
8989
# states
9090
self._discarding = False # discard the remainder of records
9191
self._attached = False # attached to a connection
92-
self._streaming = False # there is still more records to buffer upp on the wire
93-
self._has_more = False # there is more records available to pull from the server
94-
self._closed = False # the result have been properly iterated or consumed fully
92+
# there are still more response messages we wait for
93+
self._streaming = False
94+
# there ar more records available to pull from the server
95+
self._has_more = False
96+
# the result has been fully iterated or consumed
97+
self._closed = False
9598

9699
def _tx_ready_run(self, query, parameters, **kwparameters):
97100
# BEGIN+RUN does not carry any extra on the RUN message.
@@ -112,11 +115,6 @@ def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters):
112115
"server": self._connection.server_info,
113116
}
114117

115-
run_metadata = {
116-
"metadata": query_metadata,
117-
"timeout": query_timeout,
118-
}
119-
120118
def on_attached(metadata):
121119
self._metadata.update(metadata)
122120
self._qid = metadata.get("qid", -1) # For auto-commit there is no qid and Bolt 3 do not support qid
@@ -144,9 +142,7 @@ def on_failed_attach(metadata):
144142
self._attach()
145143

146144
def _pull(self):
147-
148145
def on_records(records):
149-
self._streaming = True
150146
if not self._discarding:
151147
self._record_buffer.extend(self._hydrant.hydrate_records(self._keys, records))
152148

@@ -159,14 +155,11 @@ def on_failure(metadata):
159155
self._on_closed()
160156

161157
def on_success(summary_metadata):
158+
self._streaming = False
162159
has_more = summary_metadata.get("has_more")
160+
self._has_more = bool(has_more)
163161
if has_more:
164-
self._has_more = True
165-
self._streaming = False
166162
return
167-
else:
168-
self._has_more = False
169-
170163
self._metadata.update(summary_metadata)
171164
self._bookmark = summary_metadata.get("bookmark")
172165

@@ -178,11 +171,9 @@ def on_success(summary_metadata):
178171
on_failure=on_failure,
179172
on_summary=on_summary,
180173
)
174+
self._streaming = True
181175

182176
def _discard(self):
183-
def on_records(records):
184-
pass
185-
186177
def on_summary():
187178
self._attached = False
188179
self._on_closed()
@@ -193,47 +184,41 @@ def on_failure(metadata):
193184
self._on_closed()
194185

195186
def on_success(summary_metadata):
187+
self._streaming = False
196188
has_more = summary_metadata.get("has_more")
189+
self._has_more = bool(has_more)
197190
if has_more:
198-
self._has_more = True
199-
self._streaming = False
200-
else:
201-
self._has_more = False
202-
self._discarding = False
203-
191+
return
192+
self._discarding = False
204193
self._metadata.update(summary_metadata)
205194
self._bookmark = summary_metadata.get("bookmark")
206195

207196
# This was the last page received, discard the rest
208197
self._connection.discard(
209198
n=-1,
210199
qid=self._qid,
211-
on_records=on_records,
212200
on_success=on_success,
213201
on_failure=on_failure,
214202
on_summary=on_summary,
215203
)
204+
self._streaming = True
216205

217206
def __iter__(self):
218207
"""Iterator returning Records.
219208
:returns: Record, it is an immutable ordered collection of key-value pairs.
220209
:rtype: :class:`neo4j.Record`
221210
"""
222211
while self._record_buffer or self._attached:
223-
while self._record_buffer:
212+
if self._record_buffer:
224213
yield self._record_buffer.popleft()
225-
226-
while self._attached is True: # _attached is set to False for _pull on_summary and _discard on_summary
227-
self._connection.fetch_message() # Receive at least one message from the server, if available.
228-
if self._attached:
229-
if self._record_buffer:
230-
yield self._record_buffer.popleft()
231-
elif self._discarding and self._streaming is False:
232-
self._discard()
233-
self._connection.send_all()
234-
elif self._has_more and self._streaming is False:
235-
self._pull()
236-
self._connection.send_all()
214+
elif self._streaming:
215+
self._connection.fetch_message()
216+
elif self._discarding:
217+
self._discard()
218+
self._connection.send_all()
219+
elif self._has_more:
220+
self._pull()
221+
self._connection.send_all()
237222

238223
self._closed = True
239224

tests/unit/work/test_result.py

Lines changed: 122 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,20 @@ def __eq__(self, other):
7878
def __repr__(self):
7979
return "Message(%s)" % self.message
8080

81-
def __init__(self, records=None, run_meta=None, summary_meta=None):
82-
self._records = records
81+
def __init__(self, records=None, run_meta=None, summary_meta=None,
82+
force_qid=False):
83+
self._multi_result = isinstance(records, (list, tuple))
84+
if self._multi_result:
85+
self._records = records
86+
self._use_qid = True
87+
else:
88+
self._records = records,
89+
self._use_qid = force_qid
8390
self.fetch_idx = 0
84-
self.record_idx = 0
85-
self.to_pull = None
91+
self._qid = -1
92+
self.record_idxs = [0] * len(self._records)
93+
self.to_pull = [None] * len(self._records)
94+
self._exhausted = [False] * len(self._records)
8695
self.queued = []
8796
self.sent = []
8897
self.run_meta = run_meta
@@ -99,36 +108,54 @@ def fetch_message(self):
99108
msg = self.sent[self.fetch_idx]
100109
if msg == "RUN":
101110
self.fetch_idx += 1
102-
msg.on_success({"fields": self._records.fields,
103-
**(self.run_meta or {})})
111+
self._qid += 1
112+
meta = {"fields": self._records[self._qid].fields,
113+
**(self.run_meta or {})}
114+
if self._use_qid:
115+
meta.update(qid=self._qid)
116+
msg.on_success(meta)
104117
elif msg == "DISCARD":
105118
self.fetch_idx += 1
106-
self.record_idx = len(self._records)
119+
qid = msg.kwargs.get("qid", -1)
120+
if qid < 0:
121+
qid = self._qid
122+
self.record_idxs[qid] = len(self._records[qid])
107123
msg.on_success(self.summary_meta or {})
108124
msg.on_summary()
109125
elif msg == "PULL":
110-
if self.to_pull is None:
126+
qid = msg.kwargs.get("qid", -1)
127+
if qid < 0:
128+
qid = self._qid
129+
if self._exhausted[qid]:
130+
pytest.fail("PULLing exhausted result")
131+
if self.to_pull[qid] is None:
111132
n = msg.kwargs.get("n", -1)
112133
if n < 0:
113-
n = len(self._records)
114-
self.to_pull = min(n, len(self._records) - self.record_idx)
134+
n = len(self._records[qid])
135+
self.to_pull[qid] = \
136+
min(n, len(self._records[qid]) - self.record_idxs[qid])
115137
# if to == len(self._records):
116138
# self.fetch_idx += 1
117-
if self.to_pull > 0:
118-
record = self._records[self.record_idx]
119-
self.record_idx += 1
120-
self.to_pull -= 1
139+
if self.to_pull[qid] > 0:
140+
record = self._records[qid][self.record_idxs[qid]]
141+
self.record_idxs[qid] += 1
142+
self.to_pull[qid] -= 1
121143
msg.on_records([record])
122-
elif self.to_pull == 0:
123-
self.to_pull = None
144+
elif self.to_pull[qid] == 0:
145+
self.to_pull[qid] = None
124146
self.fetch_idx += 1
125-
if self.record_idx < len(self._records):
147+
if self.record_idxs[qid] < len(self._records[qid]):
126148
msg.on_success({"has_more": True})
127149
else:
128150
msg.on_success({"bookmark": "foo",
129151
**(self.summary_meta or {})})
152+
self._exhausted[qid] = True
130153
msg.on_summary()
131154

155+
def fetch_all(self):
156+
while self.fetch_idx < len(self.sent):
157+
self.fetch_message()
158+
132159
def run(self, *args, **kwargs):
133160
self.queued.append(ConnectionStub.Message("RUN", *args, **kwargs))
134161

@@ -153,30 +180,90 @@ def noop(*_, **__):
153180
pass
154181

155182

156-
def test_result_iteration():
157-
records = [[1], [2], [3], [4], [5]]
158-
connection = ConnectionStub(records=Records(["x"], records))
159-
result = Result(connection, HydratorStub(), 2, noop, noop)
160-
result._run("CYPHER", {}, None, "r", None)
161-
received = []
162-
for record in result:
163-
assert isinstance(record, Record)
164-
received.append([record.data().get("x", None)])
165-
assert received == records
183+
def _fetch_and_compare_all_records(result, key, expected_records, method,
184+
limit=None):
185+
received_records = []
186+
if method == "for loop":
187+
for record in result:
188+
assert isinstance(record, Record)
189+
received_records.append([record.data().get(key, None)])
190+
if limit is not None and len(received_records) == limit:
191+
break
192+
elif method == "next":
193+
iter_ = iter(result)
194+
n = len(expected_records) if limit is None else limit
195+
for _ in range(n):
196+
received_records.append([next(iter_).get(key, None)])
197+
if limit is None:
198+
with pytest.raises(StopIteration):
199+
received_records.append([next(iter_).get(key, None)])
200+
elif method == "new iter":
201+
n = len(expected_records) if limit is None else limit
202+
for _ in range(n):
203+
received_records.append([next(iter(result)).get(key, None)])
204+
if limit is None:
205+
with pytest.raises(StopIteration):
206+
received_records.append([next(iter(result)).get(key, None)])
207+
else:
208+
raise ValueError()
209+
assert received_records == expected_records
166210

167211

168-
def test_result_next():
212+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
213+
def test_result_iteration(method):
169214
records = [[1], [2], [3], [4], [5]]
170215
connection = ConnectionStub(records=Records(["x"], records))
171216
result = Result(connection, HydratorStub(), 2, noop, noop)
172217
result._run("CYPHER", {}, None, "r", None)
173-
iter_ = iter(result)
174-
received = []
175-
for _ in range(len(records)):
176-
received.append([next(iter_).get("x", None)])
177-
with pytest.raises(StopIteration):
178-
received.append([next(iter_).get("x", None)])
179-
assert received == records
218+
_fetch_and_compare_all_records(result, "x", records, method)
219+
220+
221+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
222+
@pytest.mark.parametrize("invert_fetch", (True, False))
223+
def test_parallel_result_iteration(method, invert_fetch):
224+
records1 = [[i] for i in range(1, 6)]
225+
records2 = [[i] for i in range(6, 11)]
226+
connection = ConnectionStub(
227+
records=(Records(["x"], records1), Records(["x"], records2))
228+
)
229+
result1 = Result(connection, HydratorStub(), 2, noop, noop)
230+
result1._run("CYPHER1", {}, None, "r", None)
231+
result2 = Result(connection, HydratorStub(), 2, noop, noop)
232+
result2._run("CYPHER2", {}, None, "r", None)
233+
if invert_fetch:
234+
_fetch_and_compare_all_records(result2, "x", records2, method)
235+
_fetch_and_compare_all_records(result1, "x", records1, method)
236+
else:
237+
_fetch_and_compare_all_records(result1, "x", records1, method)
238+
_fetch_and_compare_all_records(result2, "x", records2, method)
239+
240+
241+
@pytest.mark.parametrize("method", ("for loop", "next", "new iter"))
242+
@pytest.mark.parametrize("invert_fetch", (True, False))
243+
def test_interwoven_result_iteration(method, invert_fetch):
244+
records1 = [[i] for i in range(1, 10)]
245+
records2 = [[i] for i in range(11, 20)]
246+
connection = ConnectionStub(
247+
records=(Records(["x"], records1), Records(["y"], records2))
248+
)
249+
result1 = Result(connection, HydratorStub(), 2, noop, noop)
250+
result1._run("CYPHER1", {}, None, "r", None)
251+
result2 = Result(connection, HydratorStub(), 2, noop, noop)
252+
result2._run("CYPHER2", {}, None, "r", None)
253+
start = 0
254+
for n in (1, 2, 3, 1, None):
255+
end = n if n is None else start + n
256+
if invert_fetch:
257+
_fetch_and_compare_all_records(result2, "y", records2[start:end],
258+
method, n)
259+
_fetch_and_compare_all_records(result1, "x", records1[start:end],
260+
method, n)
261+
else:
262+
_fetch_and_compare_all_records(result1, "x", records1[start:end],
263+
method, n)
264+
_fetch_and_compare_all_records(result2, "y", records2[start:end],
265+
method, n)
266+
start = end
180267

181268

182269
@pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))

0 commit comments

Comments
 (0)