Skip to content

Commit b47039f

Browse files
committed
Add helpful error messages to assert_raises calls in test_array_object.py
1 parent 5379bd5 commit b47039f

File tree

1 file changed

+32
-33
lines changed

1 file changed

+32
-33
lines changed

array_api_strict/tests/test_array_object.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import operator
22
from builtins import all as all_
33

4-
from numpy.testing import assert_raises, suppress_warnings
4+
import numpy.testing
55
import numpy as np
66
import pytest
77

@@ -29,6 +29,10 @@
2929

3030
import array_api_strict
3131

32+
def assert_raises(exception, func, msg=None):
33+
with numpy.testing.assert_raises(exception, msg=msg):
34+
func()
35+
3236
def test_validate_index():
3337
# The indexing tests in the official array API test suite test that the
3438
# array object correctly handles the subset of indices that are required
@@ -111,6 +115,7 @@ def test_operators():
111115
"__truediv__": "floating",
112116
"__xor__": "integer_or_boolean",
113117
}
118+
comparison_ops = ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]
114119
# Recompute each time because of in-place ops
115120
def _array_vals():
116121
for d in _integer_dtypes:
@@ -124,7 +129,7 @@ def _array_vals():
124129
BIG_INT = int(1e30)
125130
for op, dtypes in binary_op_dtypes.items():
126131
ops = [op]
127-
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
132+
if op not in comparison_ops:
128133
rop = "__r" + op[2:]
129134
iop = "__i" + op[2:]
130135
ops += [rop, iop]
@@ -155,16 +160,16 @@ def _array_vals():
155160
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
156161
)):
157162
if a.dtype in _integer_dtypes and s == BIG_INT:
158-
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
163+
assert_raises(OverflowError, lambda: getattr(a, _op)(s), _op)
159164
else:
160165
# Only test for no error
161-
with suppress_warnings() as sup:
166+
with numpy.testing.suppress_warnings() as sup:
162167
# ignore warnings from pow(BIG_INT)
163168
sup.filter(RuntimeWarning,
164169
"invalid value encountered in power")
165170
getattr(a, _op)(s)
166171
else:
167-
assert_raises(TypeError, lambda: getattr(a, _op)(s))
172+
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
168173

169174
# Test array op array.
170175
for _op in ops:
@@ -188,7 +193,7 @@ def _array_vals():
188193
_op.startswith("__i")
189194
and result_type(x.dtype, y.dtype) != x.dtype
190195
):
191-
assert_raises(TypeError, lambda: getattr(x, _op)(y))
196+
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
192197
# Ensure only those dtypes that are required for every operator are allowed.
193198
elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
194199
or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
@@ -202,7 +207,7 @@ def _array_vals():
202207
):
203208
getattr(x, _op)(y)
204209
else:
205-
assert_raises(TypeError, lambda: getattr(x, _op)(y))
210+
assert_raises(TypeError, lambda: getattr(x, _op)(y), _op)
206211

207212
unary_op_dtypes = {
208213
"__abs__": "numeric",
@@ -221,7 +226,7 @@ def _array_vals():
221226
# Only test for no error
222227
getattr(a, op)()
223228
else:
224-
assert_raises(TypeError, lambda: getattr(a, op)())
229+
assert_raises(TypeError, lambda: getattr(a, op)(), _op)
225230

226231
# Finally, matmul() must be tested separately, because it works a bit
227232
# different from the other operations.
@@ -240,9 +245,9 @@ def _matmul_array_vals():
240245
or type(s) == int and a.dtype in _integer_dtypes):
241246
# Type promotion is valid, but @ is not allowed on 0-D
242247
# inputs, so the error is a ValueError
243-
assert_raises(ValueError, lambda: getattr(a, _op)(s))
248+
assert_raises(ValueError, lambda: getattr(a, _op)(s), _op)
244249
else:
245-
assert_raises(TypeError, lambda: getattr(a, _op)(s))
250+
assert_raises(TypeError, lambda: getattr(a, _op)(s), _op)
246251

247252
for x in _matmul_array_vals():
248253
for y in _matmul_array_vals():
@@ -356,20 +361,17 @@ def test_allow_newaxis():
356361

357362
def test_disallow_flat_indexing_with_newaxis():
358363
a = ones((3, 3, 3))
359-
with pytest.raises(IndexError):
360-
a[None, 0, 0]
364+
assert_raises(IndexError, lambda: a[None, 0, 0])
361365

362366
def test_disallow_mask_with_newaxis():
363367
a = ones((3, 3, 3))
364-
with pytest.raises(IndexError):
365-
a[None, asarray(True)]
368+
assert_raises(IndexError, lambda: a[None, asarray(True)])
366369

367370
@pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
368371
@pytest.mark.parametrize("index", ["string", False, True])
369372
def test_error_on_invalid_index(shape, index):
370373
a = ones(shape)
371-
with pytest.raises(IndexError):
372-
a[index]
374+
assert_raises(IndexError, lambda: a[index])
373375

374376
def test_mask_0d_array_without_errors():
375377
a = ones(())
@@ -380,10 +382,8 @@ def test_mask_0d_array_without_errors():
380382
)
381383
def test_error_on_invalid_index_with_ellipsis(i):
382384
a = ones((3, 3, 3))
383-
with pytest.raises(IndexError):
384-
a[..., i]
385-
with pytest.raises(IndexError):
386-
a[i, ...]
385+
assert_raises(IndexError, lambda: a[..., i])
386+
assert_raises(IndexError, lambda: a[i, ...])
387387

388388
def test_array_keys_use_private_array():
389389
"""
@@ -400,8 +400,7 @@ def test_array_keys_use_private_array():
400400

401401
a = ones((0,), dtype=bool_)
402402
key = ones((0, 0), dtype=bool_)
403-
with pytest.raises(IndexError):
404-
a[key]
403+
assert_raises(IndexError, lambda: a[key])
405404

406405
def test_array_namespace():
407406
a = ones((3, 3))
@@ -422,16 +421,16 @@ def test_array_namespace():
422421
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
423422
assert array_api_strict.__array_api_version__ == "2021.12"
424423

425-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
426-
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
424+
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
425+
assert_raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
427426

428427
def test_iter():
429-
pytest.raises(TypeError, lambda: iter(asarray(3)))
428+
assert_raises(TypeError, lambda: iter(asarray(3)))
430429
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
431430
assert all_(isinstance(a, Array) for a in iter(ones(3)))
432431
assert all_(a.shape == () for a in iter(ones(3)))
433432
assert all_(a.dtype == float64 for a in iter(ones(3)))
434-
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
433+
assert_raises(TypeError, lambda: iter(ones((3, 3))))
435434

436435
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
437436
def dlpack_2023_12(api_version):
@@ -447,17 +446,17 @@ def dlpack_2023_12(api_version):
447446

448447

449448
exception = NotImplementedError if api_version >= '2023.12' else ValueError
450-
pytest.raises(exception, lambda:
449+
assert_raises(exception, lambda:
451450
a.__dlpack__(dl_device=CPU_DEVICE))
452-
pytest.raises(exception, lambda:
451+
assert_raises(exception, lambda:
453452
a.__dlpack__(dl_device=None))
454-
pytest.raises(exception, lambda:
453+
assert_raises(exception, lambda:
455454
a.__dlpack__(max_version=(1, 0)))
456-
pytest.raises(exception, lambda:
455+
assert_raises(exception, lambda:
457456
a.__dlpack__(max_version=None))
458-
pytest.raises(exception, lambda:
457+
assert_raises(exception, lambda:
459458
a.__dlpack__(copy=False))
460-
pytest.raises(exception, lambda:
459+
assert_raises(exception, lambda:
461460
a.__dlpack__(copy=True))
462-
pytest.raises(exception, lambda:
461+
assert_raises(exception, lambda:
463462
a.__dlpack__(copy=None))

0 commit comments

Comments
 (0)