Skip to content

Commit 780d686

Browse files
authored
implement dpnp.cbrt, dpnp.exp2, dpnp.copysign, dpnp.rsqrt (#1624)
* implement dpnp.cbrt, dpnp.exp2, dpnp.copysign, dpnp.rsqrt * address comments * address comments - 2nd round
1 parent 0604f18 commit 780d686

20 files changed

+1135
-471
lines changed

dpnp/backend/extensions/vm/cbrt.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event cbrt_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
using resTy = typename types::CbrtOutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
53+
54+
return mkl_vm::cbrt(exec_q,
55+
n, // number of elements to be calculated
56+
a, // pointer `a` containing input vector of size n
57+
y, // pointer `y` to the output vector of size n
58+
depends);
59+
}
60+
61+
template <typename fnT, typename T>
62+
struct CbrtContigFactory
63+
{
64+
fnT get()
65+
{
66+
if constexpr (std::is_same_v<
67+
typename types::CbrtOutputType<T>::value_type, void>)
68+
{
69+
return nullptr;
70+
}
71+
else {
72+
return cbrt_contig_impl<T>;
73+
}
74+
}
75+
};
76+
} // namespace vm
77+
} // namespace ext
78+
} // namespace backend
79+
} // namespace dpnp

dpnp/backend/extensions/vm/exp2.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event exp2_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
using resTy = typename types::Exp2OutputType<T>::value_type;
52+
resTy *y = reinterpret_cast<resTy *>(out_y);
53+
54+
return mkl_vm::exp2(exec_q,
55+
n, // number of elements to be calculated
56+
a, // pointer `a` containing input vector of size n
57+
y, // pointer `y` to the output vector of size n
58+
depends);
59+
}
60+
61+
template <typename fnT, typename T>
62+
struct Exp2ContigFactory
63+
{
64+
fnT get()
65+
{
66+
if constexpr (std::is_same_v<
67+
typename types::Exp2OutputType<T>::value_type, void>)
68+
{
69+
return nullptr;
70+
}
71+
else {
72+
return exp2_contig_impl<T>;
73+
}
74+
}
75+
};
76+
} // namespace vm
77+
} // namespace ext
78+
} // namespace backend
79+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,21 @@ struct AtanhOutputType
202202
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
203203
};
204204

205+
/**
206+
* @brief A factory to define pairs of supported types for which
207+
* MKL VM library provides support in oneapi::mkl::vm::cbrt<T> function.
208+
*
209+
* @tparam T Type of input vector `a` and of result vector `y`.
210+
*/
211+
template <typename T>
212+
struct CbrtOutputType
213+
{
214+
using value_type = typename std::disjunction<
215+
dpctl_td_ns::TypeMapResultEntry<T, double>,
216+
dpctl_td_ns::TypeMapResultEntry<T, float>,
217+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
218+
};
219+
205220
/**
206221
* @brief A factory to define pairs of supported types for which
207222
* MKL VM library provides support in oneapi::mkl::vm::ceil<T> function.
@@ -308,6 +323,21 @@ struct ExpOutputType
308323
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
309324
};
310325

326+
/**
327+
* @brief A factory to define pairs of supported types for which
328+
* MKL VM library provides support in oneapi::mkl::vm::exp2<T> function.
329+
*
330+
* @tparam T Type of input vector `a` and of result vector `y`.
331+
*/
332+
template <typename T>
333+
struct Exp2OutputType
334+
{
335+
using value_type = typename std::disjunction<
336+
dpctl_td_ns::TypeMapResultEntry<T, double>,
337+
dpctl_td_ns::TypeMapResultEntry<T, float>,
338+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
339+
};
340+
311341
/**
312342
* @brief A factory to define pairs of supported types for which
313343
* MKL VM library provides support in oneapi::mkl::vm::expm1<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@
3939
#include "atan.hpp"
4040
#include "atan2.hpp"
4141
#include "atanh.hpp"
42+
#include "cbrt.hpp"
4243
#include "ceil.hpp"
4344
#include "common.hpp"
4445
#include "conj.hpp"
4546
#include "cos.hpp"
4647
#include "cosh.hpp"
4748
#include "div.hpp"
4849
#include "exp.hpp"
50+
#include "exp2.hpp"
4951
#include "expm1.hpp"
5052
#include "floor.hpp"
5153
#include "hypot.hpp"
@@ -81,12 +83,14 @@ static unary_impl_fn_ptr_t asinh_dispatch_vector[dpctl_td_ns::num_types];
8183
static unary_impl_fn_ptr_t atan_dispatch_vector[dpctl_td_ns::num_types];
8284
static binary_impl_fn_ptr_t atan2_dispatch_vector[dpctl_td_ns::num_types];
8385
static unary_impl_fn_ptr_t atanh_dispatch_vector[dpctl_td_ns::num_types];
86+
static unary_impl_fn_ptr_t cbrt_dispatch_vector[dpctl_td_ns::num_types];
8487
static unary_impl_fn_ptr_t ceil_dispatch_vector[dpctl_td_ns::num_types];
8588
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
8689
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
8790
static unary_impl_fn_ptr_t cosh_dispatch_vector[dpctl_td_ns::num_types];
8891
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
8992
static unary_impl_fn_ptr_t exp_dispatch_vector[dpctl_td_ns::num_types];
93+
static unary_impl_fn_ptr_t exp2_dispatch_vector[dpctl_td_ns::num_types];
9094
static unary_impl_fn_ptr_t expm1_dispatch_vector[dpctl_td_ns::num_types];
9195
static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
9296
static binary_impl_fn_ptr_t hypot_dispatch_vector[dpctl_td_ns::num_types];
@@ -366,6 +370,34 @@ PYBIND11_MODULE(_vm_impl, m)
366370
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
367371
}
368372

373+
// UnaryUfunc: ==== Cbrt(x) ====
374+
{
375+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
376+
vm_ext::CbrtContigFactory>(
377+
cbrt_dispatch_vector);
378+
379+
auto cbrt_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
380+
const event_vecT &depends = {}) {
381+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
382+
cbrt_dispatch_vector);
383+
};
384+
m.def("_cbrt", cbrt_pyapi,
385+
"Call `cbrt` function from OneMKL VM library to compute "
386+
"the element-wise cube root of vector elements",
387+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
388+
py::arg("depends") = py::list());
389+
390+
auto cbrt_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
391+
arrayT dst) {
392+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
393+
cbrt_dispatch_vector);
394+
};
395+
m.def("_mkl_cbrt_to_call", cbrt_need_to_call_pyapi,
396+
"Check input arguments to answer if `cbrt` function from "
397+
"OneMKL VM library can be used",
398+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
399+
}
400+
369401
// UnaryUfunc: ==== Ceil(x) ====
370402
{
371403
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
@@ -536,6 +568,34 @@ PYBIND11_MODULE(_vm_impl, m)
536568
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
537569
}
538570

571+
// UnaryUfunc: ==== exp2(x) ====
572+
{
573+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
574+
vm_ext::Exp2ContigFactory>(
575+
exp2_dispatch_vector);
576+
577+
auto exp2_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
578+
const event_vecT &depends = {}) {
579+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
580+
exp2_dispatch_vector);
581+
};
582+
m.def("_exp2", exp2_pyapi,
583+
"Call `exp2` function from OneMKL VM library to compute "
584+
"the element-wise base-2 exponential of vector elements",
585+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
586+
py::arg("depends") = py::list());
587+
588+
auto exp2_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
589+
arrayT dst) {
590+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
591+
exp2_dispatch_vector);
592+
};
593+
m.def("_mkl_exp2_to_call", exp2_need_to_call_pyapi,
594+
"Check input arguments to answer if `exp2` function from "
595+
"OneMKL VM library can be used",
596+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
597+
}
598+
539599
// UnaryUfunc: ==== expm1(x) ====
540600
{
541601
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,8 @@ enum class DPNPFuncName : size_t
8686
DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() impl */
8787
DPNP_FN_BITWISE_XOR, /**< Used in numpy.bitwise_xor() impl */
8888
DPNP_FN_CBRT, /**< Used in numpy.cbrt() impl */
89-
DPNP_FN_CBRT_EXT, /**< Used in numpy.cbrt() impl, requires extra parameters
90-
*/
91-
DPNP_FN_CEIL, /**< Used in numpy.ceil() impl */
92-
DPNP_FN_CHOLESKY, /**< Used in numpy.linalg.cholesky() impl */
89+
DPNP_FN_CEIL, /**< Used in numpy.ceil() impl */
90+
DPNP_FN_CHOLESKY, /**< Used in numpy.linalg.cholesky() impl */
9391
DPNP_FN_CHOLESKY_EXT, /**< Used in numpy.linalg.cholesky() impl, requires
9492
extra parameters */
9593
DPNP_FN_CONJUGATE, /**< Used in numpy.conjugate() impl */
@@ -100,9 +98,7 @@ enum class DPNPFuncName : size_t
10098
DPNP_FN_COPY_EXT, /**< Used in numpy.copy() impl, requires extra parameters
10199
*/
102100
DPNP_FN_COPYSIGN, /**< Used in numpy.copysign() impl */
103-
DPNP_FN_COPYSIGN_EXT, /**< Used in numpy.copysign() impl, requires extra
104-
parameters */
105-
DPNP_FN_COPYTO, /**< Used in numpy.copyto() impl */
101+
DPNP_FN_COPYTO, /**< Used in numpy.copyto() impl */
106102
DPNP_FN_COPYTO_EXT, /**< Used in numpy.copyto() impl, requires extra
107103
parameters */
108104
DPNP_FN_CORRELATE, /**< Used in numpy.correlate() impl */
@@ -154,10 +150,8 @@ enum class DPNPFuncName : size_t
154150
DPNP_FN_EYE, /**< Used in numpy.eye() impl */
155151
DPNP_FN_EXP, /**< Used in numpy.exp() impl */
156152
DPNP_FN_EXP2, /**< Used in numpy.exp2() impl */
157-
DPNP_FN_EXP2_EXT, /**< Used in numpy.exp2() impl, requires extra parameters
158-
*/
159-
DPNP_FN_EXPM1, /**< Used in numpy.expm1() impl */
160-
DPNP_FN_FABS, /**< Used in numpy.fabs() impl */
153+
DPNP_FN_EXPM1, /**< Used in numpy.expm1() impl */
154+
DPNP_FN_FABS, /**< Used in numpy.fabs() impl */
161155
DPNP_FN_FABS_EXT, /**< Used in numpy.fabs() impl, requires extra parameters
162156
*/
163157
DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() impl */

0 commit comments

Comments
 (0)