@@ -52,6 +52,28 @@ def _default_reduction_dtype(inp_dt, q):
52
52
return res_dt
53
53
54
54
55
+ def _default_reduction_dtype_fp_types (inp_dt , q ):
56
+ """Gives default output data type for given input data
57
+ type `inp_dt` when reduction is performed on queue `q`
58
+ and the reduction supports only floating-point data types
59
+ """
60
+ inp_kind = inp_dt .kind
61
+ if inp_kind in "biu" :
62
+ res_dt = dpt .dtype (ti .default_device_fp_type (q ))
63
+ can_cast_v = dpt .can_cast (inp_dt , res_dt )
64
+ if not can_cast_v :
65
+ _fp64 = q .sycl_device .has_aspect_fp64
66
+ res_dt = dpt .float64 if _fp64 else dpt .float32
67
+ elif inp_kind in "f" :
68
+ res_dt = dpt .dtype (ti .default_device_fp_type (q ))
69
+ if res_dt .itemsize < inp_dt .itemsize :
70
+ res_dt = inp_dt
71
+ elif inp_kind in "c" :
72
+ raise TypeError ("reduction not defined for complex types" )
73
+
74
+ return res_dt
75
+
76
+
55
77
def _reduction_over_axis (
56
78
x ,
57
79
axis ,
@@ -91,12 +113,15 @@ def _reduction_over_axis(
91
113
res_shape = res_shape + (1 ,) * red_nd
92
114
inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
93
115
res_shape = tuple (res_shape [i ] for i in inv_perm )
94
- return dpt .full (
95
- res_shape ,
96
- _identity ,
97
- dtype = res_dt ,
98
- usm_type = res_usm_type ,
99
- sycl_queue = q ,
116
+ return dpt .astype (
117
+ dpt .full (
118
+ res_shape ,
119
+ _identity ,
120
+ dtype = _default_reduction_type_fn (inp_dt , q ),
121
+ usm_type = res_usm_type ,
122
+ sycl_queue = q ,
123
+ ),
124
+ res_dt ,
100
125
)
101
126
if red_nd == 0 :
102
127
return dpt .astype (x , res_dt , copy = False )
@@ -116,7 +141,7 @@ def _reduction_over_axis(
116
141
"Automatically determined reduction data type does not "
117
142
"have direct implementation"
118
143
)
119
- tmp_dt = _default_reduction_dtype (inp_dt , q )
144
+ tmp_dt = _default_reduction_type_fn (inp_dt , q )
120
145
tmp = dpt .empty (
121
146
res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
122
147
)
@@ -161,13 +186,13 @@ def sum(x, axis=None, dtype=None, keepdims=False):
161
186
the returned array will have the default real-valued
162
187
floating-point data type for the device where input
163
188
array `x` is allocated.
164
- * If x` has signed integral data type, the returned array
189
+ * If ` x` has signed integral data type, the returned array
165
190
will have the default signed integral type for the device
166
191
where input array `x` is allocated.
167
192
* If `x` has unsigned integral data type, the returned array
168
193
will have the default unsigned integral type for the device
169
194
where input array `x` is allocated.
170
- * If `x` has a complex-valued floating-point data typee ,
195
+ * If `x` has a complex-valued floating-point data type ,
171
196
the returned array will have the default complex-valued
172
197
floating-pointer data type for the device where input
173
198
array `x` is allocated.
@@ -222,13 +247,13 @@ def prod(x, axis=None, dtype=None, keepdims=False):
222
247
the returned array will have the default real-valued
223
248
floating-point data type for the device where input
224
249
array `x` is allocated.
225
- * If x` has signed integral data type, the returned array
250
+ * If ` x` has signed integral data type, the returned array
226
251
will have the default signed integral type for the device
227
252
where input array `x` is allocated.
228
253
* If `x` has unsigned integral data type, the returned array
229
254
will have the default unsigned integral type for the device
230
255
where input array `x` is allocated.
231
- * If `x` has a complex-valued floating-point data typee ,
256
+ * If `x` has a complex-valued floating-point data type ,
232
257
the returned array will have the default complex-valued
233
258
floating-pointer data type for the device where input
234
259
array `x` is allocated.
@@ -263,6 +288,118 @@ def prod(x, axis=None, dtype=None, keepdims=False):
263
288
)
264
289
265
290
291
+ def logsumexp (x , axis = None , dtype = None , keepdims = False ):
292
+ """logsumexp(x, axis=None, dtype=None, keepdims=False)
293
+
294
+ Calculates the logarithm of the sum of exponentials of elements in the
295
+ input array `x`.
296
+
297
+ Args:
298
+ x (usm_ndarray):
299
+ input array.
300
+ axis (Optional[int, Tuple[int, ...]]):
301
+ axis or axes along which values must be computed. If a tuple
302
+ of unique integers, values are computed over multiple axes.
303
+ If `None`, the result is computed over the entire array.
304
+ Default: `None`.
305
+ dtype (Optional[dtype]):
306
+ data type of the returned array. If `None`, the default data
307
+ type is inferred from the "kind" of the input array data type.
308
+ * If `x` has a real-valued floating-point data type,
309
+ the returned array will have the default real-valued
310
+ floating-point data type for the device where input
311
+ array `x` is allocated.
312
+ * If `x` has a boolean or integral data type, the returned array
313
+ will have the default floating point data type for the device
314
+ where input array `x` is allocated.
315
+ * If `x` has a complex-valued floating-point data type,
316
+ an error is raised.
317
+ If the data type (either specified or resolved) differs from the
318
+ data type of `x`, the input array elements are cast to the
319
+ specified data type before computing the result. Default: `None`.
320
+ keepdims (Optional[bool]):
321
+ if `True`, the reduced axes (dimensions) are included in the result
322
+ as singleton dimensions, so that the returned array remains
323
+ compatible with the input arrays according to Array Broadcasting
324
+ rules. Otherwise, if `False`, the reduced axes are not included in
325
+ the returned array. Default: `False`.
326
+ Returns:
327
+ usm_ndarray:
328
+ an array containing the results. If the result was computed over
329
+ the entire array, a zero-dimensional array is returned. The returned
330
+ array has the data type as described in the `dtype` parameter
331
+ description above.
332
+ """
333
+ return _reduction_over_axis (
334
+ x ,
335
+ axis ,
336
+ dtype ,
337
+ keepdims ,
338
+ ti ._logsumexp_over_axis ,
339
+ lambda inp_dt , res_dt , * _ : ti ._logsumexp_over_axis_dtype_supported (
340
+ inp_dt , res_dt
341
+ ),
342
+ _default_reduction_dtype_fp_types ,
343
+ _identity = - dpt .inf ,
344
+ )
345
+
346
+
347
+ def reduce_hypot (x , axis = None , dtype = None , keepdims = False ):
348
+ """reduce_hypot(x, axis=None, dtype=None, keepdims=False)
349
+
350
+ Calculates the square root of the sum of squares of elements in the input
351
+ array `x`.
352
+
353
+ Args:
354
+ x (usm_ndarray):
355
+ input array.
356
+ axis (Optional[int, Tuple[int, ...]]):
357
+ axis or axes along which values must be computed. If a tuple
358
+ of unique integers, values are computed over multiple axes.
359
+ If `None`, the result is computed over the entire array.
360
+ Default: `None`.
361
+ dtype (Optional[dtype]):
362
+ data type of the returned array. If `None`, the default data
363
+ type is inferred from the "kind" of the input array data type.
364
+ * If `x` has a real-valued floating-point data type,
365
+ the returned array will have the default real-valued
366
+ floating-point data type for the device where input
367
+ array `x` is allocated.
368
+ * If `x` has a boolean or integral data type, the returned array
369
+ will have the default floating point data type for the device
370
+ where input array `x` is allocated.
371
+ * If `x` has a complex-valued floating-point data type,
372
+ an error is raised.
373
+ If the data type (either specified or resolved) differs from the
374
+ data type of `x`, the input array elements are cast to the
375
+ specified data type before computing the result. Default: `None`.
376
+ keepdims (Optional[bool]):
377
+ if `True`, the reduced axes (dimensions) are included in the result
378
+ as singleton dimensions, so that the returned array remains
379
+ compatible with the input arrays according to Array Broadcasting
380
+ rules. Otherwise, if `False`, the reduced axes are not included in
381
+ the returned array. Default: `False`.
382
+ Returns:
383
+ usm_ndarray:
384
+ an array containing the results. If the result was computed over
385
+ the entire array, a zero-dimensional array is returned. The returned
386
+ array has the data type as described in the `dtype` parameter
387
+ description above.
388
+ """
389
+ return _reduction_over_axis (
390
+ x ,
391
+ axis ,
392
+ dtype ,
393
+ keepdims ,
394
+ ti ._hypot_over_axis ,
395
+ lambda inp_dt , res_dt , * _ : ti ._hypot_over_axis_dtype_supported (
396
+ inp_dt , res_dt
397
+ ),
398
+ _default_reduction_dtype_fp_types ,
399
+ _identity = 0 ,
400
+ )
401
+
402
+
266
403
def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
267
404
if not isinstance (x , dpt .usm_ndarray ):
268
405
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
0 commit comments