Skip to content

Commit a0c2aac

Browse files
authored
Update reduction data types for 2023.12 array API specification, update __array_api_version__ (#1621)
* Increase `__array_api_version__` to 2023.12 Also changes docstrings in _array_api.py * Aligns reductions with 2023.12 array API spec Floating point data types are no longer promoted based on item size * Fix `device` kwarg-only argument being used as positional for calls to `default_dtypes` throughout tests
1 parent 6abcd34 commit a0c2aac

8 files changed

+80
-110
lines changed

dpctl/tensor/_accumulation.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,53 +20,14 @@
2020
import dpctl.tensor as dpt
2121
import dpctl.tensor._tensor_accumulation_impl as tai
2222
import dpctl.tensor._tensor_impl as ti
23-
from dpctl.tensor._type_utils import _to_device_supported_dtype
23+
from dpctl.tensor._type_utils import (
24+
_default_accumulation_dtype,
25+
_default_accumulation_dtype_fp_types,
26+
_to_device_supported_dtype,
27+
)
2428
from dpctl.utils import ExecutionPlacementError
2529

2630

27-
def _default_accumulation_dtype(inp_dt, q):
28-
"""Gives default output data type for given input data
29-
type `inp_dt` when accumulation is performed on queue `q`
30-
"""
31-
inp_kind = inp_dt.kind
32-
if inp_kind in "bi":
33-
res_dt = dpt.dtype(ti.default_device_int_type(q))
34-
if inp_dt.itemsize > res_dt.itemsize:
35-
res_dt = inp_dt
36-
elif inp_kind in "u":
37-
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
38-
res_ii = dpt.iinfo(res_dt)
39-
inp_ii = dpt.iinfo(inp_dt)
40-
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
41-
pass
42-
else:
43-
res_dt = inp_dt
44-
elif inp_kind in "fc":
45-
res_dt = inp_dt
46-
47-
return res_dt
48-
49-
50-
def _default_accumulation_dtype_fp_types(inp_dt, q):
51-
"""Gives default output data type for given input data
52-
type `inp_dt` when accumulation is performed on queue `q`
53-
and the accumulation supports only floating-point data types
54-
"""
55-
inp_kind = inp_dt.kind
56-
if inp_kind in "biu":
57-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
58-
can_cast_v = dpt.can_cast(inp_dt, res_dt)
59-
if not can_cast_v:
60-
_fp64 = q.sycl_device.has_aspect_fp64
61-
res_dt = dpt.float64 if _fp64 else dpt.float32
62-
elif inp_kind in "f":
63-
res_dt = inp_dt
64-
elif inp_kind in "c":
65-
raise ValueError("function not defined for complex types")
66-
67-
return res_dt
68-
69-
7031
def _accumulate_common(
7132
x,
7233
axis,

dpctl/tensor/_array_api.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _isdtype_impl(dtype, kind):
4949
raise TypeError(f"Unsupported data type kind: {kind}")
5050

5151

52-
__array_api_version__ = "2022.12"
52+
__array_api_version__ = "2023.12"
5353

5454

5555
class Info:
@@ -80,6 +80,8 @@ def __init__(self):
8080

8181
def capabilities(self):
8282
"""
83+
capabilities()
84+
8385
Returns a dictionary of `dpctl`'s capabilities.
8486
8587
Returns:
@@ -92,12 +94,16 @@ def capabilities(self):
9294

9395
def default_device(self):
9496
"""
97+
default_device()
98+
9599
Returns the default SYCL device.
96100
"""
97101
return dpctl.select_default_device()
98102

99-
def default_dtypes(self, device=None):
103+
def default_dtypes(self, *, device=None):
100104
"""
105+
default_dtypes(*, device=None)
106+
101107
Returns a dictionary of default data types for `device`.
102108
103109
Args:
@@ -129,8 +135,10 @@ def default_dtypes(self, device=None):
129135
"indexing": dpt.dtype(default_device_index_type(device)),
130136
}
131137

132-
def dtypes(self, device=None, kind=None):
138+
def dtypes(self, *, device=None, kind=None):
133139
"""
140+
dtypes(*, device=None, kind=None)
141+
134142
Returns a dictionary of all Array API data types of a specified `kind`
135143
supported by `device`
136144
@@ -193,13 +201,16 @@ def dtypes(self, device=None, kind=None):
193201

194202
def devices(self):
195203
"""
204+
devices()
205+
196206
Returns a list of supported devices.
197207
"""
198208
return dpctl.get_devices()
199209

200210

201211
def __array_namespace_info__():
202-
"""__array_namespace_info__()
212+
"""
213+
__array_namespace_info__()
203214
204215
Returns a namespace with Array API namespace inspection utilities.
205216

dpctl/tensor/_reduction.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,58 +21,11 @@
2121
import dpctl.tensor._tensor_impl as ti
2222
import dpctl.tensor._tensor_reductions_impl as tri
2323

24-
from ._type_utils import _to_device_supported_dtype
25-
26-
27-
def _default_reduction_dtype(inp_dt, q):
28-
"""Gives default output data type for given input data
29-
type `inp_dt` when reduction is performed on queue `q`
30-
"""
31-
inp_kind = inp_dt.kind
32-
if inp_kind in "bi":
33-
res_dt = dpt.dtype(ti.default_device_int_type(q))
34-
if inp_dt.itemsize > res_dt.itemsize:
35-
res_dt = inp_dt
36-
elif inp_kind in "u":
37-
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
38-
res_ii = dpt.iinfo(res_dt)
39-
inp_ii = dpt.iinfo(inp_dt)
40-
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
41-
pass
42-
else:
43-
res_dt = inp_dt
44-
elif inp_kind in "f":
45-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
46-
if res_dt.itemsize < inp_dt.itemsize:
47-
res_dt = inp_dt
48-
elif inp_kind in "c":
49-
res_dt = dpt.dtype(ti.default_device_complex_type(q))
50-
if res_dt.itemsize < inp_dt.itemsize:
51-
res_dt = inp_dt
52-
53-
return res_dt
54-
55-
56-
def _default_reduction_dtype_fp_types(inp_dt, q):
57-
"""Gives default output data type for given input data
58-
type `inp_dt` when reduction is performed on queue `q`
59-
and the reduction supports only floating-point data types
60-
"""
61-
inp_kind = inp_dt.kind
62-
if inp_kind in "biu":
63-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
64-
can_cast_v = dpt.can_cast(inp_dt, res_dt)
65-
if not can_cast_v:
66-
_fp64 = q.sycl_device.has_aspect_fp64
67-
res_dt = dpt.float64 if _fp64 else dpt.float32
68-
elif inp_kind in "f":
69-
res_dt = dpt.dtype(ti.default_device_fp_type(q))
70-
if res_dt.itemsize < inp_dt.itemsize:
71-
res_dt = inp_dt
72-
elif inp_kind in "c":
73-
raise TypeError("reduction not defined for complex types")
74-
75-
return res_dt
24+
from ._type_utils import (
25+
_default_accumulation_dtype,
26+
_default_accumulation_dtype_fp_types,
27+
_to_device_supported_dtype,
28+
)
7629

7730

7831
def _reduction_over_axis(
@@ -237,7 +190,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
237190
keepdims,
238191
tri._sum_over_axis,
239192
tri._sum_over_axis_dtype_supported,
240-
_default_reduction_dtype,
193+
_default_accumulation_dtype,
241194
)
242195

243196

@@ -299,7 +252,7 @@ def prod(x, axis=None, dtype=None, keepdims=False):
299252
keepdims,
300253
tri._prod_over_axis,
301254
tri._prod_over_axis_dtype_supported,
302-
_default_reduction_dtype,
255+
_default_accumulation_dtype,
303256
)
304257

305258

@@ -356,7 +309,7 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
356309
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
357310
inp_dt, res_dt
358311
),
359-
_default_reduction_dtype_fp_types,
312+
_default_accumulation_dtype_fp_types,
360313
)
361314

362315

@@ -413,7 +366,7 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
413366
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
414367
inp_dt, res_dt
415368
),
416-
_default_reduction_dtype_fp_types,
369+
_default_accumulation_dtype_fp_types,
417370
)
418371

419372

dpctl/tensor/_type_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,49 @@ def isdtype(dtype, kind):
733733
raise TypeError(f"Unsupported data type kind: {kind}")
734734

735735

736+
def _default_accumulation_dtype(inp_dt, q):
737+
"""Gives default output data type for given input data
738+
type `inp_dt` when accumulation is performed on queue `q`
739+
"""
740+
inp_kind = inp_dt.kind
741+
if inp_kind in "bi":
742+
res_dt = dpt.dtype(ti.default_device_int_type(q))
743+
if inp_dt.itemsize > res_dt.itemsize:
744+
res_dt = inp_dt
745+
elif inp_kind in "u":
746+
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
747+
res_ii = dpt.iinfo(res_dt)
748+
inp_ii = dpt.iinfo(inp_dt)
749+
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
750+
pass
751+
else:
752+
res_dt = inp_dt
753+
elif inp_kind in "fc":
754+
res_dt = inp_dt
755+
756+
return res_dt
757+
758+
759+
def _default_accumulation_dtype_fp_types(inp_dt, q):
760+
"""Gives default output data type for given input data
761+
type `inp_dt` when accumulation is performed on queue `q`
762+
and the accumulation supports only floating-point data types
763+
"""
764+
inp_kind = inp_dt.kind
765+
if inp_kind in "biu":
766+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
767+
can_cast_v = dpt.can_cast(inp_dt, res_dt)
768+
if not can_cast_v:
769+
_fp64 = q.sycl_device.has_aspect_fp64
770+
res_dt = dpt.float64 if _fp64 else dpt.float32
771+
elif inp_kind in "f":
772+
res_dt = inp_dt
773+
elif inp_kind in "c":
774+
raise ValueError("function not defined for complex types")
775+
776+
return res_dt
777+
778+
736779
__all__ = [
737780
"_find_buf_dtype",
738781
"_find_buf_dtype2",
@@ -753,4 +796,6 @@ def isdtype(dtype, kind):
753796
"WeakIntegralType",
754797
"WeakFloatingType",
755798
"WeakComplexType",
799+
"_default_accumulation_dtype",
800+
"_default_accumulation_dtype_fp_types",
756801
]

dpctl/tests/test_tensor_array_api_inspection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_array_api_inspection_default_dtypes():
9696

9797
info = dpt.__array_namespace_info__()
9898
default_dts_nodev = info.default_dtypes()
99-
default_dts_dev = info.default_dtypes(dev)
99+
default_dts_dev = info.default_dtypes(device=dev)
100100

101101
assert (
102102
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def test_mixed_index_getitem():
491491
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
492492
i1b = dpt.ones(10, dtype="?")
493493
info = x.__array_namespace__().__array_namespace_info__()
494-
ind_dt = info.default_dtypes(x.device)["indexing"]
494+
ind_dt = info.default_dtypes(device=x.device)["indexing"]
495495
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
496496
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
497497
y = x[i0, i1b, i2]
@@ -503,7 +503,7 @@ def test_mixed_index_setitem():
503503
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
504504
i1b = dpt.ones(10, dtype="?")
505505
info = x.__array_namespace__().__array_namespace_info__()
506-
ind_dt = info.default_dtypes(x.device)["indexing"]
506+
ind_dt = info.default_dtypes(device=x.device)["indexing"]
507507
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
508508
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
509509
v_shape = (3, int(dpt.sum(i1b, dtype="i8")))

dpctl/tests/test_usm_ndarray_reductions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def test_logsumexp_complex():
406406
get_queue_or_skip()
407407

408408
x = dpt.zeros(1, dtype="c8")
409-
with pytest.raises(TypeError):
409+
with pytest.raises(ValueError):
410410
dpt.logsumexp(x)
411411

412412

@@ -470,7 +470,7 @@ def test_hypot_complex():
470470
get_queue_or_skip()
471471

472472
x = dpt.zeros(1, dtype="c8")
473-
with pytest.raises(TypeError):
473+
with pytest.raises(ValueError):
474474
dpt.reduce_hypot(x)
475475

476476

dpctl/tests/test_usm_ndarray_searchsorted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def _check(hay_stack, needles, needles_np):
1212
assert hay_stack.ndim == 1
1313

1414
info_ = dpt.__array_namespace_info__()
15-
default_dts_dev = info_.default_dtypes(hay_stack.device)
15+
default_dts_dev = info_.default_dtypes(device=hay_stack.device)
1616
index_dt = default_dts_dev["indexing"]
1717

1818
p_left = dpt.searchsorted(hay_stack, needles, side="left")

0 commit comments

Comments
 (0)