Skip to content

Commit 5f298e6

Browse files
Closes gh-1279
Provide cabs private method implementating abs for complex types, paying attention to array-API mandated special values. To work-around gh-1279, use std::hypot to compute value for finite inputs. Compile with -DUSE_STD_ABS_FOR_COMPLEX_TYPES to use std::abs(z) instead of std::hypot(std::real(z), std::imag(z)).
1 parent 7d9974e commit 5f298e6

File tree

1 file changed

+54
-3
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+54
-3
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
#pragma once
2626
#include <CL/sycl.hpp>
2727
#include <cmath>
28+
#include <complex>
2829
#include <cstddef>
2930
#include <cstdint>
31+
#include <limits>
3032
#include <type_traits>
3133

3234
#include "kernels/elementwise_functions/common.hpp"
@@ -36,8 +38,6 @@
3638
#include "utils/type_utils.hpp"
3739
#include <pybind11/pybind11.h>
3840

39-
#include <iostream>
40-
4141
namespace dpctl
4242
{
4343
namespace tensor
@@ -72,7 +72,58 @@ template <typename argT, typename resT> struct AbsFunctor
7272
return x;
7373
}
7474
else {
75-
return std::abs(x);
75+
if constexpr (is_complex<argT>::value) {
76+
return cabs(x);
77+
}
78+
else if constexpr (std::is_same_v<argT, sycl::half> ||
79+
std::is_floating_point_v<argT>)
80+
{
81+
return (std::signbit(x) ? -x : x);
82+
}
83+
else {
84+
return std::abs(x);
85+
}
86+
}
87+
}
88+
89+
private:
90+
template <typename realT> realT cabs(std::complex<realT> const &z) const
91+
{
92+
// Special values for cabs( x + y * 1j):
93+
// * If x is either +infinity or -infinity and y is any value
94+
// (including NaN), the result is +infinity.
95+
// * If x is any value (including NaN) and y is either +infinity or
96+
// -infinity, the result is +infinity.
97+
// * If x is either +0 or -0, the result is equal to abs(y).
98+
// * If y is either +0 or -0, the result is equal to abs(x).
99+
// * If x is NaN and y is a finite number, the result is NaN.
100+
// * If x is a finite number and y is NaN, the result is NaN.
101+
// * If x is NaN and y is NaN, the result is NaN.
102+
103+
const realT x = std::real(z);
104+
const realT y = std::imag(z);
105+
106+
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
107+
constexpr realT p_inf = std::numeric_limits<realT>::infinity();
108+
109+
if (std::isinf(x)) {
110+
return p_inf;
111+
}
112+
else if (std::isinf(y)) {
113+
return p_inf;
114+
}
115+
else if (std::isnan(x)) {
116+
return q_nan;
117+
}
118+
else if (std::isnan(y)) {
119+
return q_nan;
120+
}
121+
else {
122+
#ifdef USE_STD_ABS_FOR_COMPLEX_TYPES
123+
return std::abs(z);
124+
#else
125+
return std::hypot(std::real(z), std::imag(z));
126+
#endif
76127
}
77128
}
78129
};

0 commit comments

Comments
 (0)