@@ -78,11 +78,20 @@ def __eq__(self, other):
78
78
def __repr__ (self ):
79
79
return "Message(%s)" % self .message
80
80
81
- def __init__ (self , records = None , run_meta = None , summary_meta = None ):
82
- self ._records = records
81
+ def __init__ (self , records = None , run_meta = None , summary_meta = None ,
82
+ force_qid = False ):
83
+ self ._multi_result = isinstance (records , (list , tuple ))
84
+ if self ._multi_result :
85
+ self ._records = records
86
+ self ._use_qid = True
87
+ else :
88
+ self ._records = records ,
89
+ self ._use_qid = force_qid
83
90
self .fetch_idx = 0
84
- self .record_idx = 0
85
- self .to_pull = None
91
+ self ._qid = - 1
92
+ self .record_idxs = [0 ] * len (self ._records )
93
+ self .to_pull = [None ] * len (self ._records )
94
+ self ._exhausted = [False ] * len (self ._records )
86
95
self .queued = []
87
96
self .sent = []
88
97
self .run_meta = run_meta
@@ -99,36 +108,54 @@ def fetch_message(self):
99
108
msg = self .sent [self .fetch_idx ]
100
109
if msg == "RUN" :
101
110
self .fetch_idx += 1
102
- msg .on_success ({"fields" : self ._records .fields ,
103
- ** (self .run_meta or {})})
111
+ self ._qid += 1
112
+ meta = {"fields" : self ._records [self ._qid ].fields ,
113
+ ** (self .run_meta or {})}
114
+ if self ._use_qid :
115
+ meta .update (qid = self ._qid )
116
+ msg .on_success (meta )
104
117
elif msg == "DISCARD" :
105
118
self .fetch_idx += 1
106
- self .record_idx = len (self ._records )
119
+ qid = msg .kwargs .get ("qid" , - 1 )
120
+ if qid < 0 :
121
+ qid = self ._qid
122
+ self .record_idxs [qid ] = len (self ._records [qid ])
107
123
msg .on_success (self .summary_meta or {})
108
124
msg .on_summary ()
109
125
elif msg == "PULL" :
110
- if self .to_pull is None :
126
+ qid = msg .kwargs .get ("qid" , - 1 )
127
+ if qid < 0 :
128
+ qid = self ._qid
129
+ if self ._exhausted [qid ]:
130
+ pytest .fail ("PULLing exhausted result" )
131
+ if self .to_pull [qid ] is None :
111
132
n = msg .kwargs .get ("n" , - 1 )
112
133
if n < 0 :
113
- n = len (self ._records )
114
- self .to_pull = min (n , len (self ._records ) - self .record_idx )
134
+ n = len (self ._records [qid ])
135
+ self .to_pull [qid ] = \
136
+ min (n , len (self ._records [qid ]) - self .record_idxs [qid ])
115
137
# if to == len(self._records):
116
138
# self.fetch_idx += 1
117
- if self .to_pull > 0 :
118
- record = self ._records [self .record_idx ]
119
- self .record_idx += 1
120
- self .to_pull -= 1
139
+ if self .to_pull [ qid ] > 0 :
140
+ record = self ._records [qid ][ self .record_idxs [ qid ] ]
141
+ self .record_idxs [ qid ] += 1
142
+ self .to_pull [ qid ] -= 1
121
143
msg .on_records ([record ])
122
- elif self .to_pull == 0 :
123
- self .to_pull = None
144
+ elif self .to_pull [ qid ] == 0 :
145
+ self .to_pull [ qid ] = None
124
146
self .fetch_idx += 1
125
- if self .record_idx < len (self ._records ):
147
+ if self .record_idxs [ qid ] < len (self ._records [ qid ] ):
126
148
msg .on_success ({"has_more" : True })
127
149
else :
128
150
msg .on_success ({"bookmark" : "foo" ,
129
151
** (self .summary_meta or {})})
152
+ self ._exhausted [qid ] = True
130
153
msg .on_summary ()
131
154
155
+ def fetch_all (self ):
156
+ while self .fetch_idx < len (self .sent ):
157
+ self .fetch_message ()
158
+
132
159
def run (self , * args , ** kwargs ):
133
160
self .queued .append (ConnectionStub .Message ("RUN" , * args , ** kwargs ))
134
161
@@ -153,30 +180,90 @@ def noop(*_, **__):
153
180
pass
154
181
155
182
156
- def test_result_iteration ():
157
- records = [[1 ], [2 ], [3 ], [4 ], [5 ]]
158
- connection = ConnectionStub (records = Records (["x" ], records ))
159
- result = Result (connection , HydratorStub (), 2 , noop , noop )
160
- result ._run ("CYPHER" , {}, None , "r" , None )
161
- received = []
162
- for record in result :
163
- assert isinstance (record , Record )
164
- received .append ([record .data ().get ("x" , None )])
165
- assert received == records
183
+ def _fetch_and_compare_all_records (result , key , expected_records , method ,
184
+ limit = None ):
185
+ received_records = []
186
+ if method == "for loop" :
187
+ for record in result :
188
+ assert isinstance (record , Record )
189
+ received_records .append ([record .data ().get (key , None )])
190
+ if limit is not None and len (received_records ) == limit :
191
+ break
192
+ elif method == "next" :
193
+ iter_ = iter (result )
194
+ n = len (expected_records ) if limit is None else limit
195
+ for _ in range (n ):
196
+ received_records .append ([next (iter_ ).get (key , None )])
197
+ if limit is None :
198
+ with pytest .raises (StopIteration ):
199
+ received_records .append ([next (iter_ ).get (key , None )])
200
+ elif method == "new iter" :
201
+ n = len (expected_records ) if limit is None else limit
202
+ for _ in range (n ):
203
+ received_records .append ([next (iter (result )).get (key , None )])
204
+ if limit is None :
205
+ with pytest .raises (StopIteration ):
206
+ received_records .append ([next (iter (result )).get (key , None )])
207
+ else :
208
+ raise ValueError ()
209
+ assert received_records == expected_records
166
210
167
211
168
- def test_result_next ():
212
+ @pytest .mark .parametrize ("method" , ("for loop" , "next" , "new iter" ))
213
+ def test_result_iteration (method ):
169
214
records = [[1 ], [2 ], [3 ], [4 ], [5 ]]
170
215
connection = ConnectionStub (records = Records (["x" ], records ))
171
216
result = Result (connection , HydratorStub (), 2 , noop , noop )
172
217
result ._run ("CYPHER" , {}, None , "r" , None )
173
- iter_ = iter (result )
174
- received = []
175
- for _ in range (len (records )):
176
- received .append ([next (iter_ ).get ("x" , None )])
177
- with pytest .raises (StopIteration ):
178
- received .append ([next (iter_ ).get ("x" , None )])
179
- assert received == records
218
+ _fetch_and_compare_all_records (result , "x" , records , method )
219
+
220
+
221
+ @pytest .mark .parametrize ("method" , ("for loop" , "next" , "new iter" ))
222
+ @pytest .mark .parametrize ("invert_fetch" , (True , False ))
223
+ def test_parallel_result_iteration (method , invert_fetch ):
224
+ records1 = [[i ] for i in range (1 , 6 )]
225
+ records2 = [[i ] for i in range (6 , 11 )]
226
+ connection = ConnectionStub (
227
+ records = (Records (["x" ], records1 ), Records (["x" ], records2 ))
228
+ )
229
+ result1 = Result (connection , HydratorStub (), 2 , noop , noop )
230
+ result1 ._run ("CYPHER1" , {}, None , "r" , None )
231
+ result2 = Result (connection , HydratorStub (), 2 , noop , noop )
232
+ result2 ._run ("CYPHER2" , {}, None , "r" , None )
233
+ if invert_fetch :
234
+ _fetch_and_compare_all_records (result2 , "x" , records2 , method )
235
+ _fetch_and_compare_all_records (result1 , "x" , records1 , method )
236
+ else :
237
+ _fetch_and_compare_all_records (result1 , "x" , records1 , method )
238
+ _fetch_and_compare_all_records (result2 , "x" , records2 , method )
239
+
240
+
241
+ @pytest .mark .parametrize ("method" , ("for loop" , "next" , "new iter" ))
242
+ @pytest .mark .parametrize ("invert_fetch" , (True , False ))
243
+ def test_interwoven_result_iteration (method , invert_fetch ):
244
+ records1 = [[i ] for i in range (1 , 10 )]
245
+ records2 = [[i ] for i in range (11 , 20 )]
246
+ connection = ConnectionStub (
247
+ records = (Records (["x" ], records1 ), Records (["y" ], records2 ))
248
+ )
249
+ result1 = Result (connection , HydratorStub (), 2 , noop , noop )
250
+ result1 ._run ("CYPHER1" , {}, None , "r" , None )
251
+ result2 = Result (connection , HydratorStub (), 2 , noop , noop )
252
+ result2 ._run ("CYPHER2" , {}, None , "r" , None )
253
+ start = 0
254
+ for n in (1 , 2 , 3 , 1 , None ):
255
+ end = n if n is None else start + n
256
+ if invert_fetch :
257
+ _fetch_and_compare_all_records (result2 , "y" , records2 [start :end ],
258
+ method , n )
259
+ _fetch_and_compare_all_records (result1 , "x" , records1 [start :end ],
260
+ method , n )
261
+ else :
262
+ _fetch_and_compare_all_records (result1 , "x" , records1 [start :end ],
263
+ method , n )
264
+ _fetch_and_compare_all_records (result2 , "y" , records2 [start :end ],
265
+ method , n )
266
+ start = end
180
267
181
268
182
269
@pytest .mark .parametrize ("records" , ([[1 ], [2 ]], [[1 ]], []))
0 commit comments