Skip to content

Commit c94f9f8

Browse files
authored
Implement dpnp.logaddexp2 function (#1955)
* Add implementation of dpnp.logaddexp2 * Update third party tests * Add more tests * Update description based on review comments * Add missing space in docstring
1 parent 4498f86 commit c94f9f8

File tree

13 files changed

+413
-34
lines changed

13 files changed

+413
-34
lines changed

dpnp/backend/extensions/ufunc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(_elementwise_sources
3030
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmax.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmin.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/fmod.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp
3435
)
3536

dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "fmax.hpp"
3131
#include "fmin.hpp"
3232
#include "fmod.hpp"
33+
#include "logaddexp2.hpp"
3334
#include "radians.hpp"
3435

3536
namespace py = pybind11;
@@ -46,6 +47,7 @@ void init_elementwise_functions(py::module_ m)
4647
init_fmax(m);
4748
init_fmin(m);
4849
init_fmod(m);
50+
init_logaddexp2(m);
4951
init_radians(m);
5052
}
5153
} // namespace dpnp::extensions::ufunc
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// maxification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <sycl/sycl.hpp>
27+
28+
#include "dpctl4pybind11.hpp"
29+
30+
#include "kernels/elementwise_functions/logaddexp2.hpp"
31+
#include "logaddexp2.hpp"
32+
#include "populate.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "kernels/elementwise_functions/logaddexp.hpp"
42+
#include "utils/type_dispatch.hpp"
43+
44+
namespace dpnp::extensions::ufunc
45+
{
46+
namespace py = pybind11;
47+
namespace py_int = dpnp::extensions::py_internal;
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
50+
namespace impl
51+
{
52+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
53+
namespace logaddexp_ns = dpctl::tensor::kernels::logaddexp;
54+
55+
// Supports the same types table as for logaddexp function in dpctl
56+
template <typename T1, typename T2>
57+
using OutputType = logaddexp_ns::LogAddExpOutputType<T1, T2>;
58+
59+
using dpnp::kernels::logaddexp2::Logaddexp2Functor;
60+
61+
template <typename argT1,
62+
typename argT2,
63+
typename resT,
64+
unsigned int vec_sz = 4,
65+
unsigned int n_vecs = 2,
66+
bool enable_sg_loadstore = true>
67+
using ContigFunctor =
68+
ew_cmn_ns::BinaryContigFunctor<argT1,
69+
argT2,
70+
resT,
71+
Logaddexp2Functor<argT1, argT2, resT>,
72+
vec_sz,
73+
n_vecs,
74+
enable_sg_loadstore>;
75+
76+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
77+
using StridedFunctor =
78+
ew_cmn_ns::BinaryStridedFunctor<argT1,
79+
argT2,
80+
resT,
81+
IndexerT,
82+
Logaddexp2Functor<argT1, argT2, resT>>;
83+
84+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
85+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
86+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
87+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
88+
89+
static binary_contig_impl_fn_ptr_t
90+
logaddexp2_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
91+
static int logaddexp2_output_typeid_table[td_ns::num_types][td_ns::num_types];
92+
static binary_strided_impl_fn_ptr_t
93+
logaddexp2_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
94+
95+
MACRO_POPULATE_DISPATCH_TABLES(logaddexp2);
96+
} // namespace impl
97+
98+
void init_logaddexp2(py::module_ m)
99+
{
100+
using arrayT = dpctl::tensor::usm_ndarray;
101+
using event_vecT = std::vector<sycl::event>;
102+
{
103+
impl::populate_logaddexp2_dispatch_tables();
104+
using impl::logaddexp2_contig_dispatch_table;
105+
using impl::logaddexp2_output_typeid_table;
106+
using impl::logaddexp2_strided_dispatch_table;
107+
108+
auto logaddexp2_pyapi = [&](const arrayT &src1, const arrayT &src2,
109+
const arrayT &dst, sycl::queue &exec_q,
110+
const event_vecT &depends = {}) {
111+
return py_int::py_binary_ufunc(
112+
src1, src2, dst, exec_q, depends,
113+
logaddexp2_output_typeid_table,
114+
logaddexp2_contig_dispatch_table,
115+
logaddexp2_strided_dispatch_table,
116+
// no support of C-contig row with broadcasting in OneMKL
117+
td_ns::NullPtrTable<
118+
impl::
119+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
120+
td_ns::NullPtrTable<
121+
impl::
122+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
123+
};
124+
m.def("_logaddexp2", logaddexp2_pyapi, "", py::arg("src1"),
125+
py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"),
126+
py::arg("depends") = py::list());
127+
128+
auto logaddexp2_result_type_pyapi = [&](const py::dtype &dtype1,
129+
const py::dtype &dtype2) {
130+
return py_int::py_binary_ufunc_result_type(
131+
dtype1, dtype2, logaddexp2_output_typeid_table);
132+
};
133+
m.def("_logaddexp2_result_type", logaddexp2_result_type_pyapi);
134+
}
135+
}
136+
} // namespace dpnp::extensions::ufunc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <pybind11/pybind11.h>
29+
30+
namespace py = pybind11;
31+
32+
namespace dpnp::extensions::ufunc
33+
{
34+
void init_logaddexp2(py::module_ m);
35+
} // namespace dpnp::extensions::ufunc
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <cmath>
29+
#include <sycl/sycl.hpp>
30+
31+
namespace dpnp::kernels::logaddexp2
32+
{
33+
constexpr double log2e = 1.442695040888963407359924681001892137;
34+
35+
template <typename T>
36+
inline T log2_1p(T x)
37+
{
38+
return T(log2e) * sycl::log1p(x);
39+
}
40+
41+
template <typename T>
42+
inline T logaddexp2(T x, T y)
43+
{
44+
if (x == y) {
45+
// handles infinities of the same sign
46+
return x + 1;
47+
}
48+
49+
const T tmp = x - y;
50+
if (tmp > 0) {
51+
return x + log2_1p(sycl::exp2(-tmp));
52+
}
53+
else if (tmp <= 0) {
54+
return y + log2_1p(sycl::exp2(tmp));
55+
}
56+
return std::numeric_limits<T>::quiet_NaN();
57+
}
58+
59+
template <typename argT1, typename argT2, typename resT>
60+
struct Logaddexp2Functor
61+
{
62+
using supports_sg_loadstore = std::true_type;
63+
using supports_vec = std::false_type;
64+
65+
resT operator()(const argT1 &in1, const argT2 &in2) const
66+
{
67+
return logaddexp2<resT>(in1, in2);
68+
}
69+
};
70+
} // namespace dpnp::kernels::logaddexp2

dpnp/dpnp_iface_trigonometric.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"log1p",
8383
"log2",
8484
"logaddexp",
85+
"logaddexp2",
8586
"logsumexp",
8687
"rad2deg",
8788
"radians",
@@ -1201,7 +1202,7 @@ def cumlogsumexp(
12011202
See Also
12021203
--------
12031204
:obj:`dpnp.log` : Natural logarithm, element-wise.
1204-
:obj:`dpnp.log2` : Return the base 2 logarithm of the input array, element-wise.
1205+
:obj:`dpnp.log2` : Return the base-2 logarithm of the input array, element-wise.
12051206
:obj:`dpnp.log1p` : Return the natural logarithm of one plus the input array, element-wise.
12061207
12071208
Examples
@@ -1262,7 +1263,7 @@ def cumlogsumexp(
12621263
:obj:`dpnp.expm1` : ``exp(x) - 1``, the inverse of :obj:`dpnp.log1p`.
12631264
:obj:`dpnp.log` : Natural logarithm, element-wise.
12641265
:obj:`dpnp.log10` : Return the base 10 logarithm of the input array, element-wise.
1265-
:obj:`dpnp.log2` : Return the base 2 logarithm of the input array, element-wise.
1266+
:obj:`dpnp.log2` : Return the base-2 logarithm of the input array, element-wise.
12661267
12671268
Examples
12681269
--------
@@ -1390,7 +1391,10 @@ def cumlogsumexp(
13901391
--------
13911392
:obj:`dpnp.log` : Natural logarithm, element-wise.
13921393
:obj:`dpnp.exp` : Exponential, element-wise.
1393-
:obj:`dpnp.logsumdexp` : Logarithm of the sum of exponents of elements in the input array.
1394+
:obj:`dpnp.logaddexp2`: Logarithm of the sum of exponentiations of inputs in
1395+
base-2, element-wise.
1396+
:obj:`dpnp.logsumexp` : Logarithm of the sum of exponents of elements in the
1397+
input array.
13941398
13951399
Examples
13961400
--------
@@ -1412,6 +1416,75 @@ def cumlogsumexp(
14121416
)
14131417

14141418

1419+
_LOGADDEXP2_DOCSTRING = """
1420+
Calculates the logarithm of the sum of exponents in base-2 for each element
1421+
`x1_i` of the input array `x1` with the respective element `x2_i` of the input
1422+
array `x2`.
1423+
1424+
This function calculates `log2(2**x1 + 2**x2)`. It is useful in machine
1425+
learning when the calculated probabilities of events may be so small as
1426+
to exceed the range of normal floating point numbers. In such cases the base-2
1427+
logarithm of the calculated probability can be used instead. This function
1428+
allows adding probabilities stored in such a fashion.
1429+
1430+
For full documentation refer to :obj:`numpy.logaddexp2`.
1431+
1432+
Parameters
1433+
----------
1434+
x1 : {dpnp.ndarray, usm_ndarray, scalar}
1435+
First input array, expected to have a real-valued floating-point
1436+
data type.
1437+
Both inputs `x1` and `x2` can not be scalars at the same time.
1438+
x2 : {dpnp.ndarray, usm_ndarray, scalar}
1439+
Second input array, also expected to have a real-valued
1440+
floating-point data type.
1441+
Both inputs `x1` and `x2` can not be scalars at the same time.
1442+
out : {None, dpnp.ndarray, usm_ndarray}, optional
1443+
Output array to populate.
1444+
Array must have the correct shape and the expected data type.
1445+
Default: ``None``.
1446+
order : {"C", "F", "A", "K"}, optional
1447+
Memory layout of the newly output array, if parameter `out` is ``None``.
1448+
Default: ``"K"``.
1449+
1450+
Returns
1451+
-------
1452+
out : dpnp.ndarray
1453+
An array containing the element-wise results. The data type
1454+
of the returned array is determined by the Type Promotion Rules.
1455+
1456+
Limitations
1457+
-----------
1458+
Parameters `where` and `subok` are supported with their default values.
1459+
Keyword arguments `kwargs` are currently unsupported.
1460+
Otherwise ``NotImplementedError`` exception will be raised.
1461+
1462+
See Also
1463+
--------
1464+
:obj:`dpnp.logaddexp`: Natural logarithm of the sum of exponentiations of
1465+
inputs, element-wise.
1466+
:obj:`dpnp.logsumexp` : Logarithm of the sum of exponentiations of the inputs.
1467+
1468+
Examples
1469+
--------
1470+
>>> import dpnp as np
1471+
>>> prob1 = np.log2(np.array(1e-50))
1472+
>>> prob2 = np.log2(np.array(2.5e-50))
1473+
>>> prob12 = np.logaddexp2(prob1, prob2)
1474+
>>> prob1, prob2, prob12
1475+
(array(-166.09640474), array(-164.77447665), array(-164.28904982))
1476+
>>> 2**prob12
1477+
array(3.5e-50)
1478+
"""
1479+
1480+
logaddexp2 = DPNPBinaryFunc(
1481+
"logaddexp2",
1482+
ufi._logaddexp2_result_type,
1483+
ufi._logaddexp2,
1484+
_LOGADDEXP2_DOCSTRING,
1485+
)
1486+
1487+
14151488
def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
14161489
"""
14171490
Calculates the logarithm of the sum of exponents of elements in
@@ -1472,6 +1545,8 @@ def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None):
14721545
:obj:`dpnp.exp` : Exponential, element-wise.
14731546
:obj:`dpnp.logaddexp` : Logarithm of the sum of exponents of
14741547
the inputs, element-wise.
1548+
:obj:`dpnp.logaddexp2` : Logarithm of the sum of exponents of
1549+
the inputs in base-2, element-wise.
14751550
14761551
Examples
14771552
--------

0 commit comments

Comments
 (0)