Skip to content

Commit 448a7f1

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Always use atomic implementation for min/max if available
For add/multiplies reductions, use tree reduction for FP types, real and complex, to get better round-off accumulation properties.
1 parent 4dd054f commit 448a7f1

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

dpctl/tensor/libtensor/source/reductions/reduction_atomic_support.hpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,21 @@ bool check_atomic_support(const sycl::queue &exec_q,
8585
}
8686
}
8787

88-
template <typename fnT, typename T> struct MaxAtomicSupportFactory
88+
template <typename fnT, typename T> struct ArithmeticAtomicSupportFactory
8989
{
9090
fnT get()
9191
{
92-
if constexpr (std::is_floating_point_v<T>) {
92+
using dpctl::tensor::type_utils::is_complex;
93+
if constexpr (std::is_floating_point_v<T> ||
94+
std::is_same_v<T, sycl::half> || is_complex<T>::value)
95+
{
96+
// for real- and complex- floating point types, tree reduction has
97+
// better round-off accumulation properties (round-off error is
98+
// proportional to the log2(reduction_size), while naive elementwise
99+
// summation used by atomic implementation has round-off error
100+
// growing proportional to the reduction_size.), hence reduction
101+
// over floating point types should always use tree_reduction
102+
// algorithm, even though atomic implementation may be applicable
93103
return fixed_decision<false>;
94104
}
95105
else {
@@ -98,43 +108,33 @@ template <typename fnT, typename T> struct MaxAtomicSupportFactory
98108
}
99109
};
100110

101-
template <typename fnT, typename T> struct MinAtomicSupportFactory
111+
template <typename fnT, typename T> struct MinMaxAtomicSupportFactory
102112
{
103113
fnT get()
104114
{
105-
if constexpr (std::is_floating_point_v<T>) {
106-
return fixed_decision<false>;
107-
}
108-
else {
109-
return check_atomic_support<T>;
110-
}
115+
return check_atomic_support<T>;
111116
}
112117
};
113118

114-
template <typename fnT, typename T> struct SumAtomicSupportFactory
119+
template <typename fnT, typename T>
120+
struct MaxAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
115121
{
116-
fnT get()
117-
{
118-
if constexpr (std::is_floating_point_v<T>) {
119-
return fixed_decision<false>;
120-
}
121-
else {
122-
return check_atomic_support<T>;
123-
}
124-
}
125122
};
126123

127-
template <typename fnT, typename T> struct ProductAtomicSupportFactory
124+
template <typename fnT, typename T>
125+
struct MinAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
126+
{
127+
};
128+
129+
template <typename fnT, typename T>
130+
struct SumAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
131+
{
132+
};
133+
134+
template <typename fnT, typename T>
135+
struct ProductAtomicSupportFactory
136+
: public ArithmeticAtomicSupportFactory<fnT, T>
128137
{
129-
fnT get()
130-
{
131-
if constexpr (std::is_floating_point_v<T>) {
132-
return fixed_decision<false>;
133-
}
134-
else {
135-
return check_atomic_support<T>;
136-
}
137-
}
138138
};
139139

140140
} // namespace atomic_support

0 commit comments

Comments
 (0)