Skip to content

Commit 80a77e0

Browse files
committed
Implements dpt.cumulative_logsumexp, dpt.cumulative_prod, and dpt.cumulative_sum
The Python bindings for these functions are implemented in a new submodule `_tensor_accumulation_impl`
1 parent 85ccecb commit 80a77e0

16 files changed

+2730
-154
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ set(_tensor_linalg_impl_sources
158158
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
159159
${_linalg_sources}
160160
)
161+
set(_accumulator_sources
162+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
163+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
164+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
165+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
166+
)
167+
set(_tensor_accumulation_impl_sources
168+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
169+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
170+
${_accumulator_sources}
171+
)
161172

162173
set(_py_trgts)
163174

@@ -186,6 +197,11 @@ pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
186197
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
187198
list(APPEND _py_trgts ${python_module_name})
188199

200+
set(python_module_name _tensor_accumulation_impl)
201+
pybind11_add_module(${python_module_name} MODULE ${_tensor_accumulation_impl_sources})
202+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_impl_sources})
203+
list(APPEND _py_trgts ${python_module_name})
204+
189205
set(_clang_prefix "")
190206
if (WIN32)
191207
set(_clang_prefix "/clang:")
@@ -203,6 +219,7 @@ list(APPEND _no_fast_math_sources
203219
${_reduction_sources}
204220
${_sorting_sources}
205221
${_linalg_sources}
222+
${_accumulator_sources}
206223
)
207224

208225
foreach(_src_fn ${_no_fast_math_sources})

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
from dpctl.tensor._usmarray import usm_ndarray
9797
from dpctl.tensor._utility_functions import all, any
9898

99+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
99100
from ._array_api import __array_api_version__, __array_namespace_info__
100101
from ._clip import clip
101102
from ._constants import e, inf, nan, newaxis, pi
@@ -367,4 +368,7 @@
367368
"tensordot",
368369
"vecdot",
369370
"searchsorted",
371+
"cumulative_logsumexp",
372+
"cumulative_prod",
373+
"cumulative_sum",
370374
]

dpctl/tensor/_accumulation.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import operator
18+
19+
from numpy.core.numeric import normalize_axis_index
20+
21+
import dpctl
22+
import dpctl.tensor as dpt
23+
import dpctl.tensor._tensor_accumulation_impl as tai
24+
import dpctl.tensor._tensor_impl as ti
25+
from dpctl.tensor._type_utils import _to_device_supported_dtype
26+
27+
28+
def _default_accumulation_dtype(inp_dt, q):
29+
"""Gives default output data type for given input data
30+
type `inp_dt` when accumulation is performed on queue `q`
31+
"""
32+
inp_kind = inp_dt.kind
33+
if inp_kind in "bi":
34+
res_dt = dpt.dtype(ti.default_device_int_type(q))
35+
if inp_dt.itemsize > res_dt.itemsize:
36+
res_dt = inp_dt
37+
elif inp_kind in "u":
38+
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
39+
res_ii = dpt.iinfo(res_dt)
40+
inp_ii = dpt.iinfo(inp_dt)
41+
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
42+
pass
43+
else:
44+
res_dt = inp_dt
45+
elif inp_kind in "f":
46+
res_dt = inp_dt
47+
elif inp_kind in "c":
48+
res_dt = inp_dt
49+
50+
return res_dt
51+
52+
53+
def _default_accumulation_dtype_fp_types(inp_dt, q):
54+
"""Gives default output data type for given input data
55+
type `inp_dt` when accumulation is performed on queue `q`
56+
and the accumulation supports only floating-point data types
57+
"""
58+
inp_kind = inp_dt.kind
59+
if inp_kind in "biu":
60+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
61+
can_cast_v = dpt.can_cast(inp_dt, res_dt)
62+
if not can_cast_v:
63+
_fp64 = q.sycl_device.has_aspect_fp64
64+
res_dt = dpt.float64 if _fp64 else dpt.float32
65+
elif inp_kind in "f":
66+
res_dt = inp_dt
67+
elif inp_kind in "c":
68+
raise TypeError("reduction not defined for complex types")
69+
70+
return res_dt
71+
72+
73+
def _accumulate_over_axis(
74+
x,
75+
axis,
76+
dtype,
77+
include_initial,
78+
_accumulate_fn,
79+
_accumulate_include_initial_fn,
80+
_dtype_supported,
81+
_default_accumulation_type_fn,
82+
):
83+
if not isinstance(x, dpt.usm_ndarray):
84+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
85+
nd = x.ndim
86+
if axis is None:
87+
if nd > 1:
88+
raise ValueError
89+
axis = 0
90+
else:
91+
axis = operator.index(axis)
92+
axis = normalize_axis_index(axis, nd, "axis")
93+
a1 = axis + 1
94+
if a1 == nd:
95+
perm = list(range(nd))
96+
arr = x
97+
else:
98+
perm = [i for i in range(nd) if i != axis] + [
99+
axis,
100+
]
101+
arr = dpt.permute_dims(x, perm)
102+
q = x.sycl_queue
103+
inp_dt = x.dtype
104+
if dtype is None:
105+
res_dt = _default_accumulation_type_fn(inp_dt, q)
106+
else:
107+
res_dt = dpt.dtype(dtype)
108+
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
109+
sh = arr.shape
110+
res_sh = sh[:-1] + (sh[-1] + 1,) if include_initial else sh
111+
res_usm_type = x.usm_type
112+
113+
host_tasks_list = []
114+
if _dtype_supported(inp_dt, res_dt):
115+
res = dpt.empty(
116+
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
117+
)
118+
if not include_initial:
119+
ht_e, _ = _accumulate_fn(
120+
src=arr,
121+
trailing_dims_to_accumulate=1,
122+
dst=res,
123+
sycl_queue=q,
124+
)
125+
else:
126+
ht_e, _ = _accumulate_include_initial_fn(
127+
src=arr,
128+
dst=res,
129+
sycl_queue=q,
130+
)
131+
host_tasks_list.append(ht_e)
132+
else:
133+
if dtype is None:
134+
raise RuntimeError(
135+
"Automatically determined accumulation data type does not "
136+
"have direct implementation"
137+
)
138+
if _dtype_supported(res_dt, res_dt):
139+
tmp = dpt.empty(
140+
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
141+
)
142+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
143+
src=arr, dst=tmp, sycl_queue=q
144+
)
145+
host_tasks_list.append(ht_e_cpy)
146+
res = dpt.empty(
147+
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
148+
)
149+
if not include_initial:
150+
ht_e, _ = _accumulate_fn(
151+
src=tmp,
152+
trailing_dims_to_accumulate=1,
153+
dst=res,
154+
sycl_queue=q,
155+
depends=[cpy_e],
156+
)
157+
else:
158+
ht_e, _ = _accumulate_include_initial_fn(
159+
src=tmp,
160+
dst=res,
161+
sycl_queue=q,
162+
depends=[cpy_e],
163+
)
164+
host_tasks_list.append(ht_e)
165+
else:
166+
buf_dt = _default_accumulation_dtype(inp_dt, q)
167+
tmp = dpt.empty(
168+
arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
169+
)
170+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
171+
src=arr, dst=tmp, sycl_queue=q
172+
)
173+
tmp_res = dpt.empty(
174+
res_sh, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
175+
)
176+
host_tasks_list.append(ht_e_cpy)
177+
res = dpt.empty(
178+
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
179+
)
180+
if not include_initial:
181+
ht_e, a_e = _accumulate_fn(
182+
src=arr,
183+
trailing_dims_to_accumulate=1,
184+
dst=tmp_res,
185+
sycl_queue=q,
186+
depends=[cpy_e],
187+
)
188+
else:
189+
ht_e, a_e = _accumulate_include_initial_fn(
190+
src=arr,
191+
dst=tmp_res,
192+
sycl_queue=q,
193+
depends=[cpy_e],
194+
)
195+
host_tasks_list.append(ht_e)
196+
ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray(
197+
src=tmp_res, dst=res, sycl_queue=q, depends=[a_e]
198+
)
199+
host_tasks_list.append(ht_e_cpy2)
200+
if a1 != nd:
201+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
202+
res = dpt.permute_dims(res, inv_perm)
203+
dpctl.SyclEvent.wait_for(host_tasks_list)
204+
205+
return res
206+
207+
208+
def cumulative_sum(x, axis=None, dtype=None, include_initial=False):
209+
return _accumulate_over_axis(
210+
x,
211+
axis,
212+
dtype,
213+
include_initial,
214+
tai._cumsum_over_axis,
215+
tai._cumsum_final_axis_include_initial,
216+
tai._cumsum_dtype_supported,
217+
_default_accumulation_dtype,
218+
)
219+
220+
221+
def cumulative_prod(x, axis=None, dtype=None, include_initial=False):
222+
return _accumulate_over_axis(
223+
x,
224+
axis,
225+
dtype,
226+
include_initial,
227+
tai._cumprod_over_axis,
228+
tai._cumprod_final_axis_include_initial,
229+
tai._cumprod_dtype_supported,
230+
_default_accumulation_dtype,
231+
)
232+
233+
234+
def cumulative_logsumexp(x, axis=None, dtype=None, include_initial=False):
235+
return _accumulate_over_axis(
236+
x,
237+
axis,
238+
dtype,
239+
include_initial,
240+
tai._cumlogsumexp_over_axis,
241+
tai._cumlogsumexp_final_axis_include_initial,
242+
tai._cumlogsumexp_dtype_supported,
243+
_default_accumulation_dtype_fp_types,
244+
)

0 commit comments

Comments
 (0)