|
25 | 25 | #pragma once
|
26 | 26 | #include <CL/sycl.hpp>
|
27 | 27 | #include <cmath>
|
| 28 | +#include <complex> |
28 | 29 | #include <cstddef>
|
29 | 30 | #include <cstdint>
|
| 31 | +#include <limits> |
30 | 32 | #include <type_traits>
|
31 | 33 |
|
32 | 34 | #include "kernels/elementwise_functions/common.hpp"
|
|
36 | 38 | #include "utils/type_utils.hpp"
|
37 | 39 | #include <pybind11/pybind11.h>
|
38 | 40 |
|
39 |
| -#include <iostream> |
40 |
| - |
41 | 41 | namespace dpctl
|
42 | 42 | {
|
43 | 43 | namespace tensor
|
@@ -72,7 +72,58 @@ template <typename argT, typename resT> struct AbsFunctor
|
72 | 72 | return x;
|
73 | 73 | }
|
74 | 74 | else {
|
75 |
| - return std::abs(x); |
| 75 | + if constexpr (is_complex<argT>::value) { |
| 76 | + return cabs(x); |
| 77 | + } |
| 78 | + else if constexpr (std::is_same_v<argT, sycl::half> || |
| 79 | + std::is_floating_point_v<argT>) |
| 80 | + { |
| 81 | + return (std::signbit(x) ? -x : x); |
| 82 | + } |
| 83 | + else { |
| 84 | + return std::abs(x); |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | +private: |
| 90 | + template <typename realT> realT cabs(std::complex<realT> const &z) const |
| 91 | + { |
| 92 | + // Special values for cabs( x + y * 1j): |
| 93 | + // * If x is either +infinity or -infinity and y is any value |
| 94 | + // (including NaN), the result is +infinity. |
| 95 | + // * If x is any value (including NaN) and y is either +infinity or |
| 96 | + // -infinity, the result is +infinity. |
| 97 | + // * If x is either +0 or -0, the result is equal to abs(y). |
| 98 | + // * If y is either +0 or -0, the result is equal to abs(x). |
| 99 | + // * If x is NaN and y is a finite number, the result is NaN. |
| 100 | + // * If x is a finite number and y is NaN, the result is NaN. |
| 101 | + // * If x is NaN and y is NaN, the result is NaN. |
| 102 | + |
| 103 | + const realT x = std::real(z); |
| 104 | + const realT y = std::imag(z); |
| 105 | + |
| 106 | + constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN(); |
| 107 | + constexpr realT p_inf = std::numeric_limits<realT>::infinity(); |
| 108 | + |
| 109 | + if (std::isinf(x)) { |
| 110 | + return p_inf; |
| 111 | + } |
| 112 | + else if (std::isinf(y)) { |
| 113 | + return p_inf; |
| 114 | + } |
| 115 | + else if (std::isnan(x)) { |
| 116 | + return q_nan; |
| 117 | + } |
| 118 | + else if (std::isnan(y)) { |
| 119 | + return q_nan; |
| 120 | + } |
| 121 | + else { |
| 122 | +#ifdef USE_STD_ABS_FOR_COMPLEX_TYPES |
| 123 | + return std::abs(z); |
| 124 | +#else |
| 125 | + return std::hypot(std::real(z), std::imag(z)); |
| 126 | +#endif |
76 | 127 | }
|
77 | 128 | }
|
78 | 129 | };
|
|
0 commit comments