Skip to content

Commit 039f5bc

Browse files
committed
logsumexp and reduce_hypot no longer use atomics
This change was made to improve the accuracy of these functions
1 parent 4bf688f commit 039f5bc

File tree

4 files changed

+454
-334
lines changed

4 files changed

+454
-334
lines changed

dpctl/tensor/_reduction.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,98 @@ def prod(x, axis=None, dtype=None, keepdims=False):
288288
)
289289

290290

291+
def _tree_reduction_over_axis(
292+
x,
293+
axis,
294+
dtype,
295+
keepdims,
296+
_reduction_fn,
297+
_dtype_supported,
298+
_default_reduction_type_fn,
299+
_identity=None,
300+
):
301+
if not isinstance(x, dpt.usm_ndarray):
302+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
303+
nd = x.ndim
304+
if axis is None:
305+
axis = tuple(range(nd))
306+
if not isinstance(axis, (tuple, list)):
307+
axis = (axis,)
308+
axis = normalize_axis_tuple(axis, nd, "axis")
309+
red_nd = len(axis)
310+
perm = [i for i in range(nd) if i not in axis] + list(axis)
311+
arr2 = dpt.permute_dims(x, perm)
312+
res_shape = arr2.shape[: nd - red_nd]
313+
q = x.sycl_queue
314+
inp_dt = x.dtype
315+
if dtype is None:
316+
res_dt = _default_reduction_type_fn(inp_dt, q)
317+
else:
318+
res_dt = dpt.dtype(dtype)
319+
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
320+
321+
res_usm_type = x.usm_type
322+
if x.size == 0:
323+
if _identity is None:
324+
raise ValueError("reduction does not support zero-size arrays")
325+
else:
326+
if keepdims:
327+
res_shape = res_shape + (1,) * red_nd
328+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
329+
res_shape = tuple(res_shape[i] for i in inv_perm)
330+
return dpt.astype(
331+
dpt.full(
332+
res_shape,
333+
_identity,
334+
dtype=_default_reduction_type_fn(inp_dt, q),
335+
usm_type=res_usm_type,
336+
sycl_queue=q,
337+
),
338+
res_dt,
339+
)
340+
if red_nd == 0:
341+
return dpt.astype(x, res_dt, copy=False)
342+
343+
host_tasks_list = []
344+
if _dtype_supported(inp_dt, res_dt):
345+
res = dpt.empty(
346+
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
347+
)
348+
ht_e, _ = _reduction_fn(
349+
src=arr2, trailing_dims_to_reduce=red_nd, dst=res, sycl_queue=q
350+
)
351+
host_tasks_list.append(ht_e)
352+
else:
353+
if dtype is None:
354+
raise RuntimeError(
355+
"Automatically determined reduction data type does not "
356+
"have direct implementation"
357+
)
358+
tmp_dt = _default_reduction_type_fn(inp_dt, q)
359+
tmp = dpt.empty(
360+
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
361+
)
362+
ht_e_tmp, r_e = _reduction_fn(
363+
src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q
364+
)
365+
host_tasks_list.append(ht_e_tmp)
366+
res = dpt.empty(
367+
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
368+
)
369+
ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray(
370+
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
371+
)
372+
host_tasks_list.append(ht_e)
373+
374+
if keepdims:
375+
res_shape = res_shape + (1,) * red_nd
376+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
377+
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
378+
dpctl.SyclEvent.wait_for(host_tasks_list)
379+
380+
return res
381+
382+
291383
def logsumexp(x, axis=None, dtype=None, keepdims=False):
292384
"""logsumexp(x, axis=None, dtype=None, keepdims=False)
293385
@@ -330,7 +422,7 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
330422
array has the data type as described in the `dtype` parameter
331423
description above.
332424
"""
333-
return _reduction_over_axis(
425+
return _tree_reduction_over_axis(
334426
x,
335427
axis,
336428
dtype,
@@ -384,7 +476,7 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
384476
array has the data type as described in the `dtype` parameter
385477
description above.
386478
"""
387-
return _reduction_over_axis(
479+
return _tree_reduction_over_axis(
388480
x,
389481
axis,
390482
dtype,

dpctl/tensor/libtensor/include/kernels/reductions.hpp

Lines changed: 0 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -3069,53 +3069,6 @@ struct ProductOverAxis0TempsContigFactory
30693069
}
30703070
};
30713071

3072-
/* @brief Types supported by hypot-reduction code based on atomic_ref */
3073-
template <typename argTy, typename outTy>
3074-
struct TypePairSupportDataForHypotReductionAtomic
3075-
{
3076-
3077-
/* value if true a kernel for <argTy, outTy> must be instantiated, false
3078-
* otherwise */
3079-
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
3080-
// feature, supported
3081-
// by DPC++ input bool
3082-
// input bool
3083-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
3084-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
3085-
// input int8
3086-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
3087-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
3088-
// input uint8
3089-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
3090-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
3091-
// input int16
3092-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
3093-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
3094-
// input uint16
3095-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
3096-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
3097-
// input int32
3098-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
3099-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
3100-
// input uint32
3101-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
3102-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
3103-
// input int64
3104-
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
3105-
// input uint64
3106-
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
3107-
// input half
3108-
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
3109-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
3110-
// input float
3111-
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
3112-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
3113-
// input double
3114-
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
3115-
// fall-through
3116-
td_ns::NotDefinedEntry>::is_defined;
3117-
};
3118-
31193072
template <typename argTy, typename outTy>
31203073
struct TypePairSupportDataForHypotReductionTemps
31213074
{
@@ -3177,25 +3130,6 @@ struct TypePairSupportDataForHypotReductionTemps
31773130
td_ns::NotDefinedEntry>::is_defined;
31783131
};
31793132

3180-
template <typename fnT, typename srcTy, typename dstTy>
3181-
struct HypotOverAxisAtomicStridedFactory
3182-
{
3183-
fnT get() const
3184-
{
3185-
if constexpr (TypePairSupportDataForHypotReductionAtomic<
3186-
srcTy, dstTy>::is_defined)
3187-
{
3188-
using ReductionOpT = su_ns::Hypot<dstTy>;
3189-
return dpctl::tensor::kernels::
3190-
reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
3191-
ReductionOpT>;
3192-
}
3193-
else {
3194-
return nullptr;
3195-
}
3196-
}
3197-
};
3198-
31993133
template <typename fnT, typename srcTy, typename dstTy>
32003134
struct HypotOverAxisTempsStridedFactory
32013135
{
@@ -3215,44 +3149,6 @@ struct HypotOverAxisTempsStridedFactory
32153149
}
32163150
};
32173151

3218-
template <typename fnT, typename srcTy, typename dstTy>
3219-
struct HypotOverAxis1AtomicContigFactory
3220-
{
3221-
fnT get() const
3222-
{
3223-
if constexpr (TypePairSupportDataForHypotReductionAtomic<
3224-
srcTy, dstTy>::is_defined)
3225-
{
3226-
using ReductionOpT = su_ns::Hypot<dstTy>;
3227-
return dpctl::tensor::kernels::
3228-
reduction_axis1_over_group_with_atomics_contig_impl<
3229-
srcTy, dstTy, ReductionOpT>;
3230-
}
3231-
else {
3232-
return nullptr;
3233-
}
3234-
}
3235-
};
3236-
3237-
template <typename fnT, typename srcTy, typename dstTy>
3238-
struct HypotOverAxis0AtomicContigFactory
3239-
{
3240-
fnT get() const
3241-
{
3242-
if constexpr (TypePairSupportDataForHypotReductionAtomic<
3243-
srcTy, dstTy>::is_defined)
3244-
{
3245-
using ReductionOpT = su_ns::Hypot<dstTy>;
3246-
return dpctl::tensor::kernels::
3247-
reduction_axis0_over_group_with_atomics_contig_impl<
3248-
srcTy, dstTy, ReductionOpT>;
3249-
}
3250-
else {
3251-
return nullptr;
3252-
}
3253-
}
3254-
};
3255-
32563152
template <typename fnT, typename srcTy, typename dstTy>
32573153
struct HypotOverAxis1TempsContigFactory
32583154
{
@@ -3291,53 +3187,6 @@ struct HypotOverAxis0TempsContigFactory
32913187
}
32923188
};
32933189

3294-
/* @brief Types supported by logsumexp-reduction code based on atomic_ref */
3295-
template <typename argTy, typename outTy>
3296-
struct TypePairSupportDataForLogSumExpReductionAtomic
3297-
{
3298-
3299-
/* value if true a kernel for <argTy, outTy> must be instantiated, false
3300-
* otherwise */
3301-
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
3302-
// feature, supported
3303-
// by DPC++ input bool
3304-
// input bool
3305-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, float>,
3306-
td_ns::TypePairDefinedEntry<argTy, bool, outTy, double>,
3307-
// input int8
3308-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, float>,
3309-
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, double>,
3310-
// input uint8
3311-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, float>,
3312-
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, double>,
3313-
// input int16
3314-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, float>,
3315-
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, double>,
3316-
// input uint16
3317-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, float>,
3318-
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, double>,
3319-
// input int32
3320-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, float>,
3321-
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, double>,
3322-
// input uint32
3323-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, float>,
3324-
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, double>,
3325-
// input int64
3326-
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, double>,
3327-
// input uint64
3328-
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, double>,
3329-
// input half
3330-
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float>,
3331-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
3332-
// input float
3333-
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
3334-
td_ns::TypePairDefinedEntry<argTy, float, outTy, double>,
3335-
// input double
3336-
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
3337-
// fall-through
3338-
td_ns::NotDefinedEntry>::is_defined;
3339-
};
3340-
33413190
template <typename argTy, typename outTy>
33423191
struct TypePairSupportDataForLogSumExpReductionTemps
33433192
{
@@ -3399,25 +3248,6 @@ struct TypePairSupportDataForLogSumExpReductionTemps
33993248
td_ns::NotDefinedEntry>::is_defined;
34003249
};
34013250

3402-
template <typename fnT, typename srcTy, typename dstTy>
3403-
struct LogSumExpOverAxisAtomicStridedFactory
3404-
{
3405-
fnT get() const
3406-
{
3407-
if constexpr (TypePairSupportDataForLogSumExpReductionAtomic<
3408-
srcTy, dstTy>::is_defined)
3409-
{
3410-
using ReductionOpT = su_ns::LogSumExp<dstTy>;
3411-
return dpctl::tensor::kernels::
3412-
reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
3413-
ReductionOpT>;
3414-
}
3415-
else {
3416-
return nullptr;
3417-
}
3418-
}
3419-
};
3420-
34213251
template <typename fnT, typename srcTy, typename dstTy>
34223252
struct LogSumExpOverAxisTempsStridedFactory
34233253
{
@@ -3437,44 +3267,6 @@ struct LogSumExpOverAxisTempsStridedFactory
34373267
}
34383268
};
34393269

3440-
template <typename fnT, typename srcTy, typename dstTy>
3441-
struct LogSumExpOverAxis1AtomicContigFactory
3442-
{
3443-
fnT get() const
3444-
{
3445-
if constexpr (TypePairSupportDataForLogSumExpReductionAtomic<
3446-
srcTy, dstTy>::is_defined)
3447-
{
3448-
using ReductionOpT = su_ns::LogSumExp<dstTy>;
3449-
return dpctl::tensor::kernels::
3450-
reduction_axis1_over_group_with_atomics_contig_impl<
3451-
srcTy, dstTy, ReductionOpT>;
3452-
}
3453-
else {
3454-
return nullptr;
3455-
}
3456-
}
3457-
};
3458-
3459-
template <typename fnT, typename srcTy, typename dstTy>
3460-
struct LogSumExpOverAxis0AtomicContigFactory
3461-
{
3462-
fnT get() const
3463-
{
3464-
if constexpr (TypePairSupportDataForLogSumExpReductionAtomic<
3465-
srcTy, dstTy>::is_defined)
3466-
{
3467-
using ReductionOpT = su_ns::LogSumExp<dstTy>;
3468-
return dpctl::tensor::kernels::
3469-
reduction_axis0_over_group_with_atomics_contig_impl<
3470-
srcTy, dstTy, ReductionOpT>;
3471-
}
3472-
else {
3473-
return nullptr;
3474-
}
3475-
}
3476-
};
3477-
34783270
template <typename fnT, typename srcTy, typename dstTy>
34793271
struct LogSumExpOverAxis1TempsContigFactory
34803272
{

0 commit comments

Comments
 (0)