Skip to content

Commit 7a99d3c

Browse files
committed
Update binary functions multiply and subtract to use experimental SYCL complex type
1 parent ac6bea6 commit 7a99d3c

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,20 @@ template <typename argT, typename resT> struct MultiplyInplaceFunctor
419419
using supports_vec = std::negation<
420420
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
421421

422-
void operator()(resT &res, const argT &in) { res *= in; }
422+
void operator()(resT &res, const argT &in)
423+
{
424+
if constexpr (tu_ns::is_complex_v<resT> && tu_ns::is_complex_v<argT>) {
425+
using res_rT = typename resT::value_type;
426+
using arg_rT = typename argT::value_type;
427+
428+
auto res1 = exprm_ns::complex<res_rT>(res);
429+
res1 *= exprm_ns::complex<arg_rT>(in);
430+
res = res1;
431+
}
432+
else {
433+
res *= in;
434+
}
435+
}
423436

424437
template <int vec_sz>
425438
void operator()(sycl::vec<resT, vec_sz> &res,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32+
#include "sycl_complex.hpp"
3233
#include "vec_size_util.hpp"
3334

3435
#include "utils/offset_utils.hpp"
@@ -62,7 +63,17 @@ template <typename argT1, typename argT2, typename resT> struct SubtractFunctor
6263

6364
resT operator()(const argT1 &in1, const argT2 &in2) const
6465
{
65-
return in1 - in2;
66+
if constexpr (tu_ns::is_complex_v<argT1> && tu_ns::is_complex_v<argT2>)
67+
{
68+
using realT1 = typename argT1::value_type;
69+
using realT2 = typename argT2::value_type;
70+
71+
return exprm_ns::complex<realT1>(in1) -
72+
exprm_ns::complex<realT2>(in2);
73+
}
74+
else {
75+
return in1 - in2;
76+
}
6677
}
6778

6879
template <int vec_sz>
@@ -424,7 +435,17 @@ template <typename argT, typename resT> struct SubtractInplaceFunctor
424435
void operator()(sycl::vec<resT, vec_sz> &res,
425436
const sycl::vec<argT, vec_sz> &in)
426437
{
427-
res -= in;
438+
if constexpr (tu_ns::is_complex_v<resT> && tu_ns::is_complex_v<argT>) {
439+
using res_rT = typename resT::value_type;
440+
using arg_rT = typename argT::value_type;
441+
442+
auto res1 = exprm_ns::complex<res_rT>(res);
443+
res1 -= exprm_ns::complex<arg_rT>(in);
444+
res = res1;
445+
}
446+
else {
447+
res -= in;
448+
}
428449
}
429450
};
430451

0 commit comments

Comments
 (0)