Skip to content

Commit 3d4bc82

Browse files
committed
Better traceback
1 parent 47ef8b5 commit 3d4bc82

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

src/array_api_extra/testing.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,35 @@ class CountingDaskScheduler(SchedulerGetCallable):
203203
"""
204204
Dask scheduler that counts how many times `dask.compute` is called.
205205
206+
If the number of times exceeds 'max_count', it raises an error.
206207
This is a wrapper around Dask's own 'synchronous' scheduler.
208+
209+
Parameters
210+
----------
211+
max_count : int
212+
Maximum number of allowed calls to `dask.compute`.
213+
msg : str
214+
Assertion to raise when the count exceeds `max_count`.
207215
"""
208216

209217
count: int
218+
max_count: int
219+
msg: str
210220

211-
def __init__(self): # numpydoc ignore=GL08
221+
def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08
212222
self.count = 0
223+
self.max_count = max_count
224+
self.msg = msg
213225

214226
@override
215227
def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
216228
import dask
217229

218230
self.count += 1
231+
# This should yield a nice traceback to the
232+
# offending line in the user's code
233+
assert self.count <= self.max_count, self.msg
234+
219235
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
220236

221237

@@ -228,21 +244,18 @@ def _allow_dask_compute(
228244
import dask.config
229245

230246
func_name = getattr(func, "__name__", str(func))
247+
n_str = f"only up to {n}" if n else "no"
248+
msg = (
249+
f"Called `dask.compute()` or `dask.persist()` {n + 1} times, "
250+
f"but {n_str} calls are allowed. Set "
251+
f"`lazy_xp_function({func_name}, allow_dask_compute={n + 1}) "
252+
"to allow for more (but note that this will harm performance). "
253+
)
231254

232255
@wraps(func)
233256
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
234-
scheduler = CountingDaskScheduler()
257+
scheduler = CountingDaskScheduler(n, msg)
235258
with dask.config.set({"scheduler": scheduler}):
236-
out = func(*args, **kwargs)
237-
if scheduler.count > n:
238-
n_str = f"only up to {n}" if n else "no"
239-
msg = (
240-
f"Called `dask.compute()` or `dask.persist()` {scheduler.count} times, "
241-
f"but {n_str} calls are allowed. Set "
242-
f"`lazy_xp_function({func_name}, allow_dask_compute={scheduler.count}) "
243-
"to allow for more (but note that this will harm performance). "
244-
)
245-
raise AssertionError(msg)
246-
return out
259+
return func(*args, **kwargs)
247260

248261
return wrapper

tests/test_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_lazy_xp_function(xp: ModuleType):
160160
non_materializable3(x)
161161
with pytest.raises(
162162
AssertionError,
163-
match=r"dask\.compute.* 2 times, but no calls are allowed",
163+
match=r"dask\.compute.* 1 times, but no calls are allowed",
164164
):
165165
non_materializable4(x)
166166

0 commit comments

Comments
 (0)