Skip to content

Add type checking in assert_dtype_allclose for inexact dtypes #1634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 5, 2023
Merged
50 changes: 45 additions & 5 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,65 @@
import dpnp


def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True):
def assert_dtype_allclose(
dpnp_arr, numpy_arr, check_type=True, check_only_type_kind=False
):
"""
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
for floating and complex types.
For other dtypes the assertion is based on exact matching of the arrays.

"""

When 'check_type' is True (default), the function asserts:
- Equal dtypes for exact types.
For inexact types:
- If the numpy array's dtype is `numpy.float16`, checks if the device
of the `dpnp_arr` supports 64-bit precision floating point operations.
If supported, asserts equal dtypes.
Otherwise, asserts equal type kinds.
- For other inexact types, asserts equal dtypes if the device of the `dpnp_arr`
supports 64-bit precision floating point operations or if the numpy array's inexact
dtype is not a double precision type.
Otherwise, asserts equal type kinds.
The 'check_only_type_kind' parameter (False by default) asserts only equal type kinds
for all data types supported by DPNP when set to True.
It is effective only when 'check_type' is also set to True.

"""

list_64bit_types = [numpy.float64, numpy.complex128]
is_inexact = lambda x: dpnp.issubdtype(x.dtype, dpnp.inexact)
if is_inexact(dpnp_arr) or is_inexact(numpy_arr):
tol = 8 * max(
dpnp.finfo(dpnp_arr).resolution,
numpy.finfo(numpy_arr.dtype).resolution,
)
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
if check_type:
numpy_arr_dtype = numpy_arr.dtype
dpnp_arr_dtype = dpnp_arr.dtype
dpnp_arr_dev = dpnp_arr.sycl_device

if check_only_type_kind:
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
else:
is_np_arr_f2 = numpy_arr_dtype == numpy.float16

if is_np_arr_f2:
if has_support_aspect16(dpnp_arr_dev):
assert dpnp_arr_dtype == numpy_arr_dtype
elif (
numpy_arr_dtype not in list_64bit_types
or has_support_aspect64(dpnp_arr_dev)
):
assert dpnp_arr_dtype == numpy_arr_dtype
else:
assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind
else:
assert_array_equal(dpnp_arr.asnumpy(), numpy_arr)
if check_type:
assert dpnp_arr.dtype == numpy_arr.dtype
if check_only_type_kind:
assert dpnp_arr.dtype.kind == numpy_arr.dtype.kind
else:
assert dpnp_arr.dtype == numpy_arr.dtype


def get_complex_dtypes(device=None):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_fft(dtype, norm):
np_res = numpy.fft.fft(data, norm=norm)
dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm)

assert_dtype_allclose(dpnp_res, np_res)
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)


@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
Expand All @@ -29,7 +29,7 @@ def test_fft_ndim(dtype, shape, norm):
np_res = numpy.fft.fft(np_data, norm=norm)
dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm)

assert_dtype_allclose(dpnp_res, np_res)
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)


@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
Expand All @@ -44,7 +44,7 @@ def test_fft_ifft(dtype, shape, norm):
np_res = numpy.fft.ifft(np_data, norm=norm)
dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm)

assert_dtype_allclose(dpnp_res, np_res)
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)


@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
Expand All @@ -58,7 +58,7 @@ def test_fft_rfft(dtype, shape):
np_res = numpy.fft.rfft(np_data)
dpnp_res = dpnp.fft.rfft(dpnp_data)

assert_dtype_allclose(dpnp_res, np_res)
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)


@pytest.mark.parametrize(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,16 +1105,19 @@ def test_out_dtypes(self, dtype):
dp_array2 = dpnp.arange(size, dtype=dtype)

dp_out = dpnp.empty(size, dtype=dpnp.complex64)
check_dtype = True
if dtype != dpnp.complex64:
# dtype of out mismatches types of input arrays
with pytest.raises(TypeError):
dpnp.divide(dp_array1, dp_array2, out=dp_out)

# allocate new out with expected type
dp_out = dpnp.empty(size, dtype=dtype)
# Set check_dtype to False as dtype does not match
check_dtype = False

result = dpnp.divide(dp_array1, dp_array2, out=dp_out)
assert_dtype_allclose(result, expected)
assert_dtype_allclose(result, expected, check_type=check_dtype)

@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
Expand Down
8 changes: 7 additions & 1 deletion tests/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@ def test_sum_float(dtype):
)
ia = dpnp.array(a)

# Flag for type check in special cases
# Use only type kinds checks when dpnp handles float32 arrays
# as `dpnp.sum()` and `numpy.sum()` return different dtypes
check_type_kind = dtype == dpnp.float32
for axis in range(len(a)):
result = dpnp.sum(ia, axis=axis)
expected = numpy.sum(a, axis=axis)
assert_dtype_allclose(result, expected)
assert_dtype_allclose(
result, expected, check_only_type_kind=check_type_kind
)


def test_sum_int():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def test_fft_rfft(type, shape, device):
np_res = numpy.fft.rfft(np_data)
dpnp_res = dpnp.fft.rfft(dpnp_data)

assert_dtype_allclose(dpnp_res, np_res)
assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True)

expected_queue = dpnp_data.get_array().sycl_queue
result_queue = dpnp_res.get_array().sycl_queue
Expand Down