Skip to content

Commit d88e78f

Browse files
committed
logaddexp implementation moved to math_utils
Reduces code repetition between logsumexp and logaddexp
1 parent 448a7f1 commit d88e78f

File tree

3 files changed

+27
-38
lines changed

3 files changed

+27
-38
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <limits>
3232
#include <type_traits>
3333

34+
#include "utils/math_utils.hpp"
3435
#include "utils/offset_utils.hpp"
3536
#include "utils/type_dispatch.hpp"
3637
#include "utils/type_utils.hpp"
@@ -61,7 +62,8 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6162

6263
resT operator()(const argT1 &in1, const argT2 &in2) const
6364
{
64-
return impl<resT>(in1, in2);
65+
using dpctl::tensor::math_utils::logaddexp;
66+
return logaddexp<resT>(in1, in2);
6567
}
6668

6769
template <int vec_sz>
@@ -79,34 +81,15 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
7981
impl_finite<resT>(-std::abs(diff[i]));
8082
}
8183
else {
82-
res[i] = impl<resT>(in1[i], in2[i]);
84+
using dpctl::tensor::math_utils::logaddexp;
85+
res[i] = logaddexp<resT>(in1[i], in2[i]);
8386
}
8487
}
8588

8689
return res;
8790
}
8891

8992
private:
90-
template <typename T> T impl(T const &in1, T const &in2) const
91-
{
92-
if (in1 == in2) { // handle signed infinities
93-
const T log2 = std::log(T(2));
94-
return in1 + log2;
95-
}
96-
else {
97-
const T tmp = in1 - in2;
98-
if (tmp > 0) {
99-
return in1 + std::log1p(std::exp(-tmp));
100-
}
101-
else if (tmp <= 0) {
102-
return in2 + std::log1p(std::exp(tmp));
103-
}
104-
else {
105-
return std::numeric_limits<T>::quiet_NaN();
106-
}
107-
}
108-
}
109-
11093
template <typename T> T impl_finite(T const &in) const
11194
{
11295
return (in > 0) ? (in + std::log1p(std::exp(-in)))

dpctl/tensor/libtensor/include/utils/math_utils.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,26 @@ template <typename T> T min_complex(const T &x1, const T &x2)
115115
return (std::isnan(real1) || isnan_imag1 || lt) ? x1 : x2;
116116
}
117117

118+
template <typename T> T logaddexp(T x, T y)
119+
{
120+
if (x == y) { // handle signed infinities
121+
const T log2 = std::log(T(2));
122+
return x + log2;
123+
}
124+
else {
125+
const T tmp = x - y;
126+
if (tmp > 0) {
127+
return x + std::log1p(std::exp(-tmp));
128+
}
129+
else if (tmp <= 0) {
130+
return y + std::log1p(std::exp(tmp));
131+
}
132+
else {
133+
return std::numeric_limits<T>::quiet_NaN();
134+
}
135+
}
136+
}
137+
118138
} // namespace math_utils
119139
} // namespace tensor
120140
} // namespace dpctl

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -292,22 +292,8 @@ template <typename T> struct LogSumExp
292292
{
293293
T operator()(const T &x, const T &y) const
294294
{
295-
if (x == y) {
296-
const T log2 = std::log(T(2));
297-
return x + log2;
298-
}
299-
else {
300-
const T tmp = x - y;
301-
if (tmp > 0) {
302-
return x + std::log1p(std::exp(-tmp));
303-
}
304-
else if (tmp <= 0) {
305-
return y + std::log1p(std::exp(tmp));
306-
}
307-
else {
308-
return std::numeric_limits<T>::quiet_NaN();
309-
}
310-
}
295+
using dpctl::tensor::math_utils::logaddexp;
296+
return logaddexp<T>(x, y);
311297
}
312298
};
313299

0 commit comments

Comments
 (0)