@@ -85,11 +85,21 @@ bool check_atomic_support(const sycl::queue &exec_q,
85
85
}
86
86
}
87
87
88
- template <typename fnT, typename T> struct MaxAtomicSupportFactory
88
+ template <typename fnT, typename T> struct ArithmeticAtomicSupportFactory
89
89
{
90
90
fnT get ()
91
91
{
92
- if constexpr (std::is_floating_point_v<T>) {
92
+ using dpctl::tensor::type_utils::is_complex;
93
+ if constexpr (std::is_floating_point_v<T> ||
94
+ std::is_same_v<T, sycl::half> || is_complex<T>::value)
95
+ {
96
+ // for real- and complex- floating point types, tree reduction has
97
+ // better round-off accumulation properties (round-off error is
98
+ // proportional to the log2(reduction_size), while naive elementwise
99
+ // summation used by atomic implementation has round-off error
100
+ // growing proportional to the reduction_size.), hence reduction
101
+ // over floating point types should always use tree_reduction
102
+ // algorithm, even though atomic implementation may be applicable
93
103
return fixed_decision<false >;
94
104
}
95
105
else {
@@ -98,43 +108,33 @@ template <typename fnT, typename T> struct MaxAtomicSupportFactory
98
108
}
99
109
};
100
110
101
- template <typename fnT, typename T> struct MinAtomicSupportFactory
111
+ template <typename fnT, typename T> struct MinMaxAtomicSupportFactory
102
112
{
103
113
fnT get ()
104
114
{
105
- if constexpr (std::is_floating_point_v<T>) {
106
- return fixed_decision<false >;
107
- }
108
- else {
109
- return check_atomic_support<T>;
110
- }
115
+ return check_atomic_support<T>;
111
116
}
112
117
};
113
118
114
- template <typename fnT, typename T> struct SumAtomicSupportFactory
119
+ template <typename fnT, typename T>
120
+ struct MaxAtomicSupportFactory : public ArithmeticAtomicSupportFactory <fnT, T>
115
121
{
116
- fnT get ()
117
- {
118
- if constexpr (std::is_floating_point_v<T>) {
119
- return fixed_decision<false >;
120
- }
121
- else {
122
- return check_atomic_support<T>;
123
- }
124
- }
125
122
};
126
123
127
- template <typename fnT, typename T> struct ProductAtomicSupportFactory
124
+ template <typename fnT, typename T>
125
+ struct MinAtomicSupportFactory : public ArithmeticAtomicSupportFactory <fnT, T>
126
+ {
127
+ };
128
+
129
+ template <typename fnT, typename T>
130
+ struct SumAtomicSupportFactory : public ArithmeticAtomicSupportFactory <fnT, T>
131
+ {
132
+ };
133
+
134
+ template <typename fnT, typename T>
135
+ struct ProductAtomicSupportFactory
136
+ : public ArithmeticAtomicSupportFactory<fnT, T>
128
137
{
129
- fnT get ()
130
- {
131
- if constexpr (std::is_floating_point_v<T>) {
132
- return fixed_decision<false >;
133
- }
134
- else {
135
- return check_atomic_support<T>;
136
- }
137
- }
138
138
};
139
139
140
140
} // namespace atomic_support
0 commit comments