File tree Expand file tree Collapse file tree 13 files changed +47
-13
lines changed
libtensor/include/kernels/elementwise_functions Expand file tree Collapse file tree 13 files changed +47
-13
lines changed Original file line number Diff line number Diff line change @@ -64,7 +64,7 @@ set_source_files_properties(
64
64
if (UNIX )
65
65
set_source_files_properties (
66
66
${CMAKE_CURRENT_SOURCE_DIR} /libtensor/source/elementwise_functions.cpp
67
- PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES" )
67
+ PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX " )
68
68
endif ()
69
69
target_compile_options (${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int )
70
70
target_link_options (${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel )
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace acos
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -114,7 +116,8 @@ template <typename argT, typename resT> struct AcosFunctor
114
116
}
115
117
116
118
/* ordinary cases */
117
- return std::acos (in);
119
+ return cmplx_ns::acos (
120
+ cmplx_ns::complex<realT>(in)); // std::acos(in);
118
121
}
119
122
else {
120
123
static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace acosh
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -118,7 +120,8 @@ template <typename argT, typename resT> struct AcoshFunctor
118
120
}
119
121
else {
120
122
/* ordinary cases */
121
- acos_in = std::acos (in);
123
+ acos_in = cmplx_ns::acos (
124
+ cmplx_ns::complex<realT>(in)); // std::acos(in);
122
125
}
123
126
124
127
/* Now we calculate acosh(z) */
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace asin
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -134,7 +136,8 @@ template <typename argT, typename resT> struct AsinFunctor
134
136
return resT{asinh_im, asinh_re};
135
137
}
136
138
/* ordinary cases */
137
- return std::asin (in);
139
+ return cmplx_ns::asin (
140
+ cmplx_ns::complex<realT>(in)); // std::asin(in);
138
141
}
139
142
else {
140
143
static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace asinh
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -115,7 +117,8 @@ template <typename argT, typename resT> struct AsinhFunctor
115
117
}
116
118
117
119
/* ordinary cases */
118
- return std::asinh (in);
120
+ return cmplx_ns::asinh (
121
+ cmplx_ns::complex<realT>(in)); // std::asinh(in);
119
122
}
120
123
else {
121
124
static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 28
28
#include < complex>
29
29
#include < cstddef>
30
30
#include < cstdint>
31
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
31
32
#include < type_traits>
32
33
33
34
#include " kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace atan
48
49
49
50
namespace py = pybind11;
50
51
namespace td_ns = dpctl::tensor::type_dispatch;
52
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
51
53
52
54
using dpctl::tensor::type_utils::is_complex;
53
55
@@ -126,7 +128,8 @@ template <typename argT, typename resT> struct AtanFunctor
126
128
return resT{atanh_im, atanh_re};
127
129
}
128
130
/* ordinary cases */
129
- return std::atan (in);
131
+ return cmplx_ns::atan (
132
+ cmplx_ns::complex<realT>(in)); // std::atan(in);
130
133
}
131
134
else {
132
135
static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 28
28
#include < complex>
29
29
#include < cstddef>
30
30
#include < cstdint>
31
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
31
32
#include < type_traits>
32
33
33
34
#include " kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace atanh
48
49
49
50
namespace py = pybind11;
50
51
namespace td_ns = dpctl::tensor::type_dispatch;
52
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
51
53
52
54
using dpctl::tensor::type_utils::is_complex;
53
55
@@ -119,7 +121,8 @@ template <typename argT, typename resT> struct AtanhFunctor
119
121
return resT{res_re, res_im};
120
122
}
121
123
/* ordinary cases */
122
- return std::atanh (in);
124
+ return cmplx_ns::atanh (
125
+ cmplx_ns::complex<realT>(in)); // std::atanh(in);
123
126
}
124
127
else {
125
128
static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace cos
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -81,7 +83,8 @@ template <typename argT, typename resT> struct CosFunctor
81
83
* real and imaginary parts of input are finite.
82
84
*/
83
85
if (in_re_finite && in_im_finite) {
84
- return std::cos (in);
86
+ return cmplx_ns::cos (
87
+ cmplx_ns::complex<realT>(in)); // std::cos(in);
85
88
}
86
89
87
90
/*
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace cosh
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -81,7 +83,8 @@ template <typename argT, typename resT> struct CoshFunctor
81
83
* real and imaginary parts of input are finite.
82
84
*/
83
85
if (xfinite && yfinite) {
84
- return std::cosh (in);
86
+ return cmplx_ns::cosh (
87
+ cmplx_ns::complex<realT>(in)); // std::cosh(in);
85
88
}
86
89
87
90
/*
Original file line number Diff line number Diff line change 23
23
// ===---------------------------------------------------------------------===//
24
24
25
25
#pragma once
26
- #include < CL/sycl.hpp>
27
26
#include < cmath>
28
27
#include < cstddef>
29
28
#include < cstdint>
29
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
+ #include < sycl/sycl.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace sin
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
@@ -79,7 +81,8 @@ template <typename argT, typename resT> struct SinFunctor
79
81
* real and imaginary parts of input are finite.
80
82
*/
81
83
if (in_re_finite && in_im_finite) {
82
- return std::sin (in);
84
+ return cmplx_ns::sin (
85
+ cmplx_ns::complex<realT>(in)); // std::sin(in);
83
86
}
84
87
85
88
/*
Original file line number Diff line number Diff line change 27
27
#include < cmath>
28
28
#include < cstddef>
29
29
#include < cstdint>
30
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
30
31
#include < type_traits>
31
32
32
33
#include " kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace sinh
47
48
48
49
namespace py = pybind11;
49
50
namespace td_ns = dpctl::tensor::type_dispatch;
51
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
50
52
51
53
using dpctl::tensor::type_utils::is_complex;
52
54
Original file line number Diff line number Diff line change 28
28
#include < complex>
29
29
#include < cstddef>
30
30
#include < cstdint>
31
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
31
32
#include < type_traits>
32
33
33
34
#include " kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace tan
48
49
49
50
namespace py = pybind11;
50
51
namespace td_ns = dpctl::tensor::type_dispatch;
52
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
51
53
52
54
using dpctl::tensor::type_utils::is_complex;
53
55
@@ -118,7 +120,7 @@ template <typename argT, typename resT> struct TanFunctor
118
120
return resT{q_nan, q_nan};
119
121
}
120
122
/* ordinary cases */
121
- return std::tan (in);
123
+ return cmplx_ns::tan (cmplx_ns::complex<realT>(in)); // std::tan(in);
122
124
}
123
125
else {
124
126
static_assert (std::is_floating_point_v<argT> ||
Original file line number Diff line number Diff line change 29
29
#include < complex>
30
30
#include < cstddef>
31
31
#include < cstdint>
32
+ #include < sycl/ext/oneapi/experimental/sycl_complex.hpp>
32
33
#include < type_traits>
33
34
34
35
#include " kernels/elementwise_functions/common.hpp"
@@ -49,6 +50,7 @@ namespace tanh
49
50
50
51
namespace py = pybind11;
51
52
namespace td_ns = dpctl::tensor::type_dispatch;
53
+ namespace cmplx_ns = sycl::ext::oneapi::experimental;
52
54
53
55
using dpctl::tensor::type_utils::is_complex;
54
56
@@ -112,7 +114,8 @@ template <typename argT, typename resT> struct TanhFunctor
112
114
return resT{q_nan, q_nan};
113
115
}
114
116
/* ordinary cases */
115
- return std::tanh (in);
117
+ return cmplx_ns::tanh (
118
+ cmplx_ns::complex<realT>(in)); // std::tanh(in);
116
119
}
117
120
else {
118
121
static_assert (std::is_floating_point_v<argT> ||
You can’t perform that action at this time.
0 commit comments