Skip to content

Commit 4a2578f

Browse files
Closes gh-1279 for dpt.sqrt
This change provides private method csqrt to evaluate square-root for complex types. It handles special values as mandated by array API. The finite input, it provides its own implementation based on std::hypot and std::sqrt for real types instead of calling std::sqrt on finite input of complex type. Compile with -DUSE_STD_SQRT_FOR_COMPLEX_TYPES to use std::sqrt instead of custom implementation. Cursory performance study suggests that custom implementation is at least not worse than std::sqrt one.
1 parent 5f298e6 commit 4a2578f

File tree

1 file changed

+93
-1
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+93
-1
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
#pragma once
2727
#include <CL/sycl.hpp>
2828
#include <cmath>
29+
#include <complex>
2930
#include <cstddef>
3031
#include <cstdint>
32+
#include <limits>
3133
#include <type_traits>
3234

3335
#include "kernels/elementwise_functions/common.hpp"
@@ -66,7 +68,97 @@ template <typename argT, typename resT> struct SqrtFunctor
6668

6769
resT operator()(const argT &in)
6870
{
69-
return std::sqrt(in);
71+
if constexpr (is_complex<argT>::value) {
72+
// #ifdef _WINDOWS
73+
// return csqrt(in);
74+
// #else
75+
// return std::sqrt(in);
76+
// #endif
77+
return csqrt(in);
78+
}
79+
else {
80+
return std::sqrt(in);
81+
}
82+
}
83+
84+
private:
85+
template <typename T> std::complex<T> csqrt(std::complex<T> const &z) const
86+
{
87+
// csqrt(x + y*1j)
88+
// * csqrt(x - y * 1j) = conj(csqrt(x + y * 1j))
89+
// * If x is either +0 or -0 and y is +0, the result is +0 + 0j.
90+
// * If x is any value (including NaN) and y is +infinity, the result
91+
// is +infinity + infinity j.
92+
// * If x is a finite number and y is NaN, the result is NaN + NaN j.
93+
94+
// * If x -infinity and y is a positive (i.e., greater than 0) finite
95+
// number, the result is NaN + NaN j.
96+
// * If x is +infinity and y is a positive (i.e., greater than 0)
97+
// finite number, the result is +0 + infinity j.
98+
// * If x is -infinity and y is NaN, the result is NaN + infinity j
99+
// (sign of the imaginary component is unspecified).
100+
// * If x is +infinity and y is NaN, the result is +infinity + NaN j.
101+
// * If x is NaN and y is any value, the result is NaN + NaN j.
102+
103+
using realT = T;
104+
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
105+
constexpr realT p_inf = std::numeric_limits<realT>::infinity();
106+
constexpr realT zero = realT(0);
107+
108+
realT x = std::real(z);
109+
realT y = std::imag(z);
110+
111+
if (std::isinf(y)) {
112+
return {p_inf, y};
113+
}
114+
else if (std::isnan(x)) {
115+
return {x, q_nan};
116+
}
117+
else if (std::isinf(x)) { // x is an infinity
118+
// y is either finite, or nan
119+
if (std::signbit(x)) { // x == -inf
120+
return {(std::isfinite(y) ? zero : y), std::copysign(p_inf, y)};
121+
}
122+
else {
123+
return {p_inf, (std::isfinite(y) ? std::copysign(zero, y) : y)};
124+
}
125+
}
126+
else { // x is finite
127+
if (std::isfinite(y)) {
128+
#ifdef USE_STD_SQRT_FOR_COMPLEX_TYPES
129+
return std::sqrt(z);
130+
#else
131+
return csqrt_finite(x, y);
132+
#endif
133+
}
134+
else {
135+
return {q_nan, y};
136+
}
137+
}
138+
}
139+
140+
template <typename T>
141+
std::complex<T> csqrt_finite(T const &x, T const &y) const
142+
{
143+
// csqrt(x + y*1j) =
144+
// sqrt((cabs(x, y) + x) / 2) +
145+
// 1j * copysign(sqrt((cabs(x, y) - x) / 2), y)
146+
147+
using realT = T;
148+
constexpr realT half = realT(0x1.0p-1f); // 1/2
149+
constexpr realT zero = realT(0);
150+
151+
if (std::signbit(x)) {
152+
realT m = std::hypot(x, y);
153+
realT d = std::sqrt((m - x) * half);
154+
return {(d == zero ? zero : std::abs(y) / d * half),
155+
std::copysign(d, y)};
156+
}
157+
else {
158+
realT m = std::hypot(x, y);
159+
realT d = std::sqrt((m + x) * half);
160+
return {d, (d == zero) ? std::copysign(zero, y) : y * half / d};
161+
}
70162
}
71163
};
72164

0 commit comments

Comments
 (0)