Skip to content

Commit 2c8cbb5

Browse files
Add type checking in assert_dtype_allclose for inexact dtypes (#1634)
* Add dtype checking for inexact dtype in assert_dtype_allclose * Update test_out_dtypes in TestDivide * Add a check for support of 16 bit types * Add an empty line after the description * fix condition when numpy`s array is not float16 * Address the remarks * Update test_sum_float in test_sum.py * Add a new check_only_type_kind param to assert_dtype_allclose * Update test_sum and test_fft * Use check_only_type_kind in test_fft_rfft
1 parent ad90f66 commit 2c8cbb5

File tree

5 files changed

+61
-12
lines changed

5 files changed

+61
-12
lines changed

tests/helper.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,65 @@
77
import dpnp
88

99

10-
def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True):
10+
def assert_dtype_allclose(
11+
dpnp_arr, numpy_arr, check_type=True, check_only_type_kind=False
12+
):
1113
"""
1214
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
1315
for floating and complex types.
1416
For other dtypes the assertion is based on exact matching of the arrays.
15-
16-
"""
17-
17+
When 'check_type' is True (default), the function asserts:
18+
- Equal dtypes for exact types.
19+
For inexact types:
20+
- If the numpy array's dtype is `numpy.float16`, checks if the device
21+
of the `dpnp_arr` supports 64-bit precision floating point operations.
22+
If supported, asserts equal dtypes.
23+
Otherwise, asserts equal type kinds.
24+
- For other inexact types, asserts equal dtypes if the device of the `dpnp_arr`
25+
supports 64-bit precision floating point operations or if the numpy array's inexact
26+
dtype is not a double precision type.
27+
Otherwise, asserts equal type kinds.
28+
The 'check_only_type_kind' parameter (False by default) asserts only equal type kinds
29+
for all data types supported by DPNP when set to True.
30+
It is effective only when 'check_type' is also set to True.
31+
32+
"""
33+
34+
list_64bit_types = [numpy.float64, numpy.complex128]
1835
is_inexact = lambda x: dpnp.issubdtype(x.dtype, dpnp.inexact)
1936
if is_inexact(dpnp_arr) or is_inexact(numpy_arr):
2037
tol = 8 * max(
2138
dpnp.finfo(dpnp_arr).resolution,
2239
numpy.finfo(numpy_arr.dtype).resolution,
2340
)
2441
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
42+
if check_type:
43+
numpy_arr_dtype = numpy_arr.dtype
44+
dpnp_arr_dtype = dpnp_arr.dtype
45+
dpnp_arr_dev = dpnp_arr.sycl_device
46+
47+
if check_only_type_kind:
48+
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
49+
else:
50+
is_np_arr_f2 = numpy_arr_dtype == numpy.float16
51+
52+
if is_np_arr_f2:
53+
if has_support_aspect16(dpnp_arr_dev):
54+
assert dpnp_arr_dtype == numpy_arr_dtype
55+
elif (
56+
numpy_arr_dtype not in list_64bit_types
57+
or has_support_aspect64(dpnp_arr_dev)
58+
):
59+
assert dpnp_arr_dtype == numpy_arr_dtype
60+
else:
61+
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
2562
else:
2663
assert_array_equal(dpnp_arr.asnumpy(), numpy_arr)
2764
if check_type:
28-
assert dpnp_arr.dtype == numpy_arr.dtype
65+
if check_only_type_kind:
66+
assert dpnp_arr.dtype.kind == numpy_arr.dtype.kind
67+
else:
68+
assert dpnp_arr.dtype == numpy_arr.dtype
2969

3070

3171
def get_complex_dtypes(device=None):

tests/test_fft.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_fft(dtype, norm):
1616
np_res = numpy.fft.fft(data, norm=norm)
1717
dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm)
1818

19-
assert_dtype_allclose(dpnp_res, np_res)
19+
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)
2020

2121

2222
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@@ -29,7 +29,7 @@ def test_fft_ndim(dtype, shape, norm):
2929
np_res = numpy.fft.fft(np_data, norm=norm)
3030
dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm)
3131

32-
assert_dtype_allclose(dpnp_res, np_res)
32+
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)
3333

3434

3535
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@@ -44,7 +44,7 @@ def test_fft_ifft(dtype, shape, norm):
4444
np_res = numpy.fft.ifft(np_data, norm=norm)
4545
dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm)
4646

47-
assert_dtype_allclose(dpnp_res, np_res)
47+
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)
4848

4949

5050
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
@@ -58,7 +58,7 @@ def test_fft_rfft(dtype, shape):
5858
np_res = numpy.fft.rfft(np_data)
5959
dpnp_res = dpnp.fft.rfft(dpnp_data)
6060

61-
assert_dtype_allclose(dpnp_res, np_res)
61+
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)
6262

6363

6464
@pytest.mark.parametrize(

tests/test_mathematical.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,16 +1105,19 @@ def test_out_dtypes(self, dtype):
11051105
dp_array2 = dpnp.arange(size, dtype=dtype)
11061106

11071107
dp_out = dpnp.empty(size, dtype=dpnp.complex64)
1108+
check_dtype = True
11081109
if dtype != dpnp.complex64:
11091110
# dtype of out mismatches types of input arrays
11101111
with pytest.raises(TypeError):
11111112
dpnp.divide(dp_array1, dp_array2, out=dp_out)
11121113

11131114
# allocate new out with expected type
11141115
dp_out = dpnp.empty(size, dtype=dtype)
1116+
# Set check_dtype to False as dtype does not match
1117+
check_dtype = False
11151118

11161119
result = dpnp.divide(dp_array1, dp_array2, out=dp_out)
1117-
assert_dtype_allclose(result, expected)
1120+
assert_dtype_allclose(result, expected, check_type=check_dtype)
11181121

11191122
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
11201123
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())

tests/test_sum.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,16 @@ def test_sum_float(dtype):
2727
)
2828
ia = dpnp.array(a)
2929

30+
# Flag for type check in special cases
31+
# Use only type kinds checks when dpnp handles float32 arrays
32+
# as `dpnp.sum()` and `numpy.sum()` return different dtypes
33+
check_type_kind = dtype == dpnp.float32
3034
for axis in range(len(a)):
3135
result = dpnp.sum(ia, axis=axis)
3236
expected = numpy.sum(a, axis=axis)
33-
assert_dtype_allclose(result, expected)
37+
assert_dtype_allclose(
38+
result, expected, check_only_type_kind=check_type_kind
39+
)
3440

3541

3642
def test_sum_int():

tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def test_fft_rfft(type, shape, device):
875875
np_res = numpy.fft.rfft(np_data)
876876
dpnp_res = dpnp.fft.rfft(dpnp_data)
877877

878-
assert_dtype_allclose(dpnp_res, np_res)
878+
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)
879879

880880
expected_queue = dpnp_data.get_array().sycl_queue
881881
result_queue = dpnp_res.get_array().sycl_queue

0 commit comments

Comments
 (0)