Skip to content

Commit 44b779d

Browse files
committed
Implements logsumexp and reduce_hypot
1 parent b437c47 commit 44b779d

File tree

5 files changed

+777
-12
lines changed

5 files changed

+777
-12
lines changed

dpctl/tensor/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,16 @@
164164
tanh,
165165
trunc,
166166
)
167-
from ._reduction import argmax, argmin, max, min, prod, sum
167+
from ._reduction import (
168+
argmax,
169+
argmin,
170+
logsumexp,
171+
max,
172+
min,
173+
prod,
174+
reduce_hypot,
175+
sum,
176+
)
168177
from ._testing import allclose
169178

170179
__all__ = [
@@ -322,4 +331,6 @@
322331
"exp2",
323332
"copysign",
324333
"rsqrt",
334+
"logsumexp",
335+
"reduce_hypot",
325336
]

dpctl/tensor/_reduction.py

Lines changed: 144 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ def _default_reduction_dtype(inp_dt, q):
5252
return res_dt
5353

5454

55+
def _default_reduction_dtype_fp_types(inp_dt, q):
56+
"""Gives default output data type for given input data
57+
type `inp_dt` when reduction is performed on queue `q`
58+
and the reduction supports only floating-point data types
59+
"""
60+
inp_kind = inp_dt.kind
61+
if inp_kind in "biu":
62+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
63+
can_cast_v = dpt.can_cast(inp_dt, res_dt)
64+
if not can_cast_v:
65+
_fp64 = q.sycl_device.has_aspect_fp64
66+
res_dt = dpt.float64 if _fp64 else dpt.float32
67+
elif inp_kind in "f":
68+
res_dt = dpt.dtype(ti.default_device_fp_type(q))
69+
if res_dt.itemsize < inp_dt.itemsize:
70+
res_dt = inp_dt
71+
elif inp_kind in "c":
72+
raise TypeError("reduction not defined for complex types")
73+
74+
return res_dt
75+
76+
5577
def _reduction_over_axis(
5678
x,
5779
axis,
@@ -91,12 +113,15 @@ def _reduction_over_axis(
91113
res_shape = res_shape + (1,) * red_nd
92114
inv_perm = sorted(range(nd), key=lambda d: perm[d])
93115
res_shape = tuple(res_shape[i] for i in inv_perm)
94-
return dpt.full(
95-
res_shape,
96-
_identity,
97-
dtype=res_dt,
98-
usm_type=res_usm_type,
99-
sycl_queue=q,
116+
return dpt.astype(
117+
dpt.full(
118+
res_shape,
119+
_identity,
120+
dtype=_default_reduction_type_fn(inp_dt, q),
121+
usm_type=res_usm_type,
122+
sycl_queue=q,
123+
),
124+
res_dt,
100125
)
101126
if red_nd == 0:
102127
return dpt.astype(x, res_dt, copy=False)
@@ -116,7 +141,7 @@ def _reduction_over_axis(
116141
"Automatically determined reduction data type does not "
117142
"have direct implementation"
118143
)
119-
tmp_dt = _default_reduction_dtype(inp_dt, q)
144+
tmp_dt = _default_reduction_type_fn(inp_dt, q)
120145
tmp = dpt.empty(
121146
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
122147
)
@@ -161,13 +186,13 @@ def sum(x, axis=None, dtype=None, keepdims=False):
161186
the returned array will have the default real-valued
162187
floating-point data type for the device where input
163188
array `x` is allocated.
164-
* If x` has signed integral data type, the returned array
189+
* If `x` has signed integral data type, the returned array
165190
will have the default signed integral type for the device
166191
where input array `x` is allocated.
167192
* If `x` has unsigned integral data type, the returned array
168193
will have the default unsigned integral type for the device
169194
where input array `x` is allocated.
170-
* If `x` has a complex-valued floating-point data typee,
195+
* If `x` has a complex-valued floating-point data type,
171196
the returned array will have the default complex-valued
172197
floating-pointer data type for the device where input
173198
array `x` is allocated.
@@ -222,13 +247,13 @@ def prod(x, axis=None, dtype=None, keepdims=False):
222247
the returned array will have the default real-valued
223248
floating-point data type for the device where input
224249
array `x` is allocated.
225-
* If x` has signed integral data type, the returned array
250+
* If `x` has signed integral data type, the returned array
226251
will have the default signed integral type for the device
227252
where input array `x` is allocated.
228253
* If `x` has unsigned integral data type, the returned array
229254
will have the default unsigned integral type for the device
230255
where input array `x` is allocated.
231-
* If `x` has a complex-valued floating-point data typee,
256+
* If `x` has a complex-valued floating-point data type,
232257
the returned array will have the default complex-valued
233258
floating-pointer data type for the device where input
234259
array `x` is allocated.
@@ -263,6 +288,114 @@ def prod(x, axis=None, dtype=None, keepdims=False):
263288
)
264289

265290

291+
def logsumexp(x, axis=None, dtype=None, keepdims=False):
292+
"""logsumexp(x, axis=None, dtype=None, keepdims=False)
293+
294+
Calculates the logarithm of the sum of exponentials of elements in the
295+
input array `x`.
296+
297+
Args:
298+
x (usm_ndarray):
299+
input array.
300+
axis (Optional[int, Tuple[int, ...]]):
301+
axis or axes along which values must be computed. If a tuple
302+
of unique integers, values are computed over multiple axes.
303+
If `None`, the result is computed over the entire array.
304+
Default: `None`.
305+
dtype (Optional[dtype]):
306+
data type of the returned array. If `None`, the default data
307+
type is inferred from the "kind" of the input array data type.
308+
* If `x` has a real-valued floating-point data type,
309+
the returned array will have the default real-valued
310+
floating-point data type for the device where input
311+
array `x` is allocated.
312+
* If `x` has a boolean or integral data type, the returned array
313+
will have the default floating point data type for the device
314+
where input array `x` is allocated.
315+
* If `x` has a complex-valued floating-point data type,
316+
an error is raised.
317+
If the data type (either specified or resolved) differs from the
318+
data type of `x`, the input array elements are cast to the
319+
specified data type before computing the result. Default: `None`.
320+
keepdims (Optional[bool]):
321+
if `True`, the reduced axes (dimensions) are included in the result
322+
as singleton dimensions, so that the returned array remains
323+
compatible with the input arrays according to Array Broadcasting
324+
rules. Otherwise, if `False`, the reduced axes are not included in
325+
the returned array. Default: `False`.
326+
Returns:
327+
usm_ndarray:
328+
an array containing the results. If the result was computed over
329+
the entire array, a zero-dimensional array is returned. The returned
330+
array has the data type as described in the `dtype` parameter
331+
description above.
332+
"""
333+
return _reduction_over_axis(
334+
x,
335+
axis,
336+
dtype,
337+
keepdims,
338+
ti._logsumexp_over_axis,
339+
ti._logsumexp_over_axis_dtype_supported,
340+
_default_reduction_dtype_fp_types,
341+
_identity=-dpt.inf,
342+
)
343+
344+
345+
def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
346+
"""reduce_hypot(x, axis=None, dtype=None, keepdims=False)
347+
348+
Calculates the square root of the sum of squares of elements in the input
349+
array `x`.
350+
351+
Args:
352+
x (usm_ndarray):
353+
input array.
354+
axis (Optional[int, Tuple[int, ...]]):
355+
axis or axes along which values must be computed. If a tuple
356+
of unique integers, values are computed over multiple axes.
357+
If `None`, the result is computed over the entire array.
358+
Default: `None`.
359+
dtype (Optional[dtype]):
360+
data type of the returned array. If `None`, the default data
361+
type is inferred from the "kind" of the input array data type.
362+
* If `x` has a real-valued floating-point data type,
363+
the returned array will have the default real-valued
364+
floating-point data type for the device where input
365+
array `x` is allocated.
366+
* If `x` has a boolean or integral data type, the returned array
367+
will have the default floating point data type for the device
368+
where input array `x` is allocated.
369+
* If `x` has a complex-valued floating-point data type,
370+
an error is raised.
371+
If the data type (either specified or resolved) differs from the
372+
data type of `x`, the input array elements are cast to the
373+
specified data type before computing the result. Default: `None`.
374+
keepdims (Optional[bool]):
375+
if `True`, the reduced axes (dimensions) are included in the result
376+
as singleton dimensions, so that the returned array remains
377+
compatible with the input arrays according to Array Broadcasting
378+
rules. Otherwise, if `False`, the reduced axes are not included in
379+
the returned array. Default: `False`.
380+
Returns:
381+
usm_ndarray:
382+
an array containing the results. If the result was computed over
383+
the entire array, a zero-dimensional array is returned. The returned
384+
array has the data type as described in the `dtype` parameter
385+
description above.
386+
"""
387+
return _reduction_over_axis(
388+
x,
389+
axis,
390+
dtype,
391+
keepdims,
392+
ti._hypot_over_axis,
393+
ti._hypot_over_axis_dtype_supported,
394+
_default_reduction_dtype_fp_types,
395+
_identity=0,
396+
)
397+
398+
266399
def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
267400
if not isinstance(x, dpt.usm_ndarray):
268401
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")

0 commit comments

Comments
 (0)