diff --git a/tests/helper.py b/tests/helper.py index b3d816e769ac..243c61504a50 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -7,14 +7,31 @@ import dpnp -def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True): +def assert_dtype_allclose( + dpnp_arr, numpy_arr, check_type=True, check_only_type_kind=False +): """ Assert DPNP and NumPy array based on maximum dtype resolution of input arrays for floating and complex types. For other dtypes the assertion is based on exact matching of the arrays. - - """ - + When 'check_type' is True (default), the function asserts: + - Equal dtypes for exact types. + For inexact types: + - If the numpy array's dtype is `numpy.float16`, checks if the device + of the `dpnp_arr` supports 64-bit precision floating point operations. + If supported, asserts equal dtypes. + Otherwise, asserts equal type kinds. + - For other inexact types, asserts equal dtypes if the device of the `dpnp_arr` + supports 64-bit precision floating point operations or if the numpy array's inexact + dtype is not a double precision type. + Otherwise, asserts equal type kinds. + The 'check_only_type_kind' parameter (False by default) asserts only equal type kinds + for all data types supported by DPNP when set to True. + It is effective only when 'check_type' is also set to True. + + """ + + list_64bit_types = [numpy.float64, numpy.complex128] is_inexact = lambda x: dpnp.issubdtype(x.dtype, dpnp.inexact) if is_inexact(dpnp_arr) or is_inexact(numpy_arr): tol = 8 * max( @@ -22,10 +39,33 @@ def assert_dtype_allclose(dpnp_arr, numpy_arr, check_type=True): numpy.finfo(numpy_arr.dtype).resolution, ) assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol) + if check_type: + numpy_arr_dtype = numpy_arr.dtype + dpnp_arr_dtype = dpnp_arr.dtype + dpnp_arr_dev = dpnp_arr.sycl_device + + if check_only_type_kind: + assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind + else: + is_np_arr_f2 = numpy_arr_dtype == numpy.float16 + + if is_np_arr_f2: + if has_support_aspect16(dpnp_arr_dev): + assert dpnp_arr_dtype == numpy_arr_dtype + elif ( + numpy_arr_dtype not in list_64bit_types + or has_support_aspect64(dpnp_arr_dev) + ): + assert dpnp_arr_dtype == numpy_arr_dtype + else: + assert dpnp_arr_dtype.kind == numpy_arr_dtype.kind else: assert_array_equal(dpnp_arr.asnumpy(), numpy_arr) if check_type: - assert dpnp_arr.dtype == numpy_arr.dtype + if check_only_type_kind: + assert dpnp_arr.dtype.kind == numpy_arr.dtype.kind + else: + assert dpnp_arr.dtype == numpy_arr.dtype def get_complex_dtypes(device=None): diff --git a/tests/test_fft.py b/tests/test_fft.py index 0d2ea664b58d..b439ef38cce6 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -16,7 +16,7 @@ def test_fft(dtype, norm): np_res = numpy.fft.fft(data, norm=norm) dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm) - assert_dtype_allclose(dpnp_res, np_res) + assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) @@ -29,7 +29,7 @@ def test_fft_ndim(dtype, shape, norm): np_res = numpy.fft.fft(np_data, norm=norm) dpnp_res = dpnp.fft.fft(dpnp_data, norm=norm) - assert_dtype_allclose(dpnp_res, np_res) + assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) @@ -44,7 +44,7 @@ def test_fft_ifft(dtype, shape, norm): np_res = numpy.fft.ifft(np_data, norm=norm) dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm) - assert_dtype_allclose(dpnp_res, np_res) + assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True)) @@ -58,7 +58,7 @@ def test_fft_rfft(dtype, shape): np_res = numpy.fft.rfft(np_data) dpnp_res = dpnp.fft.rfft(dpnp_data) - assert_dtype_allclose(dpnp_res, np_res) + assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) @pytest.mark.parametrize( diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 9c8850aad188..4f751b697fef 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1105,6 +1105,7 @@ def test_out_dtypes(self, dtype): dp_array2 = dpnp.arange(size, dtype=dtype) dp_out = dpnp.empty(size, dtype=dpnp.complex64) + check_dtype = True if dtype != dpnp.complex64: # dtype of out mismatches types of input arrays with pytest.raises(TypeError): @@ -1112,9 +1113,11 @@ def test_out_dtypes(self, dtype): # allocate new out with expected type dp_out = dpnp.empty(size, dtype=dtype) + # Set check_dtype to False as dtype does not match + check_dtype = False result = dpnp.divide(dp_array1, dp_array2, out=dp_out) - assert_dtype_allclose(result, expected) + assert_dtype_allclose(result, expected, check_type=check_dtype) @pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings") @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) diff --git a/tests/test_sum.py b/tests/test_sum.py index 16b17847f270..4104b33a6248 100644 --- a/tests/test_sum.py +++ b/tests/test_sum.py @@ -27,10 +27,16 @@ def test_sum_float(dtype): ) ia = dpnp.array(a) + # Flag for type check in special cases + # Use only type kinds checks when dpnp handles float32 arrays + # as `dpnp.sum()` and `numpy.sum()` return different dtypes + check_type_kind = dtype == dpnp.float32 for axis in range(len(a)): result = dpnp.sum(ia, axis=axis) expected = numpy.sum(a, axis=axis) - assert_dtype_allclose(result, expected) + assert_dtype_allclose( + result, expected, check_only_type_kind=check_type_kind + ) def test_sum_int(): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 265d412b2ac1..3618a9bb4c54 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -875,7 +875,7 @@ def test_fft_rfft(type, shape, device): np_res = numpy.fft.rfft(np_data) dpnp_res = dpnp.fft.rfft(dpnp_data) - assert_dtype_allclose(dpnp_res, np_res) + assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) expected_queue = dpnp_data.get_array().sycl_queue result_queue = dpnp_res.get_array().sycl_queue