@@ -97,11 +97,23 @@ struct SequentialDotProduct
97
97
auto lhs_reduction_offset = reduction_offsets.get_first_offset ();
98
98
auto rhs_reduction_offset = reduction_offsets.get_second_offset ();
99
99
100
- using tu_ns::convert_impl;
101
- red_val += convert_impl<outT, lhsT>(
102
- lhs_[lhs_batch_offset + lhs_reduction_offset]) *
103
- convert_impl<outT, rhsT>(
104
- rhs_[rhs_batch_offset + rhs_reduction_offset]);
100
+ if constexpr (tu_ns::is_complex_v<outT>) {
101
+ using realT = typename outT::value_type;
102
+ using sycl_complex = exprm_ns::complex<realT>;
103
+
104
+ auto tmp = sycl_complex (red_val);
105
+ tmp += sycl_complex (tu_ns::convert_impl<outT, lhsT>(
106
+ lhs_[lhs_batch_offset + lhs_reduction_offset])) *
107
+ sycl_complex (tu_ns::convert_impl<outT, rhsT>(
108
+ rhs_[rhs_batch_offset + rhs_reduction_offset]));
109
+ red_val = outT (tmp);
110
+ }
111
+ else {
112
+ red_val += tu_ns::convert_impl<outT, lhsT>(
113
+ lhs_[lhs_batch_offset + lhs_reduction_offset]) *
114
+ tu_ns::convert_impl<outT, rhsT>(
115
+ rhs_[rhs_batch_offset + rhs_reduction_offset]);
116
+ }
105
117
}
106
118
107
119
out_[out_batch_offset] = red_val;
@@ -180,10 +192,9 @@ struct DotProductFunctor
180
192
const auto &rhs_reduction_offset =
181
193
reduction_offsets_.get_second_offset ();
182
194
183
- using tu_ns::convert_impl;
184
- outT val = convert_impl<outT, lhsT>(
195
+ outT val = tu_ns::convert_impl<outT, lhsT>(
185
196
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
186
- convert_impl<outT, rhsT>(
197
+ tu_ns:: convert_impl<outT, rhsT>(
187
198
rhs_[rhs_batch_offset + rhs_reduction_offset]);
188
199
189
200
local_red_val += val;
@@ -278,10 +289,9 @@ struct DotProductCustomFunctor
278
289
const auto &rhs_reduction_offset =
279
290
reduction_offsets_.get_second_offset ();
280
291
281
- using tu_ns::convert_impl;
282
- outT val = convert_impl<outT, lhsT>(
292
+ outT val = tu_ns::convert_impl<outT, lhsT>(
283
293
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
284
- convert_impl<outT, rhsT>(
294
+ tu_ns:: convert_impl<outT, rhsT>(
285
295
rhs_[rhs_batch_offset + rhs_reduction_offset]);
286
296
287
297
local_red_val += val;
@@ -723,23 +733,21 @@ struct DotProductNoAtomicFunctor
723
733
const auto &rhs_reduction_offset =
724
734
reduction_offsets_.get_second_offset ();
725
735
726
- using tu_ns::convert_impl;
727
- using tu_ns::is_complex_v;
728
- if constexpr (is_complex_v<outT>) {
736
+ if constexpr (tu_ns::is_complex_v<outT>) {
729
737
using realT = typename outT::value_type;
730
738
using sycl_complexT = exprm_ns::complex<realT>;
731
739
732
740
sycl_complexT val =
733
- sycl_complexT (convert_impl<outT, lhsT>(
741
+ sycl_complexT (tu_ns:: convert_impl<outT, lhsT>(
734
742
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
735
- sycl_complexT (convert_impl<outT, rhsT>(
743
+ sycl_complexT (tu_ns:: convert_impl<outT, rhsT>(
736
744
rhs_[rhs_batch_offset + rhs_reduction_offset]));
737
745
local_red_val = outT (sycl_complexT (local_red_val) + val);
738
746
}
739
747
else {
740
- outT val = convert_impl<outT, lhsT>(
748
+ outT val = tu_ns:: convert_impl<outT, lhsT>(
741
749
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
742
- convert_impl<outT, rhsT>(
750
+ tu_ns:: convert_impl<outT, rhsT>(
743
751
rhs_[rhs_batch_offset + rhs_reduction_offset]);
744
752
local_red_val += val;
745
753
}
@@ -837,23 +845,21 @@ struct DotProductNoAtomicCustomFunctor
837
845
const auto &rhs_reduction_offset =
838
846
reduction_offsets_.get_second_offset ();
839
847
840
- using tu_ns::convert_impl;
841
- using tu_ns::is_complex_v;
842
- if constexpr (is_complex_v<outT>) {
848
+ if constexpr (tu_ns::is_complex_v<outT>) {
843
849
using realT = typename outT::value_type;
844
850
using sycl_complexT = exprm_ns::complex<realT>;
845
851
846
852
sycl_complexT val =
847
- sycl_complexT (convert_impl<outT, lhsT>(
853
+ sycl_complexT (tu_ns:: convert_impl<outT, lhsT>(
848
854
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
849
- sycl_complexT (convert_impl<outT, rhsT>(
855
+ sycl_complexT (tu_ns:: convert_impl<outT, rhsT>(
850
856
rhs_[rhs_batch_offset + rhs_reduction_offset]));
851
857
local_red_val = outT (sycl_complexT (local_red_val) + val);
852
858
}
853
859
else {
854
- outT val = convert_impl<outT, lhsT>(
860
+ outT val = tu_ns:: convert_impl<outT, lhsT>(
855
861
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
856
- convert_impl<outT, rhsT>(
862
+ tu_ns:: convert_impl<outT, rhsT>(
857
863
rhs_[rhs_batch_offset + rhs_reduction_offset]);
858
864
local_red_val += val;
859
865
}
0 commit comments