Skip to content

Commit e3d4727

Browse files
committed
address comments
1 parent 683d79e commit e3d4727

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

tests/test_mathematical.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
assert_dtype_allclose,
1717
get_all_dtypes,
1818
get_complex_dtypes,
19+
get_float_complex_dtypes,
1920
get_float_dtypes,
2021
has_support_aspect64,
2122
is_cpu_device,
@@ -966,7 +967,7 @@ def test_invalid_out(self, out):
966967

967968

968969
class TestDivide:
969-
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
970+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
970971
def test_divide(self, dtype):
971972
array1_data = numpy.arange(10)
972973
array2_data = numpy.arange(5, 15)
@@ -983,12 +984,11 @@ def test_divide(self, dtype):
983984
np_array2 = numpy.array(array2_data, dtype=dtype)
984985
expected = numpy.divide(np_array1, np_array2, out=out)
985986

986-
tol = 1e-07
987-
assert_allclose(expected, result, rtol=tol, atol=tol)
988-
assert_allclose(out, dp_out, rtol=tol, atol=tol)
987+
assert_dtype_allclose(result, expected)
988+
assert_dtype_allclose(dp_out, out)
989989

990990
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
991-
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
991+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
992992
def test_out_dtypes(self, dtype):
993993
size = 10
994994

@@ -1010,12 +1010,10 @@ def test_out_dtypes(self, dtype):
10101010
dp_out = dpnp.empty(size, dtype=dtype)
10111011

10121012
result = dpnp.divide(dp_array1, dp_array2, out=dp_out)
1013-
1014-
tol = 1e-07
1015-
assert_allclose(expected, result, rtol=tol, atol=tol)
1013+
assert_dtype_allclose(result, expected)
10161014

10171015
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
1018-
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
1016+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
10191017
def test_out_overlap(self, dtype):
10201018
size = 15
10211019
# DPNP
@@ -1026,10 +1024,9 @@ def test_out_overlap(self, dtype):
10261024
np_a = numpy.arange(2 * size, dtype=dtype)
10271025
numpy.divide(np_a[size::], np_a[::2], out=np_a[:size:])
10281026

1029-
tol = 1e-07
1030-
assert_allclose(np_a, dp_a, rtol=tol, atol=tol)
1027+
assert_dtype_allclose(dp_a, np_a)
10311028

1032-
@pytest.mark.parametrize("dtype", get_float_dtypes() + get_complex_dtypes())
1029+
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
10331030
def test_inplace_strided_out(self, dtype):
10341031
size = 21
10351032

0 commit comments

Comments
 (0)