Skip to content

Commit 46c5cc9

Browse files
committed
Use experimental namespace in sequential dot product
1 parent 649724a commit 46c5cc9

File tree

1 file changed

+31
-25
lines changed

1 file changed

+31
-25
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,23 @@ struct SequentialDotProduct
9797
auto lhs_reduction_offset = reduction_offsets.get_first_offset();
9898
auto rhs_reduction_offset = reduction_offsets.get_second_offset();
9999

100-
using tu_ns::convert_impl;
101-
red_val += convert_impl<outT, lhsT>(
102-
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
103-
convert_impl<outT, rhsT>(
104-
rhs_[rhs_batch_offset + rhs_reduction_offset]);
100+
if constexpr (tu_ns::is_complex_v<outT>) {
101+
using realT = typename outT::value_type;
102+
using sycl_complex = exprm_ns::complex<realT>;
103+
104+
auto tmp = sycl_complex(red_val);
105+
tmp += sycl_complex(tu_ns::convert_impl<outT, lhsT>(
106+
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
107+
sycl_complex(tu_ns::convert_impl<outT, rhsT>(
108+
rhs_[rhs_batch_offset + rhs_reduction_offset]));
109+
red_val = outT(tmp);
110+
}
111+
else {
112+
red_val += tu_ns::convert_impl<outT, lhsT>(
113+
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
114+
tu_ns::convert_impl<outT, rhsT>(
115+
rhs_[rhs_batch_offset + rhs_reduction_offset]);
116+
}
105117
}
106118

107119
out_[out_batch_offset] = red_val;
@@ -180,10 +192,9 @@ struct DotProductFunctor
180192
const auto &rhs_reduction_offset =
181193
reduction_offsets_.get_second_offset();
182194

183-
using tu_ns::convert_impl;
184-
outT val = convert_impl<outT, lhsT>(
195+
outT val = tu_ns::convert_impl<outT, lhsT>(
185196
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
186-
convert_impl<outT, rhsT>(
197+
tu_ns::convert_impl<outT, rhsT>(
187198
rhs_[rhs_batch_offset + rhs_reduction_offset]);
188199

189200
local_red_val += val;
@@ -278,10 +289,9 @@ struct DotProductCustomFunctor
278289
const auto &rhs_reduction_offset =
279290
reduction_offsets_.get_second_offset();
280291

281-
using tu_ns::convert_impl;
282-
outT val = convert_impl<outT, lhsT>(
292+
outT val = tu_ns::convert_impl<outT, lhsT>(
283293
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
284-
convert_impl<outT, rhsT>(
294+
tu_ns::convert_impl<outT, rhsT>(
285295
rhs_[rhs_batch_offset + rhs_reduction_offset]);
286296

287297
local_red_val += val;
@@ -723,23 +733,21 @@ struct DotProductNoAtomicFunctor
723733
const auto &rhs_reduction_offset =
724734
reduction_offsets_.get_second_offset();
725735

726-
using tu_ns::convert_impl;
727-
using tu_ns::is_complex_v;
728-
if constexpr (is_complex_v<outT>) {
736+
if constexpr (tu_ns::is_complex_v<outT>) {
729737
using realT = typename outT::value_type;
730738
using sycl_complexT = exprm_ns::complex<realT>;
731739

732740
sycl_complexT val =
733-
sycl_complexT(convert_impl<outT, lhsT>(
741+
sycl_complexT(tu_ns::convert_impl<outT, lhsT>(
734742
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
735-
sycl_complexT(convert_impl<outT, rhsT>(
743+
sycl_complexT(tu_ns::convert_impl<outT, rhsT>(
736744
rhs_[rhs_batch_offset + rhs_reduction_offset]));
737745
local_red_val = outT(sycl_complexT(local_red_val) + val);
738746
}
739747
else {
740-
outT val = convert_impl<outT, lhsT>(
748+
outT val = tu_ns::convert_impl<outT, lhsT>(
741749
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
742-
convert_impl<outT, rhsT>(
750+
tu_ns::convert_impl<outT, rhsT>(
743751
rhs_[rhs_batch_offset + rhs_reduction_offset]);
744752
local_red_val += val;
745753
}
@@ -837,23 +845,21 @@ struct DotProductNoAtomicCustomFunctor
837845
const auto &rhs_reduction_offset =
838846
reduction_offsets_.get_second_offset();
839847

840-
using tu_ns::convert_impl;
841-
using tu_ns::is_complex_v;
842-
if constexpr (is_complex_v<outT>) {
848+
if constexpr (tu_ns::is_complex_v<outT>) {
843849
using realT = typename outT::value_type;
844850
using sycl_complexT = exprm_ns::complex<realT>;
845851

846852
sycl_complexT val =
847-
sycl_complexT(convert_impl<outT, lhsT>(
853+
sycl_complexT(tu_ns::convert_impl<outT, lhsT>(
848854
lhs_[lhs_batch_offset + lhs_reduction_offset])) *
849-
sycl_complexT(convert_impl<outT, rhsT>(
855+
sycl_complexT(tu_ns::convert_impl<outT, rhsT>(
850856
rhs_[rhs_batch_offset + rhs_reduction_offset]));
851857
local_red_val = outT(sycl_complexT(local_red_val) + val);
852858
}
853859
else {
854-
outT val = convert_impl<outT, lhsT>(
860+
outT val = tu_ns::convert_impl<outT, lhsT>(
855861
lhs_[lhs_batch_offset + lhs_reduction_offset]) *
856-
convert_impl<outT, rhsT>(
862+
tu_ns::convert_impl<outT, rhsT>(
857863
rhs_[rhs_batch_offset + rhs_reduction_offset]);
858864
local_red_val += val;
859865
}

0 commit comments

Comments
 (0)