40
40
#include " utils/sycl_utils.hpp"
41
41
#include " utils/type_utils.hpp"
42
42
43
+ #define SYCL_EXT_ONEAPI_COMPLEX
44
+ #include < sycl/ext/oneapi/experimental/complex/complex.hpp>
45
+
43
46
namespace dpctl
44
47
{
45
48
namespace tensor
@@ -49,6 +52,8 @@ namespace kernels
49
52
50
53
using dpctl::tensor::ssize_t ;
51
54
namespace su_ns = dpctl::tensor::sycl_utils;
55
+ namespace tu_ns = dpctl::tensor::type_utils;
56
+ namespace exprm_ns = sycl::ext::oneapi::experimental;
52
57
53
58
template <typename lhsT,
54
59
typename rhsT,
@@ -92,7 +97,7 @@ struct SequentialDotProduct
92
97
auto lhs_reduction_offset = reduction_offsets.get_first_offset ();
93
98
auto rhs_reduction_offset = reduction_offsets.get_second_offset ();
94
99
95
- using dpctl::tensor::type_utils ::convert_impl;
100
+ using tu_ns ::convert_impl;
96
101
red_val += convert_impl<outT, lhsT>(
97
102
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
98
103
convert_impl<outT, rhsT>(
@@ -175,7 +180,7 @@ struct DotProductFunctor
175
180
const auto &rhs_reduction_offset =
176
181
reduction_offsets_.get_second_offset ();
177
182
178
- using dpctl::tensor::type_utils ::convert_impl;
183
+ using tu_ns ::convert_impl;
179
184
outT val = convert_impl<outT, lhsT>(
180
185
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
181
186
convert_impl<outT, rhsT>(
@@ -273,7 +278,7 @@ struct DotProductCustomFunctor
273
278
const auto &rhs_reduction_offset =
274
279
reduction_offsets_.get_second_offset ();
275
280
276
- using dpctl::tensor::type_utils ::convert_impl;
281
+ using tu_ns ::convert_impl;
277
282
outT val = convert_impl<outT, lhsT>(
278
283
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
279
284
convert_impl<outT, rhsT>(
@@ -718,13 +723,26 @@ struct DotProductNoAtomicFunctor
718
723
const auto &rhs_reduction_offset =
719
724
reduction_offsets_.get_second_offset ();
720
725
721
- using dpctl::tensor::type_utils::convert_impl;
722
- outT val = convert_impl<outT, lhsT>(
723
- lhs_[lhs_batch_offset + lhs_reduction_offset]) *
724
- convert_impl<outT, rhsT>(
725
- rhs_[rhs_batch_offset + rhs_reduction_offset]);
726
-
727
- local_red_val += val;
726
+ using tu_ns::convert_impl;
727
+ using tu_ns::is_complex_v;
728
+ if constexpr (is_complex_v<outT>) {
729
+ using realT = typename outT::value_type;
730
+ using sycl_complexT = exprm_ns::complex<realT>;
731
+
732
+ sycl_complexT val =
733
+ sycl_complexT (convert_impl<outT, lhsT>(
734
+ lhs_[lhs_batch_offset + lhs_reduction_offset])) *
735
+ sycl_complexT (convert_impl<outT, rhsT>(
736
+ rhs_[rhs_batch_offset + rhs_reduction_offset]));
737
+ local_red_val = outT (sycl_complexT (local_red_val) + val);
738
+ }
739
+ else {
740
+ outT val = convert_impl<outT, lhsT>(
741
+ lhs_[lhs_batch_offset + lhs_reduction_offset]) *
742
+ convert_impl<outT, rhsT>(
743
+ rhs_[rhs_batch_offset + rhs_reduction_offset]);
744
+ local_red_val += val;
745
+ }
728
746
}
729
747
730
748
auto work_group = it.get_group ();
@@ -819,13 +837,26 @@ struct DotProductNoAtomicCustomFunctor
819
837
const auto &rhs_reduction_offset =
820
838
reduction_offsets_.get_second_offset ();
821
839
822
- using dpctl::tensor::type_utils::convert_impl;
823
- outT val = convert_impl<outT, lhsT>(
824
- lhs_[lhs_batch_offset + lhs_reduction_offset]) *
825
- convert_impl<outT, rhsT>(
826
- rhs_[rhs_batch_offset + rhs_reduction_offset]);
827
-
828
- local_red_val += val;
840
+ using tu_ns::convert_impl;
841
+ using tu_ns::is_complex_v;
842
+ if constexpr (is_complex_v<outT>) {
843
+ using realT = typename outT::value_type;
844
+ using sycl_complexT = exprm_ns::complex<realT>;
845
+
846
+ sycl_complexT val =
847
+ sycl_complexT (convert_impl<outT, lhsT>(
848
+ lhs_[lhs_batch_offset + lhs_reduction_offset])) *
849
+ sycl_complexT (convert_impl<outT, rhsT>(
850
+ rhs_[rhs_batch_offset + rhs_reduction_offset]));
851
+ local_red_val = outT (sycl_complexT (local_red_val) + val);
852
+ }
853
+ else {
854
+ outT val = convert_impl<outT, lhsT>(
855
+ lhs_[lhs_batch_offset + lhs_reduction_offset]) *
856
+ convert_impl<outT, rhsT>(
857
+ rhs_[rhs_batch_offset + rhs_reduction_offset]);
858
+ local_red_val += val;
859
+ }
829
860
}
830
861
831
862
auto work_group = it.get_group ();
0 commit comments