Skip to content

Commit 4dd054f

Browse files
committed
Removes _tree_reduction_over_axis
Use lambdas to ignore atomic-specific arguments to hypot and logsumexp dtype_supported functions
1 parent 4e2789d commit 4dd054f

File tree

1 file changed

+8
-96
lines changed

1 file changed

+8
-96
lines changed

dpctl/tensor/_reduction.py

Lines changed: 8 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -288,98 +288,6 @@ def prod(x, axis=None, dtype=None, keepdims=False):
288288
)
289289

290290

291-
def _tree_reduction_over_axis(
292-
x,
293-
axis,
294-
dtype,
295-
keepdims,
296-
_reduction_fn,
297-
_dtype_supported,
298-
_default_reduction_type_fn,
299-
_identity=None,
300-
):
301-
if not isinstance(x, dpt.usm_ndarray):
302-
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
303-
nd = x.ndim
304-
if axis is None:
305-
axis = tuple(range(nd))
306-
if not isinstance(axis, (tuple, list)):
307-
axis = (axis,)
308-
axis = normalize_axis_tuple(axis, nd, "axis")
309-
red_nd = len(axis)
310-
perm = [i for i in range(nd) if i not in axis] + list(axis)
311-
arr2 = dpt.permute_dims(x, perm)
312-
res_shape = arr2.shape[: nd - red_nd]
313-
q = x.sycl_queue
314-
inp_dt = x.dtype
315-
if dtype is None:
316-
res_dt = _default_reduction_type_fn(inp_dt, q)
317-
else:
318-
res_dt = dpt.dtype(dtype)
319-
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
320-
321-
res_usm_type = x.usm_type
322-
if x.size == 0:
323-
if _identity is None:
324-
raise ValueError("reduction does not support zero-size arrays")
325-
else:
326-
if keepdims:
327-
res_shape = res_shape + (1,) * red_nd
328-
inv_perm = sorted(range(nd), key=lambda d: perm[d])
329-
res_shape = tuple(res_shape[i] for i in inv_perm)
330-
return dpt.astype(
331-
dpt.full(
332-
res_shape,
333-
_identity,
334-
dtype=_default_reduction_type_fn(inp_dt, q),
335-
usm_type=res_usm_type,
336-
sycl_queue=q,
337-
),
338-
res_dt,
339-
)
340-
if red_nd == 0:
341-
return dpt.astype(x, res_dt, copy=False)
342-
343-
host_tasks_list = []
344-
if _dtype_supported(inp_dt, res_dt):
345-
res = dpt.empty(
346-
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
347-
)
348-
ht_e, _ = _reduction_fn(
349-
src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q
350-
)
351-
host_tasks_list.append(ht_e)
352-
else:
353-
if dtype is None:
354-
raise RuntimeError(
355-
"Automatically determined reduction data type does not "
356-
"have direct implementation"
357-
)
358-
tmp_dt = _default_reduction_type_fn(inp_dt, q)
359-
tmp = dpt.empty(
360-
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
361-
)
362-
ht_e_tmp, r_e = _reduction_fn(
363-
src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q
364-
)
365-
host_tasks_list.append(ht_e_tmp)
366-
res = dpt.empty(
367-
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
368-
)
369-
ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray(
370-
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
371-
)
372-
host_tasks_list.append(ht_e)
373-
374-
if keepdims:
375-
res_shape = res_shape + (1,) * red_nd
376-
inv_perm = sorted(range(nd), key=lambda d: perm[d])
377-
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
378-
dpctl.SyclEvent.wait_for(host_tasks_list)
379-
380-
return res
381-
382-
383291
def logsumexp(x, axis=None, dtype=None, keepdims=False):
384292
"""logsumexp(x, axis=None, dtype=None, keepdims=False)
385293
@@ -422,13 +330,15 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
422330
array has the data type as described in the `dtype` parameter
423331
description above.
424332
"""
425-
return _tree_reduction_over_axis(
333+
return _reduction_over_axis(
426334
x,
427335
axis,
428336
dtype,
429337
keepdims,
430338
ti._logsumexp_over_axis,
431-
ti._logsumexp_over_axis_dtype_supported,
339+
lambda inp_dt, res_dt, *_: ti._logsumexp_over_axis_dtype_supported(
340+
inp_dt, res_dt
341+
),
432342
_default_reduction_dtype_fp_types,
433343
_identity=-dpt.inf,
434344
)
@@ -476,13 +386,15 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
476386
array has the data type as described in the `dtype` parameter
477387
description above.
478388
"""
479-
return _tree_reduction_over_axis(
389+
return _reduction_over_axis(
480390
x,
481391
axis,
482392
dtype,
483393
keepdims,
484394
ti._hypot_over_axis,
485-
ti._hypot_over_axis_dtype_supported,
395+
lambda inp_dt, res_dt, *_: ti._hypot_over_axis_dtype_supported(
396+
inp_dt, res_dt
397+
),
486398
_default_reduction_dtype_fp_types,
487399
_identity=0,
488400
)

0 commit comments

Comments
 (0)