Skip to content

Commit d825887

Browse files
Use sycl_complex in add, conj
1 parent c1e7ab1 commit d825887

File tree

2 files changed

+32
-2
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

2 files changed

+32
-2
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <CL/sycl.hpp>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "utils/offset_utils.hpp"
@@ -49,6 +50,7 @@ namespace add
4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
5152
namespace tu_ns = dpctl::tensor::type_utils;
53+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5254

5355
template <typename argT1, typename argT2, typename resT> struct AddFunctor
5456
{
@@ -60,7 +62,31 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
6062

6163
resT operator()(const argT1 &in1, const argT2 &in2) const
6264
{
63-
return in1 + in2;
65+
if constexpr (tu_ns::is_complex<argT1>::value &&
66+
tu_ns::is_complex<argT2>::value)
67+
{
68+
using rT1 = typename argT1::value_type;
69+
using rT2 = typename argT2::value_type;
70+
71+
return exprm_ns::complex<rT1>(in1) + exprm_ns::complex<rT2>(in2);
72+
}
73+
else if constexpr (tu_ns::is_complex<argT1>::value &&
74+
!tu_ns::is_complex<argT2>::value)
75+
{
76+
using rT1 = typename argT1::value_type;
77+
78+
return exprm_ns::complex<rT1>(in1) + in2;
79+
}
80+
else if constexpr (!tu_ns::is_complex<argT1>::value &&
81+
tu_ns::is_complex<argT2>::value)
82+
{
83+
using rT2 = typename argT2::value_type;
84+
85+
return in1 + exprm_ns::complex<rT2>(in2);
86+
}
87+
else {
88+
return in1 + in2;
89+
}
6490
}
6591

6692
template <int vec_sz>

dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <complex>
3030
#include <cstddef>
3131
#include <cstdint>
32+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3233
#include <type_traits>
3334

3435
#include "kernels/elementwise_functions/common.hpp"
@@ -49,6 +50,7 @@ namespace conj
4950

5051
namespace py = pybind11;
5152
namespace td_ns = dpctl::tensor::type_dispatch;
53+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5254

5355
using dpctl::tensor::type_utils::is_complex;
5456

@@ -68,7 +70,9 @@ template <typename argT, typename resT> struct ConjFunctor
6870
resT operator()(const argT &in) const
6971
{
7072
if constexpr (is_complex<argT>::value) {
71-
return std::conj(in);
73+
using rT = typename argT::value_type;
74+
75+
return exprm_ns::conj(exprm_ns::complex<rT>(in)); // std::conj(in);
7276
}
7377
else {
7478
if constexpr (!std::is_same_v<argT, bool>)

0 commit comments

Comments
 (0)