@@ -203,19 +203,35 @@ class CountingDaskScheduler(SchedulerGetCallable):
203
203
"""
204
204
Dask scheduler that counts how many times `dask.compute` is called.
205
205
206
+ If the number of times exceeds 'max_count', it raises an error.
206
207
This is a wrapper around Dask's own 'synchronous' scheduler.
208
+
209
+ Parameters
210
+ ----------
211
+ max_count : int
212
+ Maximum number of allowed calls to `dask.compute`.
213
+ msg : str
214
+ Assertion to raise when the count exceeds `max_count`.
207
215
"""
208
216
209
217
count : int
218
+ max_count : int
219
+ msg : str
210
220
211
- def __init__ (self ): # numpydoc ignore=GL08
221
+ def __init__ (self , max_count : int , msg : str ): # numpydoc ignore=GL08
212
222
self .count = 0
223
+ self .max_count = max_count
224
+ self .msg = msg
213
225
214
226
@override
215
227
def __call__ (self , dsk : Graph , keys : Sequence [Key ] | Key , ** kwargs : Any ) -> Any : # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
216
228
import dask
217
229
218
230
self .count += 1
231
+ # This should yield a nice traceback to the
232
+ # offending line in the user's code
233
+ assert self .count <= self .max_count , self .msg
234
+
219
235
return dask .get (dsk , keys , ** kwargs ) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
220
236
221
237
@@ -228,21 +244,18 @@ def _allow_dask_compute(
228
244
import dask .config
229
245
230
246
func_name = getattr (func , "__name__" , str (func ))
247
+ n_str = f"only up to { n } " if n else "no"
248
+ msg = (
249
+ f"Called `dask.compute()` or `dask.persist()` { n + 1 } times, "
250
+ f"but { n_str } calls are allowed. Set "
251
+ f"`lazy_xp_function({ func_name } , allow_dask_compute={ n + 1 } ) "
252
+ "to allow for more (but note that this will harm performance). "
253
+ )
231
254
232
255
@wraps (func )
233
256
def wrapper (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
234
- scheduler = CountingDaskScheduler ()
257
+ scheduler = CountingDaskScheduler (n , msg )
235
258
with dask .config .set ({"scheduler" : scheduler }):
236
- out = func (* args , ** kwargs )
237
- if scheduler .count > n :
238
- n_str = f"only up to { n } " if n else "no"
239
- msg = (
240
- f"Called `dask.compute()` or `dask.persist()` { scheduler .count } times, "
241
- f"but { n_str } calls are allowed. Set "
242
- f"`lazy_xp_function({ func_name } , allow_dask_compute={ scheduler .count } ) "
243
- "to allow for more (but note that this will harm performance). "
244
- )
245
- raise AssertionError (msg )
246
- return out
259
+ return func (* args , ** kwargs )
247
260
248
261
return wrapper
0 commit comments