@@ -288,98 +288,6 @@ def prod(x, axis=None, dtype=None, keepdims=False):
288
288
)
289
289
290
290
291
- def _tree_reduction_over_axis (
292
- x ,
293
- axis ,
294
- dtype ,
295
- keepdims ,
296
- _reduction_fn ,
297
- _dtype_supported ,
298
- _default_reduction_type_fn ,
299
- _identity = None ,
300
- ):
301
- if not isinstance (x , dpt .usm_ndarray ):
302
- raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
303
- nd = x .ndim
304
- if axis is None :
305
- axis = tuple (range (nd ))
306
- if not isinstance (axis , (tuple , list )):
307
- axis = (axis ,)
308
- axis = normalize_axis_tuple (axis , nd , "axis" )
309
- red_nd = len (axis )
310
- perm = [i for i in range (nd ) if i not in axis ] + list (axis )
311
- arr2 = dpt .permute_dims (x , perm )
312
- res_shape = arr2 .shape [: nd - red_nd ]
313
- q = x .sycl_queue
314
- inp_dt = x .dtype
315
- if dtype is None :
316
- res_dt = _default_reduction_type_fn (inp_dt , q )
317
- else :
318
- res_dt = dpt .dtype (dtype )
319
- res_dt = _to_device_supported_dtype (res_dt , q .sycl_device )
320
-
321
- res_usm_type = x .usm_type
322
- if x .size == 0 :
323
- if _identity is None :
324
- raise ValueError ("reduction does not support zero-size arrays" )
325
- else :
326
- if keepdims :
327
- res_shape = res_shape + (1 ,) * red_nd
328
- inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
329
- res_shape = tuple (res_shape [i ] for i in inv_perm )
330
- return dpt .astype (
331
- dpt .full (
332
- res_shape ,
333
- _identity ,
334
- dtype = _default_reduction_type_fn (inp_dt , q ),
335
- usm_type = res_usm_type ,
336
- sycl_queue = q ,
337
- ),
338
- res_dt ,
339
- )
340
- if red_nd == 0 :
341
- return dpt .astype (x , res_dt , copy = False )
342
-
343
- host_tasks_list = []
344
- if _dtype_supported (inp_dt , res_dt ):
345
- res = dpt .empty (
346
- res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
347
- )
348
- ht_e , _ = _reduction_fn (
349
- src = arr2 , trailing_dims_to_reduce = red_nd , dst = res , sycl_queue = q
350
- )
351
- host_tasks_list .append (ht_e )
352
- else :
353
- if dtype is None :
354
- raise RuntimeError (
355
- "Automatically determined reduction data type does not "
356
- "have direct implementation"
357
- )
358
- tmp_dt = _default_reduction_type_fn (inp_dt , q )
359
- tmp = dpt .empty (
360
- res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
361
- )
362
- ht_e_tmp , r_e = _reduction_fn (
363
- src = arr2 , trailing_dims_to_reduce = red_nd , dst = tmp , sycl_queue = q
364
- )
365
- host_tasks_list .append (ht_e_tmp )
366
- res = dpt .empty (
367
- res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
368
- )
369
- ht_e , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
370
- src = tmp , dst = res , sycl_queue = q , depends = [r_e ]
371
- )
372
- host_tasks_list .append (ht_e )
373
-
374
- if keepdims :
375
- res_shape = res_shape + (1 ,) * red_nd
376
- inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
377
- res = dpt .permute_dims (dpt .reshape (res , res_shape ), inv_perm )
378
- dpctl .SyclEvent .wait_for (host_tasks_list )
379
-
380
- return res
381
-
382
-
383
291
def logsumexp (x , axis = None , dtype = None , keepdims = False ):
384
292
"""logsumexp(x, axis=None, dtype=None, keepdims=False)
385
293
@@ -422,13 +330,15 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
422
330
array has the data type as described in the `dtype` parameter
423
331
description above.
424
332
"""
425
- return _tree_reduction_over_axis (
333
+ return _reduction_over_axis (
426
334
x ,
427
335
axis ,
428
336
dtype ,
429
337
keepdims ,
430
338
ti ._logsumexp_over_axis ,
431
- ti ._logsumexp_over_axis_dtype_supported ,
339
+ lambda inp_dt , res_dt , * _ : ti ._logsumexp_over_axis_dtype_supported (
340
+ inp_dt , res_dt
341
+ ),
432
342
_default_reduction_dtype_fp_types ,
433
343
_identity = - dpt .inf ,
434
344
)
@@ -476,13 +386,15 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
476
386
array has the data type as described in the `dtype` parameter
477
387
description above.
478
388
"""
479
- return _tree_reduction_over_axis (
389
+ return _reduction_over_axis (
480
390
x ,
481
391
axis ,
482
392
dtype ,
483
393
keepdims ,
484
394
ti ._hypot_over_axis ,
485
- ti ._hypot_over_axis_dtype_supported ,
395
+ lambda inp_dt , res_dt , * _ : ti ._hypot_over_axis_dtype_supported (
396
+ inp_dt , res_dt
397
+ ),
486
398
_default_reduction_dtype_fp_types ,
487
399
_identity = 0 ,
488
400
)
0 commit comments