Skip to content

Commit aa24801

Browse files
committed
Aligns reductions with 2023.12 array API spec
Floating point data types are no longer promoted based on item size
1 parent f6caaa8 commit aa24801

File tree

4 files changed

+61
-102
lines changed

4 files changed

+61
-102
lines changed

dpctl/tensor/_accumulation.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,53 +20,14 @@
2020
import dpctl.tensor as dpt
2121
import dpctl.tensor._tensor_accumulation_impl as tai
2222
import dpctl.tensor._tensor_impl as ti
23-
from dpctl.tensor._type_utils import _to_device_supported_dtype
23+
from dpctl.tensor._type_utils import (
24+
_default_accumulation_dtype,
25+
_default_accumulation_dtype_fp_types,
26+
_to_device_supported_dtype,
27+
)
2428
from dpctl.utils import ExecutionPlacementError
2529

2630

27-
def _default_accumulation_dtype(inp_dt, q):
28-
"""Gives default output data type for given input data
29-
type `inp_dt` when accumulation is performed on queue `q`
30-
"""
31-
inp_kind = inp_dt.kind
32-
if inp_kind in "bi":
33-
res_dt = dpt.dtype(ti.default_device_int_type(q))
34-
if inp_dt.itemsize > res_dt.itemsize:
35-
res_dt = inp_dt
36-
elif inp_kind in "u":
37-
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
38-
res_ii = dpt.iinfo(res_dt)
39-
inp_ii = dpt.iinfo(inp_dt)
40-
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
41-
pass
42-
else:
43-
res_dt = inp_dt
44-
elif inp_kind in "fc":
45-
res_dt = inp_dt
46-
47-
return res_dt
48-
49-
50-
def _default_accumulation_dtype_fp_types(inp_dt, q):
51-
"""Gives default output data type for given input data
52-
type `inp_dt` when accumulation is performed on queue `q`
53-
and the accumulation supports only floating-point data types
54-
"""
55-
inp_kind = inp_dt.kind
56-
if inp_kind in "biu":
57-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
58-
can_cast_v = dpt.can_cast(inp_dt, res_dt)
59-
if not can_cast_v:
60-
_fp64 = q.sycl_device.has_aspect_fp64
61-
res_dt = dpt.float64 if _fp64 else dpt.float32
62-
elif inp_kind in "f":
63-
res_dt = inp_dt
64-
elif inp_kind in "c":
65-
raise ValueError("function not defined for complex types")
66-
67-
return res_dt
68-
69-
7031
def _accumulate_common(
7132
x,
7233
axis,

dpctl/tensor/_reduction.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,58 +21,11 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.tensor._tensor_reductions_impl as tri
2323

24-
from ._type_utils import _to_device_supported_dtype
25-
26-
27-
def _default_reduction_dtype(inp_dt, q):
28-
"""Gives default output data type for given input data
29-
type `inp_dt` when reduction is performed on queue `q`
30-
"""
31-
inp_kind = inp_dt.kind
32-
if inp_kind in "bi":
33-
res_dt = dpt.dtype(ti.default_device_int_type(q))
34-
if inp_dt.itemsize > res_dt.itemsize:
35-
res_dt = inp_dt
36-
elif inp_kind in "u":
37-
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
38-
res_ii = dpt.iinfo(res_dt)
39-
inp_ii = dpt.iinfo(inp_dt)
40-
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
41-
pass
42-
else:
43-
res_dt = inp_dt
44-
elif inp_kind in "f":
45-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
46-
if res_dt.itemsize < inp_dt.itemsize:
47-
res_dt = inp_dt
48-
elif inp_kind in "c":
49-
res_dt = dpt.dtype(ti.default_device_complex_type(q))
50-
if res_dt.itemsize < inp_dt.itemsize:
51-
res_dt = inp_dt
52-
53-
return res_dt
54-
55-
56-
def _default_reduction_dtype_fp_types(inp_dt, q):
57-
"""Gives default output data type for given input data
58-
type `inp_dt` when reduction is performed on queue `q`
59-
and the reduction supports only floating-point data types
60-
"""
61-
inp_kind = inp_dt.kind
62-
if inp_kind in "biu":
63-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
64-
can_cast_v = dpt.can_cast(inp_dt, res_dt)
65-
if not can_cast_v:
66-
_fp64 = q.sycl_device.has_aspect_fp64
67-
res_dt = dpt.float64 if _fp64 else dpt.float32
68-
elif inp_kind in "f":
69-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
70-
if res_dt.itemsize < inp_dt.itemsize:
71-
res_dt = inp_dt
72-
elif inp_kind in "c":
73-
raise TypeError("reduction not defined for complex types")
74-
75-
return res_dt
24+
from ._type_utils import (
25+
_default_accumulation_dtype,
26+
_default_accumulation_dtype_fp_types,
27+
_to_device_supported_dtype,
28+
)
7629

7730

7831
def _reduction_over_axis(
@@ -237,7 +190,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
237190
keepdims,
238191
tri._sum_over_axis,
239192
tri._sum_over_axis_dtype_supported,
240-
_default_reduction_dtype,
193+
_default_accumulation_dtype,
241194
)
242195

243196

@@ -299,7 +252,7 @@ def prod(x, axis=None, dtype=None, keepdims=False):
299252
keepdims,
300253
tri._prod_over_axis,
301254
tri._prod_over_axis_dtype_supported,
302-
_default_reduction_dtype,
255+
_default_accumulation_dtype,
303256
)
304257

305258

@@ -356,7 +309,7 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
356309
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
357310
inp_dt, res_dt
358311
),
359-
_default_reduction_dtype_fp_types,
312+
_default_accumulation_dtype_fp_types,
360313
)
361314

362315

@@ -413,7 +366,7 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
413366
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
414367
inp_dt, res_dt
415368
),
416-
_default_reduction_dtype_fp_types,
369+
_default_accumulation_dtype_fp_types,
417370
)
418371

419372

dpctl/tensor/_type_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,49 @@ def isdtype(dtype, kind):
733733
raise TypeError(f"Unsupported data type kind: {kind}")
734734

735735

736+
def _default_accumulation_dtype(inp_dt, q):
737+
"""Gives default output data type for given input data
738+
type `inp_dt` when accumulation is performed on queue `q`
739+
"""
740+
inp_kind = inp_dt.kind
741+
if inp_kind in "bi":
742+
res_dt = dpt.dtype(ti.default_device_int_type(q))
743+
if inp_dt.itemsize > res_dt.itemsize:
744+
res_dt = inp_dt
745+
elif inp_kind in "u":
746+
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
747+
res_ii = dpt.iinfo(res_dt)
748+
inp_ii = dpt.iinfo(inp_dt)
749+
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
750+
pass
751+
else:
752+
res_dt = inp_dt
753+
elif inp_kind in "fc":
754+
res_dt = inp_dt
755+
756+
return res_dt
757+
758+
759+
def _default_accumulation_dtype_fp_types(inp_dt, q):
760+
"""Gives default output data type for given input data
761+
type `inp_dt` when accumulation is performed on queue `q`
762+
and the accumulation supports only floating-point data types
763+
"""
764+
inp_kind = inp_dt.kind
765+
if inp_kind in "biu":
766+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
767+
can_cast_v = dpt.can_cast(inp_dt, res_dt)
768+
if not can_cast_v:
769+
_fp64 = q.sycl_device.has_aspect_fp64
770+
res_dt = dpt.float64 if _fp64 else dpt.float32
771+
elif inp_kind in "f":
772+
res_dt = inp_dt
773+
elif inp_kind in "c":
774+
raise ValueError("function not defined for complex types")
775+
776+
return res_dt
777+
778+
736779
__all__ = [
737780
"_find_buf_dtype",
738781
"_find_buf_dtype2",
@@ -753,4 +796,6 @@ def isdtype(dtype, kind):
753796
"WeakIntegralType",
754797
"WeakFloatingType",
755798
"WeakComplexType",
799+
"_default_accumulation_dtype",
800+
"_default_accumulation_dtype_fp_types",
756801
]

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def test_logsumexp_complex():
406406
get_queue_or_skip()
407407

408408
x = dpt.zeros(1, dtype="c8")
409-
with pytest.raises(TypeError):
409+
with pytest.raises(ValueError):
410410
dpt.logsumexp(x)
411411

412412

@@ -470,7 +470,7 @@ def test_hypot_complex():
470470
get_queue_or_skip()
471471

472472
x = dpt.zeros(1, dtype="c8")
473-
with pytest.raises(TypeError):
473+
with pytest.raises(ValueError):
474474
dpt.reduce_hypot(x)
475475

476476

0 commit comments

Comments
 (0)