Skip to content

Commit 71810a4

Browse files
Merge master into test_diff_order_fft
2 parents d82c026 + 904227e commit 71810a4

File tree

8 files changed

+37
-62
lines changed

8 files changed

+37
-62
lines changed

.github/workflows/conda-package.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ jobs:
148148
149149
- name: Test conda channel
150150
run: |
151-
mamba search ${{ env.PACKAGE_NAME }} -c ${{ env.channel-path }} --override-channels --info --json > ${{ env.ver-json-path }}
151+
conda search ${{ env.PACKAGE_NAME }} -c ${{ env.channel-path }} --override-channels --info --json > ${{ env.ver-json-path }}
152152
cat ${{ env.ver-json-path }}
153153
154154
- name: Get package version
@@ -182,7 +182,7 @@ jobs:
182182
id: run_tests_linux
183183
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
184184
with:
185-
timeout_minutes: 10
185+
timeout_minutes: 12
186186
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
187187
retry_on: any
188188
command: |
@@ -264,7 +264,7 @@ jobs:
264264
- name: Test conda channel
265265
run: |
266266
@echo on
267-
mamba search ${{ env.PACKAGE_NAME }} -c ${{ env.channel-path }} --override-channels --info --json > ${{ env.ver-json-path }}
267+
conda search ${{ env.PACKAGE_NAME }} -c ${{ env.channel-path }} --override-channels --info --json > ${{ env.ver-json-path }}
268268
269269
- name: Dump version.json
270270
run: more ${{ env.ver-json-path }}

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
* sycl::ext::oneapi::experimental::properties was added.
4141
*/
4242
#ifndef __SYCL_COMPILER_REDUCTION_PROPERTIES_SUPPORT
43-
#define __SYCL_COMPILER_REDUCTION_PROPERTIES_SUPPORT 20241210L
43+
#define __SYCL_COMPILER_REDUCTION_PROPERTIES_SUPPORT 20241208L
4444
#endif
4545

4646
namespace mkl_blas = oneapi::mkl::blas;

dpnp/backend/kernels/elementwise_functions/i0.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
* sycl::ext::intel::math::cyl_bessel_i0(x) is fully resolved.
3333
*/
3434
#ifndef __SYCL_COMPILER_BESSEL_I0_SUPPORT
35-
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241210L
35+
#define __SYCL_COMPILER_BESSEL_I0_SUPPORT 20241208L
3636
#endif
3737

3838
#if __SYCL_COMPILER_VERSION >= __SYCL_COMPILER_BESSEL_I0_SUPPORT

dpnp/dpnp_iface_mathematical.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _get_max_min(dtype):
152152
return f.max, f.min
153153

154154

155-
def _get_reduction_res_dt(a, dtype, _out):
155+
def _get_reduction_res_dt(a, dtype):
156156
"""Get a data type used by dpctl for result array in reduction function."""
157157

158158
if dtype is None:
@@ -1106,11 +1106,10 @@ def cumprod(a, axis=None, dtype=None, out=None):
11061106
usm_a = dpnp.get_usm_ndarray(a)
11071107

11081108
return dpnp_wrap_reduction_call(
1109-
a,
1109+
usm_a,
11101110
out,
11111111
dpt.cumulative_prod,
1112-
_get_reduction_res_dt,
1113-
usm_a,
1112+
_get_reduction_res_dt(a, dtype),
11141113
axis=axis,
11151114
dtype=dtype,
11161115
)
@@ -1196,11 +1195,10 @@ def cumsum(a, axis=None, dtype=None, out=None):
11961195
usm_a = dpnp.get_usm_ndarray(a)
11971196

11981197
return dpnp_wrap_reduction_call(
1199-
a,
1198+
usm_a,
12001199
out,
12011200
dpt.cumulative_sum,
1202-
_get_reduction_res_dt,
1203-
usm_a,
1201+
_get_reduction_res_dt(a, dtype),
12041202
axis=axis,
12051203
dtype=dtype,
12061204
)
@@ -1281,11 +1279,10 @@ def cumulative_prod(
12811279
"""
12821280

12831281
return dpnp_wrap_reduction_call(
1284-
x,
1282+
dpnp.get_usm_ndarray(x),
12851283
out,
12861284
dpt.cumulative_prod,
1287-
_get_reduction_res_dt,
1288-
dpnp.get_usm_ndarray(x),
1285+
_get_reduction_res_dt(x, dtype),
12891286
axis=axis,
12901287
dtype=dtype,
12911288
include_initial=include_initial,
@@ -1373,11 +1370,10 @@ def cumulative_sum(
13731370
"""
13741371

13751372
return dpnp_wrap_reduction_call(
1376-
x,
1373+
dpnp.get_usm_ndarray(x),
13771374
out,
13781375
dpt.cumulative_sum,
1379-
_get_reduction_res_dt,
1380-
dpnp.get_usm_ndarray(x),
1376+
_get_reduction_res_dt(x, dtype),
13811377
axis=axis,
13821378
dtype=dtype,
13831379
include_initial=include_initial,
@@ -3524,11 +3520,10 @@ def prod(
35243520
usm_a = dpnp.get_usm_ndarray(a)
35253521

35263522
return dpnp_wrap_reduction_call(
3527-
a,
3523+
usm_a,
35283524
out,
35293525
dpt.prod,
3530-
_get_reduction_res_dt,
3531-
usm_a,
3526+
_get_reduction_res_dt(a, dtype),
35323527
axis=axis,
35333528
dtype=dtype,
35343529
keepdims=keepdims,
@@ -4297,11 +4292,10 @@ def sum(
42974292

42984293
usm_a = dpnp.get_usm_ndarray(a)
42994294
return dpnp_wrap_reduction_call(
4300-
a,
4295+
usm_a,
43014296
out,
43024297
dpt.sum,
4303-
_get_reduction_res_dt,
4304-
usm_a,
4298+
_get_reduction_res_dt(a, dtype),
43054299
axis=axis,
43064300
dtype=dtype,
43074301
keepdims=keepdims,

dpnp/dpnp_iface_searching.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@
4848
__all__ = ["argmax", "argmin", "argwhere", "searchsorted", "where"]
4949

5050

51-
def _get_search_res_dt(a, _dtype, out):
51+
def _get_search_res_dt(a, out):
5252
"""Get a data type used by dpctl for result array in search function."""
5353

5454
# get a data type used by dpctl for result array in search function
5555
res_dt = dti.default_device_index_type(a.sycl_device)
5656

5757
# numpy raises TypeError if "out" data type mismatch default index type
58-
if not dpnp.can_cast(out.dtype, res_dt, casting="safe"):
58+
if out is not None and not dpnp.can_cast(out.dtype, res_dt, casting="safe"):
5959
raise TypeError(
6060
f"Cannot cast from {out.dtype} to {res_dt} "
6161
"according to the rule safe."
@@ -143,11 +143,10 @@ def argmax(a, axis=None, out=None, *, keepdims=False):
143143

144144
usm_a = dpnp.get_usm_ndarray(a)
145145
return dpnp_wrap_reduction_call(
146-
a,
146+
usm_a,
147147
out,
148148
dpt.argmax,
149-
_get_search_res_dt,
150-
usm_a,
149+
_get_search_res_dt(a, out),
151150
axis=axis,
152151
keepdims=keepdims,
153152
)
@@ -234,11 +233,10 @@ def argmin(a, axis=None, out=None, *, keepdims=False):
234233

235234
usm_a = dpnp.get_usm_ndarray(a)
236235
return dpnp_wrap_reduction_call(
237-
a,
236+
usm_a,
238237
out,
239238
dpt.argmin,
240-
_get_search_res_dt,
241-
usm_a,
239+
_get_search_res_dt(a, out),
242240
axis=axis,
243241
keepdims=keepdims,
244242
)

dpnp/dpnp_iface_statistics.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ def _count_reduce_items(arr, axis, where=True):
115115
return items
116116

117117

118-
def _get_comparison_res_dt(a, _dtype, _out):
119-
"""Get a data type used by dpctl for result array in comparison function."""
120-
121-
return a.dtype
122-
123-
124118
def amax(a, axis=None, out=None, keepdims=False, initial=None, where=True):
125119
"""
126120
Return the maximum of an array or maximum along an axis.
@@ -760,11 +754,10 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
760754
usm_a = dpnp.get_usm_ndarray(a)
761755

762756
return dpnp_wrap_reduction_call(
763-
a,
757+
usm_a,
764758
out,
765759
dpt.max,
766-
_get_comparison_res_dt,
767-
usm_a,
760+
a.dtype,
768761
axis=axis,
769762
keepdims=keepdims,
770763
)
@@ -1026,11 +1019,10 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
10261019
usm_a = dpnp.get_usm_ndarray(a)
10271020

10281021
return dpnp_wrap_reduction_call(
1029-
a,
1022+
usm_a,
10301023
out,
10311024
dpt.min,
1032-
_get_comparison_res_dt,
1033-
usm_a,
1025+
a.dtype,
10341026
axis=axis,
10351027
keepdims=keepdims,
10361028
)

dpnp/dpnp_iface_trigonometric.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
]
9999

100100

101-
def _get_accumulation_res_dt(a, dtype, _out):
101+
def _get_accumulation_res_dt(a, dtype):
102102
"""Get a dtype used by dpctl for result array in accumulation function."""
103103

104104
if dtype is None:
@@ -893,11 +893,10 @@ def cumlogsumexp(
893893
usm_x = dpnp.get_usm_ndarray(x)
894894

895895
return dpnp_wrap_reduction_call(
896-
x,
896+
usm_x,
897897
out,
898898
dpt.cumulative_logsumexp,
899-
_get_accumulation_res_dt,
900-
usm_x,
899+
_get_accumulation_res_dt(x, dtype),
901900
axis=axis,
902901
dtype=dtype,
903902
include_initial=include_initial,
@@ -1705,11 +1704,10 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
17051704

17061705
usm_x = dpnp.get_usm_ndarray(x)
17071706
return dpnp_wrap_reduction_call(
1708-
x,
1707+
usm_x,
17091708
out,
17101709
dpt.logsumexp,
1711-
_get_accumulation_res_dt,
1712-
usm_x,
1710+
_get_accumulation_res_dt(x, dtype),
17131711
axis=axis,
17141712
dtype=dtype,
17151713
keepdims=keepdims,
@@ -1952,11 +1950,10 @@ def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
19521950

19531951
usm_x = dpnp.get_usm_ndarray(x)
19541952
return dpnp_wrap_reduction_call(
1955-
x,
1953+
usm_x,
19561954
out,
19571955
dpt.reduce_hypot,
1958-
_get_accumulation_res_dt,
1959-
usm_x,
1956+
_get_accumulation_res_dt(x, dtype),
19601957
axis=axis,
19611958
dtype=dtype,
19621959
keepdims=keepdims,

dpnp/dpnp_utils/dpnp_utils_reduction.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
__all__ = ["dpnp_wrap_reduction_call"]
3030

3131

32-
def dpnp_wrap_reduction_call(
33-
a, out, _reduction_fn, _get_res_dt_fn, *args, **kwargs
34-
):
32+
def dpnp_wrap_reduction_call(usm_a, out, _reduction_fn, res_dt, **kwargs):
3533
"""Wrap a reduction call from dpctl.tensor interface."""
3634

3735
input_out = out
@@ -40,16 +38,12 @@ def dpnp_wrap_reduction_call(
4038
else:
4139
dpnp.check_supported_arrays_type(out)
4240

43-
# fetch dtype from the passed kwargs to the reduction call
44-
dtype = kwargs.get("dtype", None)
45-
4641
# dpctl requires strict data type matching of out array with the result
47-
res_dt = _get_res_dt_fn(a, dtype, out)
4842
if out.dtype != res_dt:
4943
out = dpnp.astype(out, dtype=res_dt, copy=False)
5044

5145
usm_out = dpnp.get_usm_ndarray(out)
5246

5347
kwargs["out"] = usm_out
54-
res_usm = _reduction_fn(*args, **kwargs)
48+
res_usm = _reduction_fn(usm_a, **kwargs)
5549
return dpnp.get_result_array(res_usm, input_out, casting="unsafe")

0 commit comments

Comments
 (0)