Skip to content

Commit 649724a

Browse files
committed
Use experimental SYCL complex in dot product
1 parent 7a99d3c commit 649724a

File tree

1 file changed

+48
-17
lines changed

1 file changed

+48
-17
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
#include "utils/sycl_utils.hpp"
4141
#include "utils/type_utils.hpp"
4242

43+
#define SYCL_EXT_ONEAPI_COMPLEX
44+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
45+
4346
namespace dpctl
4447
{
4548
namespace tensor
@@ -49,6 +52,8 @@ namespace kernels
4952

5053
using dpctl::tensor::ssize_t;
5154
namespace su_ns = dpctl::tensor::sycl_utils;
55+
namespace tu_ns = dpctl::tensor::type_utils;
56+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5257

5358
template <typename lhsT,
5459
typename rhsT,
@@ -92,7 +97,7 @@ struct SequentialDotProduct
9297
auto lhs_reduction_offset = reduction_offsets.get_first_offset();
9398
auto rhs_reduction_offset = reduction_offsets.get_second_offset();
9499

95-
using dpctl::tensor::type_utils::convert_impl;
100+
using tu_ns::convert_impl;
96101
red_val += convert_impl<outT, lhsT>(
97102
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
98103
convert_impl<outT, rhsT>(
@@ -175,7 +180,7 @@ struct DotProductFunctor
175180
const auto &rhs_reduction_offset =
176181
reduction_offsets_.get_second_offset();
177182

178-
using dpctl::tensor::type_utils::convert_impl;
183+
using tu_ns::convert_impl;
179184
outT val = convert_impl<outT, lhsT>(
180185
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
181186
convert_impl<outT, rhsT>(
@@ -273,7 +278,7 @@ struct DotProductCustomFunctor
273278
const auto &rhs_reduction_offset =
274279
reduction_offsets_.get_second_offset();
275280

276-
using dpctl::tensor::type_utils::convert_impl;
281+
using tu_ns::convert_impl;
277282
outT val = convert_impl<outT, lhsT>(
278283
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
279284
convert_impl<outT, rhsT>(
@@ -718,13 +723,26 @@ struct DotProductNoAtomicFunctor
718723
const auto &rhs_reduction_offset =
719724
reduction_offsets_.get_second_offset();
720725

721-
using dpctl::tensor::type_utils::convert_impl;
722-
outT val = convert_impl<outT, lhsT>(
723-
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
724-
convert_impl<outT, rhsT>(
725-
rhs_[rhs_batch_offset + rhs_reduction_offset]);
726-
727-
local_red_val += val;
726+
using tu_ns::convert_impl;
727+
using tu_ns::is_complex_v;
728+
if constexpr (is_complex_v<outT>) {
729+
using realT = typename outT::value_type;
730+
using sycl_complexT = exprm_ns::complex<realT>;
731+
732+
sycl_complexT val =
733+
sycl_complexT(convert_impl<outT, lhsT>(
734+
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
735+
sycl_complexT(convert_impl<outT, rhsT>(
736+
rhs_[rhs_batch_offset + rhs_reduction_offset]));
737+
local_red_val = outT(sycl_complexT(local_red_val) + val);
738+
}
739+
else {
740+
outT val = convert_impl<outT, lhsT>(
741+
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
742+
convert_impl<outT, rhsT>(
743+
rhs_[rhs_batch_offset + rhs_reduction_offset]);
744+
local_red_val += val;
745+
}
728746
}
729747

730748
auto work_group = it.get_group();
@@ -819,13 +837,26 @@ struct DotProductNoAtomicCustomFunctor
819837
const auto &rhs_reduction_offset =
820838
reduction_offsets_.get_second_offset();
821839

822-
using dpctl::tensor::type_utils::convert_impl;
823-
outT val = convert_impl<outT, lhsT>(
824-
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
825-
convert_impl<outT, rhsT>(
826-
rhs_[rhs_batch_offset + rhs_reduction_offset]);
827-
828-
local_red_val += val;
840+
using tu_ns::convert_impl;
841+
using tu_ns::is_complex_v;
842+
if constexpr (is_complex_v<outT>) {
843+
using realT = typename outT::value_type;
844+
using sycl_complexT = exprm_ns::complex<realT>;
845+
846+
sycl_complexT val =
847+
sycl_complexT(convert_impl<outT, lhsT>(
848+
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
849+
sycl_complexT(convert_impl<outT, rhsT>(
850+
rhs_[rhs_batch_offset + rhs_reduction_offset]));
851+
local_red_val = outT(sycl_complexT(local_red_val) + val);
852+
}
853+
else {
854+
outT val = convert_impl<outT, lhsT>(
855+
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
856+
convert_impl<outT, rhsT>(
857+
rhs_[rhs_batch_offset + rhs_reduction_offset]);
858+
local_red_val += val;
859+
}
829860
}
830861

831862
auto work_group = it.get_group();

0 commit comments

Comments
 (0)