Skip to content

Commit 3a36970

Browse files
Use sycl::ext::oneapi::experimental for complex trig/trigh and inverses
Use sycl_complex extension to implement complex-valued trigonometric, hyperbolic functions and their inverses. This works around use of double precision functions/literals in implementations of these functions in MSVC headers, causing failures to offload on Iris Xe for single precision input citing lack of fp64 support by the hardware.
1 parent b1c19fe commit 3a36970

File tree

13 files changed

+47
-13
lines changed

13 files changed

+47
-13
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ set_source_files_properties(
6464
if (UNIX)
6565
set_source_files_properties(
6666
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions.cpp
67-
PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES")
67+
PROPERTIES COMPILE_DEFINITIONS "USE_STD_ABS_FOR_COMPLEX_TYPES;USE_STD_SQRT_FOR_COMPLEX_TYPES;SYCL_EXT_ONEAPI_COMPLEX")
6868
endif()
6969
target_compile_options(${python_module_name} PRIVATE -fno-sycl-id-queries-fit-in-int)
7070
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace acos
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -114,7 +116,8 @@ template <typename argT, typename resT> struct AcosFunctor
114116
}
115117

116118
/* ordinary cases */
117-
return std::acos(in);
119+
return cmplx_ns::acos(
120+
cmplx_ns::complex<realT>(in)); // std::acos(in);
118121
}
119122
else {
120123
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace acosh
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -118,7 +120,8 @@ template <typename argT, typename resT> struct AcoshFunctor
118120
}
119121
else {
120122
/* ordinary cases */
121-
acos_in = std::acos(in);
123+
acos_in = cmplx_ns::acos(
124+
cmplx_ns::complex<realT>(in)); // std::acos(in);
122125
}
123126

124127
/* Now we calculate acosh(z) */

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace asin
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -134,7 +136,8 @@ template <typename argT, typename resT> struct AsinFunctor
134136
return resT{asinh_im, asinh_re};
135137
}
136138
/* ordinary cases */
137-
return std::asin(in);
139+
return cmplx_ns::asin(
140+
cmplx_ns::complex<realT>(in)); // std::asin(in);
138141
}
139142
else {
140143
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace asinh
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -115,7 +117,8 @@ template <typename argT, typename resT> struct AsinhFunctor
115117
}
116118

117119
/* ordinary cases */
118-
return std::asinh(in);
120+
return cmplx_ns::asinh(
121+
cmplx_ns::complex<realT>(in)); // std::asinh(in);
119122
}
120123
else {
121124
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <complex>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace atan
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355

@@ -126,7 +128,8 @@ template <typename argT, typename resT> struct AtanFunctor
126128
return resT{atanh_im, atanh_re};
127129
}
128130
/* ordinary cases */
129-
return std::atan(in);
131+
return cmplx_ns::atan(
132+
cmplx_ns::complex<realT>(in)); // std::atan(in);
130133
}
131134
else {
132135
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <complex>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace atanh
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355

@@ -119,7 +121,8 @@ template <typename argT, typename resT> struct AtanhFunctor
119121
return resT{res_re, res_im};
120122
}
121123
/* ordinary cases */
122-
return std::atanh(in);
124+
return cmplx_ns::atanh(
125+
cmplx_ns::complex<realT>(in)); // std::atanh(in);
123126
}
124127
else {
125128
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace cos
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -81,7 +83,8 @@ template <typename argT, typename resT> struct CosFunctor
8183
* real and imaginary parts of input are finite.
8284
*/
8385
if (in_re_finite && in_im_finite) {
84-
return std::cos(in);
86+
return cmplx_ns::cos(
87+
cmplx_ns::complex<realT>(in)); // std::cos(in);
8588
}
8689

8790
/*

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace cosh
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -81,7 +83,8 @@ template <typename argT, typename resT> struct CoshFunctor
8183
* real and imaginary parts of input are finite.
8284
*/
8385
if (xfinite && yfinite) {
84-
return std::cosh(in);
86+
return cmplx_ns::cosh(
87+
cmplx_ns::complex<realT>(in)); // std::cosh(in);
8588
}
8689

8790
/*

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
//===---------------------------------------------------------------------===//
2424

2525
#pragma once
26-
#include <CL/sycl.hpp>
2726
#include <cmath>
2827
#include <cstddef>
2928
#include <cstdint>
29+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
30+
#include <sycl/sycl.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace sin
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

@@ -79,7 +81,8 @@ template <typename argT, typename resT> struct SinFunctor
7981
* real and imaginary parts of input are finite.
8082
*/
8183
if (in_re_finite && in_im_finite) {
82-
return std::sin(in);
84+
return cmplx_ns::sin(
85+
cmplx_ns::complex<realT>(in)); // std::sin(in);
8386
}
8487

8588
/*

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <cmath>
2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3031
#include <type_traits>
3132

3233
#include "kernels/elementwise_functions/common.hpp"
@@ -47,6 +48,7 @@ namespace sinh
4748

4849
namespace py = pybind11;
4950
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5052

5153
using dpctl::tensor::type_utils::is_complex;
5254

dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <complex>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -48,6 +49,7 @@ namespace tan
4849

4950
namespace py = pybind11;
5051
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5153

5254
using dpctl::tensor::type_utils::is_complex;
5355

@@ -118,7 +120,7 @@ template <typename argT, typename resT> struct TanFunctor
118120
return resT{q_nan, q_nan};
119121
}
120122
/* ordinary cases */
121-
return std::tan(in);
123+
return cmplx_ns::tan(cmplx_ns::complex<realT>(in)); // std::tan(in);
122124
}
123125
else {
124126
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <complex>
3030
#include <cstddef>
3131
#include <cstdint>
32+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3233
#include <type_traits>
3334

3435
#include "kernels/elementwise_functions/common.hpp"
@@ -49,6 +50,7 @@ namespace tanh
4950

5051
namespace py = pybind11;
5152
namespace td_ns = dpctl::tensor::type_dispatch;
53+
namespace cmplx_ns = sycl::ext::oneapi::experimental;
5254

5355
using dpctl::tensor::type_utils::is_complex;
5456

@@ -112,7 +114,8 @@ template <typename argT, typename resT> struct TanhFunctor
112114
return resT{q_nan, q_nan};
113115
}
114116
/* ordinary cases */
115-
return std::tanh(in);
117+
return cmplx_ns::tanh(
118+
cmplx_ns::complex<realT>(in)); // std::tanh(in);
116119
}
117120
else {
118121
static_assert(std::is_floating_point_v<argT> ||

0 commit comments

Comments
 (0)