Skip to content

Commit b9afb64

Browse files
committed
sycl::vec overload for sine
1 parent a0959d0 commit b9afb64

File tree

1 file changed

+16
-1
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+16
-1
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ template <typename argT, typename resT> struct SinFunctor
5858
// constant value, if constant
5959
// constexpr resT constant_value = resT{};
6060
// is function defined for sycl::vec
61-
using supports_vec = typename std::false_type;
61+
using supports_vec = typename std::negation<
62+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6263
// do both argTy and resTy support sugroup store/load operation
6364
using supports_sg_loadstore = typename std::negation<
6465
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -181,6 +182,20 @@ template <typename argT, typename resT> struct SinFunctor
181182
return std::sin(in);
182183
}
183184
}
185+
186+
template <int vec_sz>
187+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
188+
{
189+
auto const &res_vec = sycl::sin(in);
190+
using deducedT = typename std::remove_cv_t<
191+
std::remove_reference_t<decltype(res_vec)>>::element_type;
192+
if constexpr (std::is_same_v<resT, deducedT>) {
193+
return res_vec;
194+
}
195+
else {
196+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
197+
}
198+
}
184199
};
185200

186201
template <typename argTy,

0 commit comments

Comments
 (0)