From 07c075bdb1e087ea15e14ffa11846e49679c45c6 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 7 Nov 2023 10:31:41 -0800 Subject: [PATCH 1/4] Corrected argmin/argmax docstring Removed mention of dtype kwarg in usage line --- dpctl/tensor/_reduction.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 0cd302cccc..9b078211fc 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -445,7 +445,7 @@ def _comparison_over_axis(x, axis, keepdims, _reduction_fn): def max(x, axis=None, keepdims=False): - """max(x, axis=None, dtype=None, keepdims=False) + """max(x, axis=None, keepdims=False) Calculates the maximum value of the input array `x`. @@ -473,7 +473,7 @@ def max(x, axis=None, keepdims=False): def min(x, axis=None, keepdims=False): - """min(x, axis=None, dtype=None, keepdims=False) + """min(x, axis=None, keepdims=False) Calculates the minimum value of the input array `x`. @@ -550,7 +550,7 @@ def _search_over_axis(x, axis, keepdims, _reduction_fn): def argmax(x, axis=None, keepdims=False): - """argmax(x, axis=None, dtype=None, keepdims=False) + """argmax(x, axis=None, keepdims=False) Returns the indices of the maximum values of the input array `x` along a specified axis. @@ -582,7 +582,7 @@ def argmax(x, axis=None, keepdims=False): def argmin(x, axis=None, keepdims=False): - """argmin(x, axis=None, dtype=None, keepdims=False) + """argmin(x, axis=None, keepdims=False) Returns the indices of the minimum values of the input array `x` along a specified axis. From 80e2f29e4dfbb4f7a9a3c198f0d5314b132d1044 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 7 Nov 2023 10:47:31 -0800 Subject: [PATCH 2/4] Fixed gh-1468 Function _reduce_over_axis promotes input array to requested result data type and carries out reduction computation in that data type. This is done in dtype if implementation supports it. If implementation does not support the requested dtype, we reduce in the default_dtype, and cast to the request dtype afterwards. --- dpctl/tensor/_reduction.py | 62 ++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 9b078211fc..79ce231901 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -118,7 +118,7 @@ def _reduction_over_axis( dpt.full( res_shape, _identity, - dtype=_default_reduction_type_fn(inp_dt, q), + dtype=dtype, usm_type=res_usm_type, sycl_queue=q, ), @@ -142,21 +142,51 @@ def _reduction_over_axis( "Automatically determined reduction data type does not " "have direct implementation" ) - tmp_dt = _default_reduction_type_fn(inp_dt, q) - tmp = dpt.empty( - res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q - ) - ht_e_tmp, r_e = _reduction_fn( - src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q - ) - host_tasks_list.append(ht_e_tmp) - res = dpt.empty( - res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) - ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray( - src=tmp, dst=res, sycl_queue=q, depends=[r_e] - ) - host_tasks_list.append(ht_e) + if _dtype_supported(res_dt, res_dt, res_usm_type, q): + tmp = dpt.empty( + arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=tmp, sycl_queue=q + ) + host_tasks_list.append(ht_e_cpy) + res = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_red, _ = _reduction_fn( + src=tmp, + trailing_dims_to_reduce=red_nd, + dst=res, + sycl_queue=q, + depends=[cpy_e], + ) + host_tasks_list.append(ht_e_red) + else: + buf_dt = _default_reduction_type_fn(inp_dt, q) + tmp = dpt.empty( + arr2.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=tmp, sycl_queue=q + ) + tmp_res = dpt.empty( + res_shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + host_tasks_list.append(ht_e_cpy) + res = dpt.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_red, r_e = _reduction_fn( + src=tmp, + trailing_dims_to_reduce=red_nd, + dst=tmp_res, + sycl_queue=q, + depends=[cpy_e], + ) + ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp_res, dst=res, sycl_queue=q, depends=[r_e] + ) + host_tasks_list.append(ht_e_cpy2) if keepdims: res_shape = res_shape + (1,) * red_nd From ff9b5ebac31c874b91ef8b834907a3145a9c1c49 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 7 Nov 2023 11:06:16 -0600 Subject: [PATCH 3/4] Added a test based on gh-1468 --- dpctl/tests/test_tensor_sum.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 749ca055b9..33fe4a8b4f 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -329,3 +329,12 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype): assert isinstance(r, dpt.usm_ndarray) assert r.dtype == dpt.dtype(out_dtype) assert dpt.all(r == 1) + + +def test_gh_1468(): + "See https://github.com/IntelPython/dpctl/issues/1468" + get_queue_or_skip() + + a = dpt.full((2, 3, 4), 123456789, dtype=dpt.int32) + t = dpt.sum(a, dtype="f4") + assert t > 0 From ca2c6aa9f9a0df34471c93c7b80fdec950576f80 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 7 Nov 2023 17:10:21 -0600 Subject: [PATCH 4/4] Removed redundant asdtype function call --- dpctl/tensor/_reduction.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/dpctl/tensor/_reduction.py b/dpctl/tensor/_reduction.py index 79ce231901..f797d24b0b 100644 --- a/dpctl/tensor/_reduction.py +++ b/dpctl/tensor/_reduction.py @@ -114,15 +114,12 @@ def _reduction_over_axis( res_shape = res_shape + (1,) * red_nd inv_perm = sorted(range(nd), key=lambda d: perm[d]) res_shape = tuple(res_shape[i] for i in inv_perm) - return dpt.astype( - dpt.full( - res_shape, - _identity, - dtype=dtype, - usm_type=res_usm_type, - sycl_queue=q, - ), - res_dt, + return dpt.full( + res_shape, + _identity, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=q, ) if red_nd == 0: return dpt.astype(x, res_dt, copy=False)