Skip to content

Commit d0d60fc

Browse files
committed
Hide traceback of assert helpers
1 parent 952b9c3 commit d0d60fc

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import cmath
22
import math
3+
from functools import wraps
34
from inspect import getfullargspec
4-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
56

67
from . import _array_module as xp
78
from . import dtype_helpers as dh
@@ -122,6 +123,7 @@ def assert_dtype(
122123
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
123124
124125
"""
126+
__tracebackhide__ = True
125127
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
126128
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
127129
f_out_dtype = dh.dtype_to_name[out_dtype]
@@ -149,6 +151,7 @@ def assert_kw_dtype(
149151
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
150152
151153
"""
154+
__tracebackhide__ = True
152155
f_kw_dtype = dh.dtype_to_name[kw_dtype]
153156
f_out_dtype = dh.dtype_to_name[out_dtype]
154157
msg = (
@@ -166,6 +169,7 @@ def assert_default_float(func_name: str, out_dtype: DataType):
166169
>>> assert_default_float('ones', out.dtype)
167170
168171
"""
172+
__tracebackhide__ = True
169173
f_dtype = dh.dtype_to_name[out_dtype]
170174
f_default = dh.dtype_to_name[dh.default_float]
171175
msg = (
@@ -183,6 +187,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType):
183187
>>> assert_default_complex('asarray', out.dtype)
184188
185189
"""
190+
__tracebackhide__ = True
186191
f_dtype = dh.dtype_to_name[out_dtype]
187192
f_default = dh.dtype_to_name[dh.default_complex]
188193
msg = (
@@ -200,6 +205,7 @@ def assert_default_int(func_name: str, out_dtype: DataType):
200205
>>> assert_default_int('full', out.dtype)
201206
202207
"""
208+
__tracebackhide__ = True
203209
f_dtype = dh.dtype_to_name[out_dtype]
204210
f_default = dh.dtype_to_name[dh.default_int]
205211
msg = (
@@ -217,6 +223,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
217223
>>> assert_default_int('argmax', out.dtype)
218224
219225
"""
226+
__tracebackhide__ = True
220227
f_dtype = dh.dtype_to_name[out_dtype]
221228
msg = (
222229
f"{repr_name}={f_dtype}, should be the default index dtype, "
@@ -240,6 +247,7 @@ def assert_shape(
240247
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
241248
242249
"""
250+
__tracebackhide__ = True
243251
if isinstance(out_shape, int):
244252
out_shape = (out_shape,)
245253
if isinstance(expected, int):
@@ -273,6 +281,7 @@ def assert_result_shape(
273281
>>> assert out.shape == (3, 3)
274282
275283
"""
284+
__tracebackhide__ = True
276285
if expected is None:
277286
expected = sh.broadcast_shapes(*in_shapes)
278287
f_in_shapes = " . ".join(str(s) for s in in_shapes)
@@ -307,6 +316,7 @@ def assert_keepdimable_shape(
307316
>>> assert out2.shape == (1, 1)
308317
309318
"""
319+
__tracebackhide__ = True
310320
if keepdims:
311321
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
312322
else:
@@ -337,6 +347,7 @@ def assert_0d_equals(
337347
>>> assert res[0] == x[0]
338348
339349
"""
350+
__tracebackhide__ = True
340351
msg = (
341352
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
342353
f"[{func_name}({fmt_kw(kw)})]"
@@ -369,6 +380,7 @@ def assert_scalar_equals(
369380
>>> assert int(out) == 5
370381
371382
"""
383+
__tracebackhide__ = True
372384
repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
373385
f_func = f"{func_name}({fmt_kw(kw)})"
374386
if type_ in [bool, int]:
@@ -401,6 +413,7 @@ def assert_fill(
401413
>>> assert xp.all(out == 42)
402414
403415
"""
416+
__tracebackhide__ = True
404417
msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
405418
if cmath.isnan(fill_value):
406419
assert xp.all(xp.isnan(out)), msg
@@ -443,6 +456,7 @@ def assert_array_elements(
443456
>>> assert xp.all(out == x)
444457
445458
"""
459+
__tracebackhide__ = True
446460
dh.result_type(out.dtype, expected.dtype) # sanity check
447461
assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
448462
f_func = f"[{func_name}({fmt_kw(kw)})]"
@@ -470,3 +484,18 @@ def assert_array_elements(
470484
assert xp.all(
471485
out == expected
472486
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"
487+
488+
489+
def _make_wrapped_assert_helper(assert_helper: Callable) -> Callable:
490+
@wraps(assert_helper)
491+
def wrapped_assert_helper(*args, **kwargs):
492+
__tracebackhide__ = True
493+
assert_helper(*args, **kwargs)
494+
495+
return wrapped_assert_helper
496+
497+
498+
for func_name in __all__:
499+
if func_name.startswith("assert"):
500+
assert_helper = globals()[func_name]
501+
globals()[func_name] = _make_wrapped_assert_helper(assert_helper)

0 commit comments

Comments
 (0)