1
1
import cmath
2
2
import math
3
+ from functools import wraps
3
4
from inspect import getfullargspec
4
- from typing import Any , Dict , Optional , Sequence , Tuple , Union
5
+ from typing import Any , Callable , Dict , Optional , Sequence , Tuple , Union
5
6
6
7
from . import _array_module as xp
7
8
from . import dtype_helpers as dh
@@ -122,6 +123,7 @@ def assert_dtype(
122
123
>>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
123
124
124
125
"""
126
+ __tracebackhide__ = True
125
127
in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) and not isinstance (in_dtype , str ) else [in_dtype ]
126
128
f_in_dtypes = dh .fmt_types (tuple (in_dtypes ))
127
129
f_out_dtype = dh .dtype_to_name [out_dtype ]
@@ -149,6 +151,7 @@ def assert_kw_dtype(
149
151
>>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
150
152
151
153
"""
154
+ __tracebackhide__ = True
152
155
f_kw_dtype = dh .dtype_to_name [kw_dtype ]
153
156
f_out_dtype = dh .dtype_to_name [out_dtype ]
154
157
msg = (
@@ -166,6 +169,7 @@ def assert_default_float(func_name: str, out_dtype: DataType):
166
169
>>> assert_default_float('ones', out.dtype)
167
170
168
171
"""
172
+ __tracebackhide__ = True
169
173
f_dtype = dh .dtype_to_name [out_dtype ]
170
174
f_default = dh .dtype_to_name [dh .default_float ]
171
175
msg = (
@@ -183,6 +187,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType):
183
187
>>> assert_default_complex('asarray', out.dtype)
184
188
185
189
"""
190
+ __tracebackhide__ = True
186
191
f_dtype = dh .dtype_to_name [out_dtype ]
187
192
f_default = dh .dtype_to_name [dh .default_complex ]
188
193
msg = (
@@ -200,6 +205,7 @@ def assert_default_int(func_name: str, out_dtype: DataType):
200
205
>>> assert_default_int('full', out.dtype)
201
206
202
207
"""
208
+ __tracebackhide__ = True
203
209
f_dtype = dh .dtype_to_name [out_dtype ]
204
210
f_default = dh .dtype_to_name [dh .default_int ]
205
211
msg = (
@@ -217,6 +223,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
217
223
>>> assert_default_int('argmax', out.dtype)
218
224
219
225
"""
226
+ __tracebackhide__ = True
220
227
f_dtype = dh .dtype_to_name [out_dtype ]
221
228
msg = (
222
229
f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
@@ -240,6 +247,7 @@ def assert_shape(
240
247
>>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
241
248
242
249
"""
250
+ __tracebackhide__ = True
243
251
if isinstance (out_shape , int ):
244
252
out_shape = (out_shape ,)
245
253
if isinstance (expected , int ):
@@ -273,6 +281,7 @@ def assert_result_shape(
273
281
>>> assert out.shape == (3, 3)
274
282
275
283
"""
284
+ __tracebackhide__ = True
276
285
if expected is None :
277
286
expected = sh .broadcast_shapes (* in_shapes )
278
287
f_in_shapes = " . " .join (str (s ) for s in in_shapes )
@@ -307,6 +316,7 @@ def assert_keepdimable_shape(
307
316
>>> assert out2.shape == (1, 1)
308
317
309
318
"""
319
+ __tracebackhide__ = True
310
320
if keepdims :
311
321
shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
312
322
else :
@@ -337,6 +347,7 @@ def assert_0d_equals(
337
347
>>> assert res[0] == x[0]
338
348
339
349
"""
350
+ __tracebackhide__ = True
340
351
msg = (
341
352
f"{ out_repr } ={ out_val } , but should be { x_repr } ={ x_val } "
342
353
f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -369,6 +380,7 @@ def assert_scalar_equals(
369
380
>>> assert int(out) == 5
370
381
371
382
"""
383
+ __tracebackhide__ = True
372
384
repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
373
385
f_func = f"{ func_name } ({ fmt_kw (kw )} )"
374
386
if type_ in [bool , int ]:
@@ -401,6 +413,7 @@ def assert_fill(
401
413
>>> assert xp.all(out == 42)
402
414
403
415
"""
416
+ __tracebackhide__ = True
404
417
msg = f"out not filled with { fill_value } [{ func_name } ({ fmt_kw (kw )} )]\n { out = } "
405
418
if cmath .isnan (fill_value ):
406
419
assert xp .all (xp .isnan (out )), msg
@@ -443,6 +456,7 @@ def assert_array_elements(
443
456
>>> assert xp.all(out == x)
444
457
445
458
"""
459
+ __tracebackhide__ = True
446
460
dh .result_type (out .dtype , expected .dtype ) # sanity check
447
461
assert_shape (func_name , out_shape = out .shape , expected = expected .shape , kw = kw ) # sanity check
448
462
f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -470,3 +484,18 @@ def assert_array_elements(
470
484
assert xp .all (
471
485
out == expected
472
486
), f"{ out_repr } not as expected { f_func } \n { out_repr } ={ out !r} \n { expected = } "
487
+
488
+
489
+ def _make_wrapped_assert_helper (assert_helper : Callable ) -> Callable :
490
+ @wraps (assert_helper )
491
+ def wrapped_assert_helper (* args , ** kwargs ):
492
+ __tracebackhide__ = True
493
+ assert_helper (* args , ** kwargs )
494
+
495
+ return wrapped_assert_helper
496
+
497
+
498
+ for func_name in __all__ :
499
+ if func_name .startswith ("assert" ):
500
+ assert_helper = globals ()[func_name ]
501
+ globals ()[func_name ] = _make_wrapped_assert_helper (assert_helper )
0 commit comments