Skip to content

Commit e8c088c

Browse files
Update TestSpacing and TestTensordot (#2251)
This PR suggests using `assert_allclose` instead of `assert_equal` with a resolution-based tolerance in `TestSpacing::test_zeros` to fix test failures on CUDA caused by excessive precision for float32. And updates `TestTensordot::test_axes` and `TestTensordot::test_linalg` to increase a factor parameter in `assert_dtype_allclose` for scaling tolerance
1 parent 1bed3bd commit e8c088c

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

dpnp/tests/test_mathematical.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,14 +2111,15 @@ def test_zeros(self, dt):
21112111

21122112
result = dpnp.spacing(ia)
21132113
expected = numpy.spacing(a)
2114+
tol = numpy.finfo(expected.dtype).resolution
21142115
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
2115-
assert_equal(result, expected)
2116+
assert_allclose(result, expected, rtol=tol, atol=tol)
21162117
else:
21172118
# numpy.spacing(-0.0) == numpy.spacing(0.0), i.e. NumPy returns
21182119
# positive value (looks as a bug in NumPy), because for any other
21192120
# negative input the NumPy result will be also a negative value.
21202121
expected[1] *= -1
2121-
assert_equal(result, expected)
2122+
assert_allclose(result, expected, rtol=tol, atol=tol)
21222123

21232124
@pytest.mark.parametrize("dt", get_float_dtypes(no_float16=False))
21242125
@pytest.mark.parametrize("val", [1, 1e-5, 1000])

dpnp/tests/test_product.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ def test_axes(self, dtype, axes):
799799

800800
result = dpnp.tensordot(ia, ib, axes=axes)
801801
expected = numpy.tensordot(a, b, axes=axes)
802-
assert_dtype_allclose(result, expected)
802+
assert_dtype_allclose(result, expected, factor=9)
803803

804804
@pytest.mark.parametrize("dtype1", get_all_dtypes())
805805
@pytest.mark.parametrize("dtype2", get_all_dtypes())
@@ -844,7 +844,7 @@ def test_linalg(self, axes):
844844

845845
result = dpnp.linalg.tensordot(ia, ib, axes=axes)
846846
expected = numpy.linalg.tensordot(a, b, axes=axes)
847-
assert_dtype_allclose(result, expected)
847+
assert_dtype_allclose(result, expected, factor=9)
848848

849849
def test_error(self):
850850
a = 5

0 commit comments

Comments
 (0)