Skip to content

Commit 946ff08

Browse files
authored
re-write dpnp.hypot (#1560)
* re-write dpnp.hypot * address comments * fix docstring and precommit
1 parent e3be611 commit 946ff08

13 files changed

+308
-87
lines changed

dpnp/backend/extensions/vm/hypot.hpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event hypot_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::hypot(exec_q,
56+
n, // number of elements to be calculated
57+
a, // pointer `a` containing 1st input vector of size n
58+
b, // pointer `b` containing 2nd input vector of size n
59+
y, // pointer `y` to the output vector of size n
60+
depends);
61+
}
62+
63+
template <typename fnT, typename T>
64+
struct HypotContigFactory
65+
{
66+
fnT get()
67+
{
68+
if constexpr (std::is_same_v<
69+
typename types::HypotOutputType<T>::value_type, void>)
70+
{
71+
return nullptr;
72+
}
73+
else {
74+
return hypot_contig_impl<T>;
75+
}
76+
}
77+
};
78+
} // namespace vm
79+
} // namespace ext
80+
} // namespace backend
81+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,21 @@ struct FloorOutputType
291291
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
292292
};
293293

294+
/**
295+
* @brief A factory to define pairs of supported types for which
296+
* MKL VM library provides support in oneapi::mkl::vm::hypot<T> function.
297+
*
298+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
299+
*/
300+
template <typename T>
301+
struct HypotOutputType
302+
{
303+
using value_type = typename std::disjunction<
304+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
305+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
306+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
307+
};
308+
294309
/**
295310
* @brief A factory to define pairs of supported types for which
296311
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "cosh.hpp"
4646
#include "div.hpp"
4747
#include "floor.hpp"
48+
#include "hypot.hpp"
4849
#include "ln.hpp"
4950
#include "mul.hpp"
5051
#include "pow.hpp"
@@ -74,11 +75,12 @@ static unary_impl_fn_ptr_t atan_dispatch_vector[dpctl_td_ns::num_types];
7475
static binary_impl_fn_ptr_t atan2_dispatch_vector[dpctl_td_ns::num_types];
7576
static unary_impl_fn_ptr_t atanh_dispatch_vector[dpctl_td_ns::num_types];
7677
static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types];
78+
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
7779
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
7880
static unary_impl_fn_ptr_t cosh_dispatch_vector[dpctl_td_ns::num_types];
7981
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
8082
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
81-
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
83+
static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types];
8284
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
8385
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
8486
static binary_impl_fn_ptr_t pow_dispatch_vector[dpctl_td_ns::num_types];
@@ -494,6 +496,35 @@ PYBIND11_MODULE(_vm_impl, m)
494496
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
495497
}
496498

499+
// BinaryUfunc: ==== Hypot(x1, x2) ====
500+
{
501+
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
502+
vm_ext::HypotContigFactory>(
503+
hypot_dispatch_vector);
504+
505+
auto hypot_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
506+
arrayT dst, const event_vecT &depends = {}) {
507+
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
508+
hypot_dispatch_vector);
509+
};
510+
m.def("_hypot", hypot_pyapi,
511+
"Call `hypot` function from OneMKL VM library to compute element "
512+
"by element hypotenuse of `x`",
513+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
514+
py::arg("dst"), py::arg("depends") = py::list());
515+
516+
auto hypot_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
517+
arrayT src2, arrayT dst) {
518+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
519+
hypot_dispatch_vector);
520+
};
521+
m.def("_mkl_hypot_to_call", hypot_need_to_call_pyapi,
522+
"Check input arguments to answer if `hypot` function from "
523+
"OneMKL VM library can be used",
524+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
525+
py::arg("dst"));
526+
}
527+
497528
// UnaryUfunc: ==== Ln(x) ====
498529
{
499530
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,6 @@ enum class DPNPFuncName : size_t
203203
DPNP_FN_GREATER_EQUAL_EXT, /**< Used in numpy.greater_equal() impl, requires
204204
extra parameters */
205205
DPNP_FN_HYPOT, /**< Used in numpy.hypot() impl */
206-
DPNP_FN_HYPOT_EXT, /**< Used in numpy.hypot() impl, requires extra
207-
parameters */
208206
DPNP_FN_IDENTITY, /**< Used in numpy.identity() impl */
209207
DPNP_FN_IDENTITY_EXT, /**< Used in numpy.identity() impl, requires extra
210208
parameters */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,19 +1514,6 @@ static void func_map_elemwise_2arg_3type_short_core(func_map_t &fmap)
15141514
func_type_map_t::find_type<FT1>,
15151515
func_type_map_t::find_type<FTs>>}),
15161516
...);
1517-
((fmap[DPNPFuncName::DPNP_FN_HYPOT_EXT][FT1][FTs] =
1518-
{get_floating_res_type<FT1, FTs>(),
1519-
(void *)dpnp_hypot_c_ext<
1520-
func_type_map_t::find_type<get_floating_res_type<FT1, FTs>()>,
1521-
func_type_map_t::find_type<FT1>,
1522-
func_type_map_t::find_type<FTs>>,
1523-
get_floating_res_type<FT1, FTs, std::false_type>(),
1524-
(void *)dpnp_hypot_c_ext<
1525-
func_type_map_t::find_type<
1526-
get_floating_res_type<FT1, FTs, std::false_type>()>,
1527-
func_type_map_t::find_type<FT1>,
1528-
func_type_map_t::find_type<FTs>>}),
1529-
...);
15301517
((fmap[DPNPFuncName::DPNP_FN_MAXIMUM_EXT][FT1][FTs] =
15311518
{get_floating_res_type<FT1, FTs, std::true_type, std::true_type>(),
15321519
(void *)dpnp_maximum_c_ext<

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
108108
DPNP_FN_FMOD_EXT
109109
DPNP_FN_FULL
110110
DPNP_FN_FULL_LIKE
111-
DPNP_FN_HYPOT
112-
DPNP_FN_HYPOT_EXT
113111
DPNP_FN_IDENTITY
114112
DPNP_FN_IDENTITY_EXT
115113
DPNP_FN_INV
@@ -384,8 +382,6 @@ cpdef dpnp_descriptor dpnp_copy(dpnp_descriptor x1)
384382
"""
385383
Mathematical functions
386384
"""
387-
cpdef dpnp_descriptor dpnp_hypot(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
388-
dpnp_descriptor out=*, object where=*)
389385
cpdef dpnp_descriptor dpnp_fmax(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
390386
dpnp_descriptor out=*, object where=*)
391387
cpdef dpnp_descriptor dpnp_fmin(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ __all__ += [
4646
"dpnp_fabs",
4747
"dpnp_fmod",
4848
"dpnp_gradient",
49-
'dpnp_hypot',
5049
"dpnp_fmax",
5150
"dpnp_fmin",
5251
"dpnp_modf",
@@ -273,14 +272,6 @@ cpdef utils.dpnp_descriptor dpnp_gradient(utils.dpnp_descriptor y1, int dx=1):
273272
return result
274273

275274

276-
cpdef utils.dpnp_descriptor dpnp_hypot(utils.dpnp_descriptor x1_obj,
277-
utils.dpnp_descriptor x2_obj,
278-
object dtype=None,
279-
utils.dpnp_descriptor out=None,
280-
object where=True):
281-
return call_fptr_2in_1out_strides(DPNP_FN_HYPOT_EXT, x1_obj, x2_obj, dtype, out, where)
282-
283-
284275
cpdef utils.dpnp_descriptor dpnp_fmax(utils.dpnp_descriptor x1_obj,
285276
utils.dpnp_descriptor x2_obj,
286277
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"dpnp_floor_divide",
6464
"dpnp_greater",
6565
"dpnp_greater_equal",
66+
"dpnp_hypot",
6667
"dpnp_imag",
6768
"dpnp_invert",
6869
"dpnp_isfinite",
@@ -1264,6 +1265,66 @@ def dpnp_greater_equal(x1, x2, out=None, order="K"):
12641265
return dpnp_array._create_from_usm_ndarray(res_usm)
12651266

12661267

1268+
_hypot_docstring_ = """
1269+
hypot(x1, x2, out=None, order="K")
1270+
Calculates the hypotenuse for a right triangle with "legs" `x1_i` and `x2_i` of
1271+
input arrays `x1` and `x2`.
1272+
Args:
1273+
x1 (dpnp.ndarray):
1274+
First input array, expected to have a real-valued data type.
1275+
x2 (dpnp.ndarray):
1276+
Second input array, also expected to have a real-valued data type.
1277+
out ({None, dpnp.ndarray}, optional):
1278+
Output array to populate.
1279+
Array have the correct shape and the expected data type.
1280+
order ("C","F","A","K", None, optional):
1281+
Memory layout of the newly output array, if parameter `out` is `None`.
1282+
Default: "K".
1283+
Returns:
1284+
dpnp.ndarray:
1285+
An array containing the element-wise hypotenuse. The data type
1286+
of the returned array is determined by the Type Promotion Rules.
1287+
"""
1288+
1289+
1290+
def _call_hypot(src1, src2, dst, sycl_queue, depends=None):
1291+
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""
1292+
1293+
if depends is None:
1294+
depends = []
1295+
1296+
if vmi._mkl_hypot_to_call(sycl_queue, src1, src2, dst):
1297+
# call pybind11 extension for hypot() function from OneMKL VM
1298+
return vmi._hypot(sycl_queue, src1, src2, dst, depends)
1299+
return ti._hypot(src1, src2, dst, sycl_queue, depends)
1300+
1301+
1302+
hypot_func = BinaryElementwiseFunc(
1303+
"hypot",
1304+
ti._hypot_result_type,
1305+
_call_hypot,
1306+
_hypot_docstring_,
1307+
)
1308+
1309+
1310+
def dpnp_hypot(x1, x2, out=None, order="K"):
1311+
"""
1312+
Invokes hypot() function from pybind11 extension of OneMKL VM if possible.
1313+
1314+
Otherwise fully relies on dpctl.tensor implementation for hypot() function.
1315+
"""
1316+
1317+
# dpctl.tensor only works with usm_ndarray or scalar
1318+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
1319+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
1320+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1321+
1322+
res_usm = hypot_func(
1323+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
1324+
)
1325+
return dpnp_array._create_from_usm_ndarray(res_usm)
1326+
1327+
12671328
_imag_docstring = """
12681329
imag(x, out=None, order="K")
12691330

0 commit comments

Comments
 (0)