Skip to content

Commit a0959d0

Browse files
committed
Adds sycl::vec overloads to abs, cos, expm1, log, log1p, and sqrt
1 parent 5ec9fd5 commit a0959d0

File tree

6 files changed

+122
-6
lines changed

6 files changed

+122
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ namespace py = pybind11;
5252
namespace td_ns = dpctl::tensor::type_dispatch;
5353

5454
using dpctl::tensor::type_utils::is_complex;
55+
using dpctl::tensor::type_utils::vec_cast;
5556

5657
template <typename argT, typename resT> struct AbsFunctor
5758
{
5859

5960
using is_constant = typename std::false_type;
6061
// constexpr resT constant_value = resT{};
61-
using supports_vec = typename std::false_type;
62+
using supports_vec = typename std::negation<
63+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6264
using supports_sg_loadstore = typename std::negation<
6365
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466

@@ -127,6 +129,40 @@ template <typename argT, typename resT> struct AbsFunctor
127129
#endif
128130
}
129131
}
132+
133+
template <int vec_sz>
134+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
135+
{
136+
if constexpr (std::is_integral<argT>::value) {
137+
if constexpr (std::is_same_v<argT, bool> ||
138+
std::is_unsigned<argT>::value) {
139+
return in;
140+
}
141+
else {
142+
auto const &res_vec = sycl::abs(in);
143+
using deducedT = typename std::remove_cv_t<
144+
std::remove_reference_t<decltype(res_vec)>>::element_type;
145+
if constexpr (std::is_same_v<resT, deducedT>) {
146+
return res_vec;
147+
}
148+
else {
149+
150+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
151+
}
152+
}
153+
}
154+
else {
155+
auto const &res_vec = sycl::fabs(in);
156+
using deducedT = typename std::remove_cv_t<
157+
std::remove_reference_t<decltype(res_vec)>>::element_type;
158+
if constexpr (std::is_same_v<resT, deducedT>) {
159+
return res_vec;
160+
}
161+
else {
162+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
163+
}
164+
}
165+
}
130166
};
131167

132168
template <typename argT,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
template <typename argT, typename resT> struct CosFunctor
5556
{
@@ -59,7 +60,8 @@ template <typename argT, typename resT> struct CosFunctor
5960
// constant value, if constant
6061
// constexpr resT constant_value = resT{};
6162
// is function defined for sycl::vec
62-
using supports_vec = typename std::false_type;
63+
using supports_vec = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6365
// do both argTy and resTy support sugroup store/load operation
6466
using supports_sg_loadstore = typename std::negation<
6567
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -165,6 +167,20 @@ template <typename argT, typename resT> struct CosFunctor
165167
return std::cos(in);
166168
}
167169
}
170+
171+
template <int vec_sz>
172+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
173+
{
174+
auto const &res_vec = sycl::cos(in);
175+
using deducedT = typename std::remove_cv_t<
176+
std::remove_reference_t<decltype(res_vec)>>::element_type;
177+
if constexpr (std::is_same_v<resT, deducedT>) {
178+
return res_vec;
179+
}
180+
else {
181+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
182+
}
183+
}
168184
};
169185

170186
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252

5353
using dpctl::tensor::type_utils::is_complex;
54+
using dpctl::tensor::type_utils::vec_cast;
5455

5556
template <typename argT, typename resT> struct Expm1Functor
5657
{
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Expm1Functor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -132,6 +134,20 @@ template <typename argT, typename resT> struct Expm1Functor
132134
return std::expm1(in);
133135
}
134136
}
137+
138+
template <int vec_sz>
139+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
140+
{
141+
auto const &res_vec = sycl::expm1(in);
142+
using deducedT = typename std::remove_cv_t<
143+
std::remove_reference_t<decltype(res_vec)>>::element_type;
144+
if constexpr (std::is_same_v<resT, deducedT>) {
145+
return res_vec;
146+
}
147+
else {
148+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
149+
}
150+
}
135151
};
136152

137153
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252

5353
using dpctl::tensor::type_utils::is_complex;
54+
using dpctl::tensor::type_utils::vec_cast;
5455

5556
template <typename argT, typename resT> struct LogFunctor
5657
{
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct LogFunctor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -79,6 +81,20 @@ template <typename argT, typename resT> struct LogFunctor
7981
return std::log(in);
8082
}
8183
}
84+
85+
template <int vec_sz>
86+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
87+
{
88+
auto const &res_vec = sycl::log(in);
89+
using deducedT = typename std::remove_cv_t<
90+
std::remove_reference_t<decltype(res_vec)>>::element_type;
91+
if constexpr (std::is_same_v<resT, deducedT>) {
92+
return res_vec;
93+
}
94+
else {
95+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
96+
}
97+
}
8298
};
8399

84100
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
// TODO: evaluate precision against alternatives
5556
template <typename argT, typename resT> struct Log1pFunctor
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Log1pFunctor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -99,6 +101,20 @@ template <typename argT, typename resT> struct Log1pFunctor
99101
return std::log1p(in);
100102
}
101103
}
104+
105+
template <int vec_sz>
106+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
107+
{
108+
auto const &res_vec = sycl::log1p(in);
109+
using deducedT = typename std::remove_cv_t<
110+
std::remove_reference_t<decltype(res_vec)>>::element_type;
111+
if constexpr (std::is_same_v<resT, deducedT>) {
112+
return res_vec;
113+
}
114+
else {
115+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
116+
}
117+
}
102118
};
103119

104120
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ namespace py = pybind11;
5353
namespace td_ns = dpctl::tensor::type_dispatch;
5454

5555
using dpctl::tensor::type_utils::is_complex;
56+
using dpctl::tensor::type_utils::vec_cast;
5657

5758
template <typename argT, typename resT> struct SqrtFunctor
5859
{
@@ -62,7 +63,8 @@ template <typename argT, typename resT> struct SqrtFunctor
6263
// constant value, if constant
6364
// constexpr resT constant_value = resT{};
6465
// is function defined for sycl::vec
65-
using supports_vec = typename std::false_type;
66+
using supports_vec = typename std::negation<
67+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6668
// do both argTy and resTy support sugroup store/load operation
6769
using supports_sg_loadstore = typename std::negation<
6870
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -263,6 +265,20 @@ template <typename argT, typename resT> struct SqrtFunctor
263265
? csqrt_finite_unscaled(x, y)
264266
: csqrt_finite_scaled(x, y);
265267
}
268+
269+
template <int vec_sz>
270+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
271+
{
272+
auto const &res_vec = sycl::sqrt(in);
273+
using deducedT = typename std::remove_cv_t<
274+
std::remove_reference_t<decltype(res_vec)>>::element_type;
275+
if constexpr (std::is_same_v<resT, deducedT>) {
276+
return res_vec;
277+
}
278+
else {
279+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
280+
}
281+
}
266282
};
267283

268284
template <typename argTy,

0 commit comments

Comments
 (0)