Skip to content

Commit ac6bea6

Browse files
committed
Use sycl complex extension throughout element-wise and utils
1 parent 29eeac7 commit ac6bea6

File tree

26 files changed

+230
-157
lines changed

26 files changed

+230
-157
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ template <typename argT, typename resT> struct AcosFunctor
7272
using realT = typename argT::value_type;
7373

7474
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
75-
76-
const realT x = std::real(in);
77-
const realT y = std::imag(in);
75+
using sycl_complexT = exprm_ns::complex<realT>;
76+
sycl_complexT z = sycl_complexT(in);
77+
const realT x = exprm_ns::real(z);
78+
const realT y = exprm_ns::imag(z);
7879

7980
if (std::isnan(x)) {
8081
/* acos(NaN + I*+-Inf) = NaN + I*-+Inf */
@@ -106,20 +107,18 @@ template <typename argT, typename resT> struct AcosFunctor
106107
constexpr realT r_eps =
107108
realT(1) / std::numeric_limits<realT>::epsilon();
108109
if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) {
109-
using sycl_complexT = exprm_ns::complex<realT>;
110-
sycl_complexT log_in =
111-
exprm_ns::log(exprm_ns::complex<realT>(in));
110+
sycl_complexT log_z = exprm_ns::log(z);
112111

113-
const realT wx = log_in.real();
114-
const realT wy = log_in.imag();
112+
const realT wx = log_z.real();
113+
const realT wy = log_z.imag();
115114
const realT rx = sycl::fabs(wy);
116115

117116
realT ry = wx + sycl::log(realT(2));
118117
return resT{rx, (sycl::signbit(y)) ? ry : -ry};
119118
}
120119

121120
/* ordinary cases */
122-
return exprm_ns::acos(exprm_ns::complex<realT>(in)); // acos(in);
121+
return exprm_ns::acos(z); // acos(z);
123122
}
124123
else {
125124
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,33 +77,35 @@ template <typename argT, typename resT> struct AcoshFunctor
7777
* where the sign is chosen so Re(acosh(in)) >= 0.
7878
* So, we first calculate acos(in) and then acosh(in).
7979
*/
80-
const realT x = std::real(in);
81-
const realT y = std::imag(in);
80+
using sycl_complexT = exprm_ns::complex<realT>;
81+
sycl_complexT z = sycl_complexT(in);
82+
const realT x = exprm_ns::real(z);
83+
const realT y = exprm_ns::imag(z);
8284

83-
resT acos_in;
85+
sycl_complexT acos_z;
8486
if (std::isnan(x)) {
8587
/* acos(NaN + I*+-Inf) = NaN + I*-+Inf */
8688
if (std::isinf(y)) {
87-
acos_in = resT{q_nan, -y};
89+
acos_z = resT{q_nan, -y};
8890
}
8991
else {
90-
acos_in = resT{q_nan, q_nan};
92+
acos_z = resT{q_nan, q_nan};
9193
}
9294
}
9395
else if (std::isnan(y)) {
9496
/* acos(+-Inf + I*NaN) = NaN + I*opt(-)Inf */
9597
constexpr realT inf = std::numeric_limits<realT>::infinity();
9698

9799
if (std::isinf(x)) {
98-
acos_in = resT{q_nan, -inf};
100+
acos_z = resT{q_nan, -inf};
99101
}
100102
/* acos(0 + I*NaN) = Pi/2 + I*NaN with inexact */
101103
else if (x == realT(0)) {
102104
const realT pi_half = sycl::atan(realT(1)) * 2;
103-
acos_in = resT{pi_half, q_nan};
105+
acos_z = resT{pi_half, q_nan};
104106
}
105107
else {
106-
acos_in = resT{q_nan, q_nan};
108+
acos_z = resT{q_nan, q_nan};
107109
}
108110
}
109111

@@ -113,23 +115,21 @@ template <typename argT, typename resT> struct AcoshFunctor
113115
* For large x or y including acos(+-Inf + I*+-Inf)
114116
*/
115117
if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) {
116-
using sycl_complexT = typename exprm_ns::complex<realT>;
117-
const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in));
118-
const realT wx = log_in.real();
119-
const realT wy = log_in.imag();
118+
const sycl_complexT log_z = exprm_ns::log(z);
119+
const realT wx = log_z.real();
120+
const realT wy = log_z.imag();
120121
const realT rx = sycl::fabs(wy);
121122
realT ry = wx + sycl::log(realT(2));
122-
acos_in = resT{rx, (sycl::signbit(y)) ? ry : -ry};
123+
acos_z = resT{rx, (sycl::signbit(y)) ? ry : -ry};
123124
}
124125
else {
125126
/* ordinary cases */
126-
acos_in =
127-
exprm_ns::acos(exprm_ns::complex<realT>(in)); // acos(in);
127+
acos_z = exprm_ns::acos(z); // acos(z);
128128
}
129129

130130
/* Now we calculate acosh(z) */
131-
const realT rx = std::real(acos_in);
132-
const realT ry = std::imag(acos_in);
131+
const realT rx = exprm_ns::real(acos_z);
132+
const realT ry = exprm_ns::imag(acos_z);
133133

134134
/* acosh(NaN + I*NaN) = NaN + I*NaN */
135135
if (std::isnan(rx) && std::isnan(ry)) {
@@ -145,7 +145,7 @@ template <typename argT, typename resT> struct AcoshFunctor
145145
return resT{ry, ry};
146146
}
147147
/* ordinary cases */
148-
const realT res_im = sycl::copysign(rx, std::imag(in));
148+
const realT res_im = sycl::copysign(rx, exprm_ns::imag(z));
149149
return resT{sycl::fabs(ry), res_im};
150150
}
151151
else {

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ template <typename argT, typename resT> struct AsinFunctor
8080
* y = imag(I * conj(in)) = real(in)
8181
* and then return {imag(w), real(w)} which is asin(in)
8282
*/
83-
const realT x = std::imag(in);
84-
const realT y = std::real(in);
83+
using sycl_complexT = exprm_ns::complex<realT>;
84+
sycl_complexT z = sycl_complexT(in);
85+
const realT x = exprm_ns::imag(z);
86+
const realT y = exprm_ns::real(z);
8587

8688
if (std::isnan(x)) {
8789
/* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */
@@ -120,26 +122,24 @@ template <typename argT, typename resT> struct AsinFunctor
120122
constexpr realT r_eps =
121123
realT(1) / std::numeric_limits<realT>::epsilon();
122124
if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) {
123-
using sycl_complexT = exprm_ns::complex<realT>;
124-
const sycl_complexT z{x, y};
125+
const sycl_complexT z1{x, y};
125126
realT wx, wy;
126127
if (!sycl::signbit(x)) {
127-
const auto log_z = exprm_ns::log(z);
128-
wx = log_z.real() + sycl::log(realT(2));
129-
wy = log_z.imag();
128+
const auto log_z1 = exprm_ns::log(z1);
129+
wx = log_z1.real() + sycl::log(realT(2));
130+
wy = log_z1.imag();
130131
}
131132
else {
132-
const auto log_mz = exprm_ns::log(-z);
133-
wx = log_mz.real() + sycl::log(realT(2));
134-
wy = log_mz.imag();
133+
const auto log_mz1 = exprm_ns::log(-z1);
134+
wx = log_mz1.real() + sycl::log(realT(2));
135+
wy = log_mz1.imag();
135136
}
136137
const realT asinh_re = sycl::copysign(wx, x);
137138
const realT asinh_im = sycl::copysign(wy, y);
138139
return resT{asinh_im, asinh_re};
139140
}
140141
/* ordinary cases */
141-
return exprm_ns::asin(
142-
exprm_ns::complex<realT>(in)); // sycl::asin(in);
142+
return exprm_ns::asin(z); // sycl::asin(z);
143143
}
144144
else {
145145
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,10 @@ template <typename argT, typename resT> struct AsinhFunctor
7272
using realT = typename argT::value_type;
7373

7474
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
75-
76-
const realT x = std::real(in);
77-
const realT y = std::imag(in);
75+
using sycl_complexT = exprm_ns::complex<realT>;
76+
sycl_complexT z = sycl_complexT(in);
77+
const realT x = exprm_ns::real(z);
78+
const realT y = exprm_ns::imag(z);
7879

7980
if (std::isnan(x)) {
8081
/* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */
@@ -109,20 +110,18 @@ template <typename argT, typename resT> struct AsinhFunctor
109110
realT(1) / std::numeric_limits<realT>::epsilon();
110111

111112
if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) {
112-
using sycl_complexT = exprm_ns::complex<realT>;
113-
sycl_complexT log_in = (sycl::signbit(x))
114-
? exprm_ns::log(sycl_complexT(-in))
115-
: exprm_ns::log(sycl_complexT(in));
116-
realT wx = log_in.real() + sycl::log(realT(2));
117-
realT wy = log_in.imag();
113+
sycl_complexT log_in =
114+
(sycl::signbit(x)) ? exprm_ns::log(-z) : exprm_ns::log(z);
115+
realT wx = exprm_ns::real(log_in) + sycl::log(realT(2));
116+
realT wy = exprm_ns::imag(log_in);
118117

119118
const realT res_re = sycl::copysign(wx, x);
120119
const realT res_im = sycl::copysign(wy, y);
121120
return resT{res_re, res_im};
122121
}
123122

124123
/* ordinary cases */
125-
return exprm_ns::asinh(exprm_ns::complex<realT>(in)); // asinh(in);
124+
return exprm_ns::asinh(z); // asinh(z);
126125
}
127126
else {
128127
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,11 @@ template <typename argT, typename resT> struct AtanFunctor
8383
* y = imag(I * conj(in)) = real(in)
8484
* and then return {imag(w), real(w)} which is atan(in)
8585
*/
86-
const realT x = std::imag(in);
87-
const realT y = std::real(in);
86+
using sycl_complexT = exprm_ns::complex<realT>;
87+
sycl_complexT z = sycl_complexT(in);
88+
const realT x = exprm_ns::imag(z);
89+
const realT y = exprm_ns::real(z);
90+
8891
if (std::isnan(x)) {
8992
/* atanh(NaN + I*+-Inf) = sign(NaN)*0 + I*+-Pi/2 */
9093
if (std::isinf(y)) {
@@ -132,7 +135,7 @@ template <typename argT, typename resT> struct AtanFunctor
132135
return resT{atanh_im, atanh_re};
133136
}
134137
/* ordinary cases */
135-
return exprm_ns::atan(exprm_ns::complex<realT>(in)); // atan(in);
138+
return exprm_ns::atan(z); // atan(z);
136139
}
137140
else {
138141
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ template <typename argT, typename resT> struct AtanhFunctor
7373
using realT = typename argT::value_type;
7474
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
7575

76-
const realT x = std::real(in);
77-
const realT y = std::imag(in);
76+
using sycl_complexT = exprm_ns::complex<realT>;
77+
sycl_complexT z = sycl_complexT(in);
78+
const realT x = exprm_ns::real(z);
79+
const realT y = exprm_ns::imag(z);
7880

7981
if (std::isnan(x)) {
8082
/* atanh(NaN + I*+-Inf) = sign(NaN)0 + I*+-PI/2 */
@@ -123,7 +125,7 @@ template <typename argT, typename resT> struct AtanhFunctor
123125
return resT{res_re, res_im};
124126
}
125127
/* ordinary cases */
126-
return exprm_ns::atanh(exprm_ns::complex<realT>(in)); // atanh(in);
128+
return exprm_ns::atanh(z); // atanh(z);
127129
}
128130
else {
129131
static_assert(std::is_floating_point_v<argT> ||

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,19 @@ template <typename realT> realT cabs(std::complex<realT> const &z)
5151
// * If x is a finite number and y is NaN, the result is NaN.
5252
// * If x is NaN and y is NaN, the result is NaN.
5353

54-
const realT x = std::real(z);
55-
const realT y = std::imag(z);
54+
using sycl_complexT = exprm_ns::complex<realT>;
55+
sycl_complexT _z = exprm_ns::complex<realT>(z);
56+
const realT x = exprm_ns::real(_z);
57+
const realT y = exprm_ns::imag(_z);
5658

5759
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
5860
constexpr realT p_inf = std::numeric_limits<realT>::infinity();
5961

6062
const realT res =
6163
std::isinf(x)
6264
? p_inf
63-
: ((std::isinf(y)
64-
? p_inf
65-
: ((std::isnan(x)
66-
? q_nan
67-
: exprm_ns::abs(exprm_ns::complex<realT>(z))))));
65+
: ((std::isinf(y) ? p_inf
66+
: ((std::isnan(x) ? q_nan : exprm_ns::abs(_z)))));
6867

6968
return res;
7069
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,30 +72,31 @@ template <typename argT, typename resT> struct CosFunctor
7272
using realT = typename argT::value_type;
7373

7474
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
75+
using sycl_complexT = exprm_ns::complex<realT>;
76+
sycl_complexT z = sycl_complexT(in);
77+
const realT z_re = exprm_ns::real(z);
78+
const realT z_im = exprm_ns::imag(z);
7579

76-
realT const &in_re = std::real(in);
77-
realT const &in_im = std::imag(in);
78-
79-
const bool in_re_finite = std::isfinite(in_re);
80-
const bool in_im_finite = std::isfinite(in_im);
80+
const bool z_re_finite = std::isfinite(z_re);
81+
const bool z_im_finite = std::isfinite(z_im);
8182

8283
/*
8384
* Handle the nearly-non-exceptional cases where
8485
* real and imaginary parts of input are finite.
8586
*/
86-
if (in_re_finite && in_im_finite) {
87-
return exprm_ns::cos(exprm_ns::complex<realT>(in)); // cos(in);
87+
if (z_re_finite && z_im_finite) {
88+
return exprm_ns::cos(z); // cos(z);
8889
}
8990

9091
/*
91-
* since cos(in) = cosh(I * in), for special cases,
92-
* we return cosh(I * in).
92+
* since cos(z) = cosh(I * z), for special cases,
93+
* we return cosh(I * z).
9394
*/
94-
const realT x = -in_im;
95-
const realT y = in_re;
95+
const realT x = -z_im;
96+
const realT y = z_re;
9697

97-
const bool xfinite = in_im_finite;
98-
const bool yfinite = in_re_finite;
98+
const bool xfinite = z_im_finite;
99+
const bool yfinite = z_re_finite;
99100
/*
100101
* cosh(+-0 +- I Inf) = dNaN + I sign(d(+-0, dNaN))0.
101102
* The sign of 0 in the result is unspecified. Choice = normally

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ template <typename argT, typename resT> struct CoshFunctor
7373

7474
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
7575

76-
const realT x = std::real(in);
77-
const realT y = std::imag(in);
76+
using sycl_complexT = exprm_ns::complex<realT>;
77+
sycl_complexT z = sycl_complexT(in);
78+
const realT x = exprm_ns::real(z);
79+
const realT y = exprm_ns::imag(z);
7880

7981
const bool xfinite = std::isfinite(x);
8082
const bool yfinite = std::isfinite(y);
@@ -84,8 +86,7 @@ template <typename argT, typename resT> struct CoshFunctor
8486
* real and imaginary parts of input are finite.
8587
*/
8688
if (xfinite && yfinite) {
87-
return exprm_ns::cosh(
88-
exprm_ns::complex<realT>(in)); // cosh(in);
89+
return exprm_ns::cosh(z); // cosh(z);
8990
}
9091

9192
/*

dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ template <typename argT, typename resT> struct ExpFunctor
7272

7373
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
7474

75-
const realT x = std::real(in);
76-
const realT y = std::imag(in);
75+
using sycl_complexT = exprm_ns::complex<realT>;
76+
sycl_complexT z = sycl_complexT(in);
77+
const realT x = exprm_ns::real(z);
78+
const realT y = exprm_ns::imag(z);
7779
if (std::isfinite(x)) {
7880
if (std::isfinite(y)) {
79-
return exprm_ns::exp(
80-
exprm_ns::complex<realT>(in)); // exp(in);
81+
return exprm_ns::exp(z); // exp(z);
8182
}
8283
else {
8384
return resT{q_nan, q_nan};
@@ -86,7 +87,7 @@ template <typename argT, typename resT> struct ExpFunctor
8687
else if (std::isnan(x)) {
8788
/* x is nan */
8889
if (y == realT(0)) {
89-
return resT{in};
90+
return resT{z};
9091
}
9192
else {
9293
return resT{x, q_nan};

0 commit comments

Comments
 (0)