Skip to content

Commit 03fd737

Browse files
Implements dpctl.tensor.logsumexp and dpctl.tensor.reduce_hypot (#1446)
* Implements logsumexp and reduce_hypot * Implements dedicated kernels for temp reductions over axes 1 and 0 in contiguous arrays * logsumexp and reduce_hypot no longer use atomics This change was made to improve the accuracy of these functions * Adds tests for reduce_hypot and logsumexp * Arithmetic reductions no longer use atomics for inexact types This change is intended to improve the numerical stability of sum and prod * Removed support of atomic reduction for max and min * Adds new tests for reductions * Split reductions into multiple source files * Remove unneccessary imports of reduction init functions * Added functions for querying reduction atomic support per type per function * Corrected ``min`` contig variant typo These variants were using ``sycl::maximum`` rather than ``sycl::minimum`` * Removes _tree_reduction_over_axis Use lambdas to ignore atomic-specific arguments to hypot and logsumexp dtype_supported functions * Always use atomic implementation for min/max if available For add/multiplies reductions, use tree reduction for FP types, real and complex, to get better round-off accumulation properties. * ``logaddexp`` implementation moved to math_utils Reduces code repetition between logsumexp and logaddexp --------- Co-authored-by: Oleksandr Pavlyk <oleksandr.pavlyk@intel.com>
1 parent 2eba93e commit 03fd737

32 files changed

+6735
-2402
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ set(_elementwise_sources
102102
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
103103
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
104104
)
105+
set(_reduction_sources
106+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp
107+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
108+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
109+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
110+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp
111+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp
112+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp
113+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp
114+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
115+
)
105116
set(_tensor_impl_sources
106117
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
107118
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
@@ -120,11 +131,11 @@ set(_tensor_impl_sources
120131
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/boolean_reductions.cpp
121132
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/device_support_queries.cpp
122133
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp
123-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
124134
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
125135
)
126136
list(APPEND _tensor_impl_sources
127137
${_elementwise_sources}
138+
${_reduction_sources}
128139
)
129140

130141
set(python_module_name _tensor_impl)
@@ -138,12 +149,13 @@ endif()
138149
set(_no_fast_math_sources
139150
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
140151
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
141-
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reduction_over_axis.cpp
142152
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp
143153
)
144154
list(APPEND _no_fast_math_sources
145155
${_elementwise_sources}
156+
${_reduction_sources}
146157
)
158+
147159
foreach(_src_fn ${_no_fast_math_sources})
148160
get_source_file_property(_cmpl_options_prop ${_src_fn} COMPILE_OPTIONS)
149161
set(_combined_options_prop ${_cmpl_options_prop} "${_clang_prefix}-fno-fast-math")

dpctl/tensor/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,16 @@
165165
tanh,
166166
trunc,
167167
)
168-
from ._reduction import argmax, argmin, max, min, prod, sum
168+
from ._reduction import (
169+
argmax,
170+
argmin,
171+
logsumexp,
172+
max,
173+
min,
174+
prod,
175+
reduce_hypot,
176+
sum,
177+
)
169178
from ._testing import allclose
170179

171180
__all__ = [
@@ -324,4 +333,6 @@
324333
"copysign",
325334
"rsqrt",
326335
"clip",
336+
"logsumexp",
337+
"reduce_hypot",
327338
]

dpctl/tensor/_reduction.py

Lines changed: 148 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ def _default_reduction_dtype(inp_dt, q):
5252
return res_dt
5353

5454

55+
def _default_reduction_dtype_fp_types(inp_dt, q):
56+
"""Gives default output data type for given input data
57+
type `inp_dt` when reduction is performed on queue `q`
58+
and the reduction supports only floating-point data types
59+
"""
60+
inp_kind = inp_dt.kind
61+
if inp_kind in "biu":
62+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
63+
can_cast_v = dpt.can_cast(inp_dt, res_dt)
64+
if not can_cast_v:
65+
_fp64 = q.sycl_device.has_aspect_fp64
66+
res_dt = dpt.float64 if _fp64 else dpt.float32
67+
elif inp_kind in "f":
68+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
69+
if res_dt.itemsize < inp_dt.itemsize:
70+
res_dt = inp_dt
71+
elif inp_kind in "c":
72+
raise TypeError("reduction not defined for complex types")
73+
74+
return res_dt
75+
76+
5577
def _reduction_over_axis(
5678
x,
5779
axis,
@@ -91,12 +113,15 @@ def _reduction_over_axis(
91113
res_shape = res_shape + (1,) * red_nd
92114
inv_perm = sorted(range(nd), key=lambda d: perm[d])
93115
res_shape = tuple(res_shape[i] for i in inv_perm)
94-
return dpt.full(
95-
res_shape,
96-
_identity,
97-
dtype=res_dt,
98-
usm_type=res_usm_type,
99-
sycl_queue=q,
116+
return dpt.astype(
117+
dpt.full(
118+
res_shape,
119+
_identity,
120+
dtype=_default_reduction_type_fn(inp_dt, q),
121+
usm_type=res_usm_type,
122+
sycl_queue=q,
123+
),
124+
res_dt,
100125
)
101126
if red_nd == 0:
102127
return dpt.astype(x, res_dt, copy=False)
@@ -116,7 +141,7 @@ def _reduction_over_axis(
116141
"Automatically determined reduction data type does not "
117142
"have direct implementation"
118143
)
119-
tmp_dt = _default_reduction_dtype(inp_dt, q)
144+
tmp_dt = _default_reduction_type_fn(inp_dt, q)
120145
tmp = dpt.empty(
121146
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
122147
)
@@ -161,13 +186,13 @@ def sum(x, axis=None, dtype=None, keepdims=False):
161186
the returned array will have the default real-valued
162187
floating-point data type for the device where input
163188
array `x` is allocated.
164-
* If x` has signed integral data type, the returned array
189+
* If `x` has signed integral data type, the returned array
165190
will have the default signed integral type for the device
166191
where input array `x` is allocated.
167192
* If `x` has unsigned integral data type, the returned array
168193
will have the default unsigned integral type for the device
169194
where input array `x` is allocated.
170-
* If `x` has a complex-valued floating-point data typee,
195+
* If `x` has a complex-valued floating-point data type,
171196
the returned array will have the default complex-valued
172197
floating-pointer data type for the device where input
173198
array `x` is allocated.
@@ -222,13 +247,13 @@ def prod(x, axis=None, dtype=None, keepdims=False):
222247
the returned array will have the default real-valued
223248
floating-point data type for the device where input
224249
array `x` is allocated.
225-
* If x` has signed integral data type, the returned array
250+
* If `x` has signed integral data type, the returned array
226251
will have the default signed integral type for the device
227252
where input array `x` is allocated.
228253
* If `x` has unsigned integral data type, the returned array
229254
will have the default unsigned integral type for the device
230255
where input array `x` is allocated.
231-
* If `x` has a complex-valued floating-point data typee,
256+
* If `x` has a complex-valued floating-point data type,
232257
the returned array will have the default complex-valued
233258
floating-pointer data type for the device where input
234259
array `x` is allocated.
@@ -263,6 +288,118 @@ def prod(x, axis=None, dtype=None, keepdims=False):
263288
)
264289

265290

291+
def logsumexp(x, axis=None, dtype=None, keepdims=False):
292+
"""logsumexp(x, axis=None, dtype=None, keepdims=False)
293+
294+
Calculates the logarithm of the sum of exponentials of elements in the
295+
input array `x`.
296+
297+
Args:
298+
x (usm_ndarray):
299+
input array.
300+
axis (Optional[int, Tuple[int, ...]]):
301+
axis or axes along which values must be computed. If a tuple
302+
of unique integers, values are computed over multiple axes.
303+
If `None`, the result is computed over the entire array.
304+
Default: `None`.
305+
dtype (Optional[dtype]):
306+
data type of the returned array. If `None`, the default data
307+
type is inferred from the "kind" of the input array data type.
308+
* If `x` has a real-valued floating-point data type,
309+
the returned array will have the default real-valued
310+
floating-point data type for the device where input
311+
array `x` is allocated.
312+
* If `x` has a boolean or integral data type, the returned array
313+
will have the default floating point data type for the device
314+
where input array `x` is allocated.
315+
* If `x` has a complex-valued floating-point data type,
316+
an error is raised.
317+
If the data type (either specified or resolved) differs from the
318+
data type of `x`, the input array elements are cast to the
319+
specified data type before computing the result. Default: `None`.
320+
keepdims (Optional[bool]):
321+
if `True`, the reduced axes (dimensions) are included in the result
322+
as singleton dimensions, so that the returned array remains
323+
compatible with the input arrays according to Array Broadcasting
324+
rules. Otherwise, if `False`, the reduced axes are not included in
325+
the returned array. Default: `False`.
326+
Returns:
327+
usm_ndarray:
328+
an array containing the results. If the result was computed over
329+
the entire array, a zero-dimensional array is returned. The returned
330+
array has the data type as described in the `dtype` parameter
331+
description above.
332+
"""
333+
return _reduction_over_axis(
334+
x,
335+
axis,
336+
dtype,
337+
keepdims,
338+
ti._logsumexp_over_axis,
339+
lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported(
340+
inp_dt, res_dt
341+
),
342+
_default_reduction_dtype_fp_types,
343+
_identity=-dpt.inf,
344+
)
345+
346+
347+
def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
348+
"""reduce_hypot(x, axis=None, dtype=None, keepdims=False)
349+
350+
Calculates the square root of the sum of squares of elements in the input
351+
array `x`.
352+
353+
Args:
354+
x (usm_ndarray):
355+
input array.
356+
axis (Optional[int, Tuple[int, ...]]):
357+
axis or axes along which values must be computed. If a tuple
358+
of unique integers, values are computed over multiple axes.
359+
If `None`, the result is computed over the entire array.
360+
Default: `None`.
361+
dtype (Optional[dtype]):
362+
data type of the returned array. If `None`, the default data
363+
type is inferred from the "kind" of the input array data type.
364+
* If `x` has a real-valued floating-point data type,
365+
the returned array will have the default real-valued
366+
floating-point data type for the device where input
367+
array `x` is allocated.
368+
* If `x` has a boolean or integral data type, the returned array
369+
will have the default floating point data type for the device
370+
where input array `x` is allocated.
371+
* If `x` has a complex-valued floating-point data type,
372+
an error is raised.
373+
If the data type (either specified or resolved) differs from the
374+
data type of `x`, the input array elements are cast to the
375+
specified data type before computing the result. Default: `None`.
376+
keepdims (Optional[bool]):
377+
if `True`, the reduced axes (dimensions) are included in the result
378+
as singleton dimensions, so that the returned array remains
379+
compatible with the input arrays according to Array Broadcasting
380+
rules. Otherwise, if `False`, the reduced axes are not included in
381+
the returned array. Default: `False`.
382+
Returns:
383+
usm_ndarray:
384+
an array containing the results. If the result was computed over
385+
the entire array, a zero-dimensional array is returned. The returned
386+
array has the data type as described in the `dtype` parameter
387+
description above.
388+
"""
389+
return _reduction_over_axis(
390+
x,
391+
axis,
392+
dtype,
393+
keepdims,
394+
ti._hypot_over_axis,
395+
lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported(
396+
inp_dt, res_dt
397+
),
398+
_default_reduction_dtype_fp_types,
399+
_identity=0,
400+
)
401+
402+
266403
def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
267404
if not isinstance(x, dpt.usm_ndarray):
268405
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <limits>
3232
#include <type_traits>
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -61,7 +62,8 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6162

6263
resT operator()(const argT1 &in1, const argT2 &in2) const
6364
{
64-
return impl<resT>(in1, in2);
65+
using dpctl::tensor::math_utils::logaddexp;
66+
return logaddexp<resT>(in1, in2);
6567
}
6668

6769
template <int vec_sz>
@@ -79,34 +81,15 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7981
impl_finite<resT>(-std::abs(diff[i]));
8082
}
8183
else {
82-
res[i] = impl<resT>(in1[i], in2[i]);
84+
using dpctl::tensor::math_utils::logaddexp;
85+
res[i] = logaddexp<resT>(in1[i], in2[i]);
8386
}
8487
}
8588

8689
return res;
8790
}
8891

8992
private:
90-
template <typename T> T impl(T const &in1, T const &in2) const
91-
{
92-
if (in1 == in2) { // handle signed infinities
93-
const T log2 = std::log(T(2));
94-
return in1 + log2;
95-
}
96-
else {
97-
const T tmp = in1 - in2;
98-
if (tmp > 0) {
99-
return in1 + std::log1p(std::exp(-tmp));
100-
}
101-
else if (tmp <= 0) {
102-
return in2 + std::log1p(std::exp(tmp));
103-
}
104-
else {
105-
return std::numeric_limits<T>::quiet_NaN();
106-
}
107-
}
108-
}
109-
11093
template <typename T> T impl_finite(T const &in) const
11194
{
11295
return (in > 0) ? (in + std::log1p(std::exp(-in)))

0 commit comments

Comments
 (0)