Skip to content

Commit 9b450f0

Browse files
Implement of dpnp.linalg.cholesky() (#1638)
* Add a new impl of dpnp.linalg.cholesky * Add cupy tests for dpnp.linalg.cholesky * Add a batch impl of dpnp.linalg.cholesky * Remove an old impl of dpnp_cholesky * Remove DPNP_FN_CHOLESKY_EXT in dpnp_iface_fptr * Remove dpnp_cholesky_ext_c * Add a new _dpnp_cholesky_batch func * Update test_cholesky in test_sycl_queue * Expand test scope in public CI * Add more tests for dpnp.linalg.cholesky * Remove TODOs in cholesky() and update docstings * Use _common_type in dpnp_cholesky * Update dpnp_cholesky and dpnp_cholesky_batch * Keep the lexicographic order * Remove passing n parameter to _potrf * Add additional checks to potrf and potrf_batch * Extend potrf error handler * Extend potrf_batch error handler * Update tests for dpnp.linalg.cholesky * Update license year * Update cholesky docstrings * Add support upper paramenetr for potrf * Add support upper paramenetr for potrf_batch and update dpnp_cholesky * Add tests for upper parameter of dpnp.linalg.cholesky * Address remarks * Fix validation check --------- Co-authored-by: vtavana <120411540+vtavana@users.noreply.github.com>
1 parent 0a5a2bd commit 9b450f0

21 files changed

+1114
-137
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ env:
3030
test_umath.py
3131
test_usm_type.py
3232
third_party/cupy/core_tests
33+
third_party/cupy/linalg_tests/test_decomposition.py
3334
third_party/cupy/linalg_tests/test_norms.py
3435
third_party/cupy/linalg_tests/test_product.py
3536
third_party/cupy/linalg_tests/test_solve.py

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ set(_module_src
3131
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3333
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
34+
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
35+
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
3436
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
3537
)
3638

dpnp/backend/extensions/lapack/getrf.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2023, Intel Corporation
2+
// Copyright (c) 2024, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without

dpnp/backend/extensions/lapack/getrf.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2023, Intel Corporation
2+
// Copyright (c) 2024, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//*****************************************************************************
2-
// Copyright (c) 2023, Intel Corporation
2+
// Copyright (c) 2024, Intel Corporation
33
// All rights reserved.
44
//
55
// Redistribution and use in source and binary forms, with or without

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "getrf.hpp"
3535
#include "heevd.hpp"
3636
#include "linalg_exceptions.hpp"
37+
#include "potrf.hpp"
3738
#include "syevd.hpp"
3839

3940
namespace lapack_ext = dpnp::backend::ext::lapack;
@@ -45,6 +46,8 @@ void init_dispatch_vectors(void)
4546
lapack_ext::init_gesv_dispatch_vector();
4647
lapack_ext::init_getrf_batch_dispatch_vector();
4748
lapack_ext::init_getrf_dispatch_vector();
49+
lapack_ext::init_potrf_batch_dispatch_vector();
50+
lapack_ext::init_potrf_dispatch_vector();
4851
lapack_ext::init_syevd_dispatch_vector();
4952
}
5053

@@ -92,6 +95,20 @@ PYBIND11_MODULE(_lapack_impl, m)
9295
py::arg("eig_vecs"), py::arg("eig_vals"),
9396
py::arg("depends") = py::list());
9497

98+
m.def("_potrf", &lapack_ext::potrf,
99+
"Call `potrf` from OneMKL LAPACK library to return "
100+
"the Cholesky factorization of a symmetric positive-definite matrix",
101+
py::arg("sycl_queue"), py::arg("a_array"), py::arg("upper_lower"),
102+
py::arg("depends") = py::list());
103+
104+
m.def("_potrf_batch", &lapack_ext::potrf_batch,
105+
"Call `potrf_batch` from OneMKL LAPACK library to return "
106+
"the Cholesky factorization of a batch of symmetric "
107+
"positive-definite matrix",
108+
py::arg("sycl_queue"), py::arg("a_array"), py::arg("upper_lower"),
109+
py::arg("n"), py::arg("stride_a"), py::arg("batch_size"),
110+
py::arg("depends") = py::list());
111+
95112
m.def("_syevd", &lapack_ext::syevd,
96113
"Call `syevd` from OneMKL LAPACK library to return "
97114
"the eigenvalues and eigenvectors of a real symmetric matrix",
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "linalg_exceptions.hpp"
33+
#include "potrf.hpp"
34+
#include "types_matrix.hpp"
35+
36+
#include "dpnp_utils.hpp"
37+
38+
namespace dpnp
39+
{
40+
namespace backend
41+
{
42+
namespace ext
43+
{
44+
namespace lapack
45+
{
46+
namespace mkl_lapack = oneapi::mkl::lapack;
47+
namespace py = pybind11;
48+
namespace type_utils = dpctl::tensor::type_utils;
49+
50+
typedef sycl::event (*potrf_impl_fn_ptr_t)(sycl::queue,
51+
const oneapi::mkl::uplo,
52+
const std::int64_t,
53+
char *,
54+
std::int64_t,
55+
std::vector<sycl::event> &,
56+
const std::vector<sycl::event> &);
57+
58+
static potrf_impl_fn_ptr_t potrf_dispatch_vector[dpctl_td_ns::num_types];
59+
60+
template <typename T>
61+
static sycl::event potrf_impl(sycl::queue exec_q,
62+
const oneapi::mkl::uplo upper_lower,
63+
const std::int64_t n,
64+
char *in_a,
65+
std::int64_t lda,
66+
std::vector<sycl::event> &host_task_events,
67+
const std::vector<sycl::event> &depends)
68+
{
69+
type_utils::validate_type_for_device<T>(exec_q);
70+
71+
T *a = reinterpret_cast<T *>(in_a);
72+
73+
const std::int64_t scratchpad_size =
74+
mkl_lapack::potrf_scratchpad_size<T>(exec_q, upper_lower, n, lda);
75+
T *scratchpad = nullptr;
76+
77+
std::stringstream error_msg;
78+
std::int64_t info = 0;
79+
bool is_exception_caught = false;
80+
81+
sycl::event potrf_event;
82+
try {
83+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
84+
85+
potrf_event = mkl_lapack::potrf(
86+
exec_q,
87+
upper_lower, // An enumeration value of type oneapi::mkl::uplo:
88+
// oneapi::mkl::uplo::upper for the upper triangular
89+
// part; oneapi::mkl::uplo::lower for the lower
90+
// triangular part.
91+
n, // Order of the square matrix; (0 ≤ n).
92+
a, // Pointer to the n-by-n matrix.
93+
lda, // The leading dimension of `a`.
94+
scratchpad, // Pointer to scratchpad memory to be used by MKL
95+
// routine for storing intermediate results.
96+
scratchpad_size, depends);
97+
} catch (mkl_lapack::exception const &e) {
98+
is_exception_caught = true;
99+
info = e.info();
100+
if (info < 0) {
101+
error_msg << "Parameter number " << -info
102+
<< " had an illegal value.";
103+
}
104+
else if (info == scratchpad_size && e.detail() != 0) {
105+
error_msg
106+
<< "Insufficient scratchpad size. Required size is at least "
107+
<< e.detail();
108+
}
109+
else if (info > 0 && e.detail() == 0) {
110+
sycl::free(scratchpad, exec_q);
111+
throw LinAlgError("Matrix is not positive definite.");
112+
}
113+
else {
114+
error_msg << "Unexpected MKL exception caught during getrf() "
115+
"call:\nreason: "
116+
<< e.what() << "\ninfo: " << e.info();
117+
}
118+
} catch (sycl::exception const &e) {
119+
is_exception_caught = true;
120+
error_msg << "Unexpected SYCL exception caught during potrf() call:\n"
121+
<< e.what();
122+
}
123+
124+
if (is_exception_caught) // an unexpected error occurs
125+
{
126+
if (scratchpad != nullptr) {
127+
sycl::free(scratchpad, exec_q);
128+
}
129+
throw std::runtime_error(error_msg.str());
130+
}
131+
132+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
133+
cgh.depends_on(potrf_event);
134+
auto ctx = exec_q.get_context();
135+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
136+
});
137+
host_task_events.push_back(clean_up_event);
138+
return potrf_event;
139+
}
140+
141+
std::pair<sycl::event, sycl::event>
142+
potrf(sycl::queue q,
143+
dpctl::tensor::usm_ndarray a_array,
144+
const std::int8_t upper_lower,
145+
const std::vector<sycl::event> &depends)
146+
{
147+
const int a_array_nd = a_array.get_ndim();
148+
149+
if (a_array_nd != 2) {
150+
throw py::value_error(
151+
"The input array has ndim=" + std::to_string(a_array_nd) +
152+
", but a 2-dimensional array is expected.");
153+
}
154+
155+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
156+
157+
if (a_array_shape[0] != a_array_shape[1]) {
158+
throw py::value_error("The input array must be square,"
159+
" but got a shape of (" +
160+
std::to_string(a_array_shape[0]) + ", " +
161+
std::to_string(a_array_shape[1]) + ").");
162+
}
163+
164+
bool is_a_array_c_contig = a_array.is_c_contiguous();
165+
if (!is_a_array_c_contig) {
166+
throw py::value_error("The input array "
167+
"must be C-contiguous");
168+
}
169+
170+
auto array_types = dpctl_td_ns::usm_ndarray_types();
171+
int a_array_type_id =
172+
array_types.typenum_to_lookup_id(a_array.get_typenum());
173+
174+
potrf_impl_fn_ptr_t potrf_fn = potrf_dispatch_vector[a_array_type_id];
175+
if (potrf_fn == nullptr) {
176+
throw py::value_error(
177+
"No potrf implementation defined for the provided type "
178+
"of the input matrix.");
179+
}
180+
181+
char *a_array_data = a_array.get_data();
182+
const std::int64_t n = a_array_shape[0];
183+
const std::int64_t lda = std::max<size_t>(1UL, n);
184+
const oneapi::mkl::uplo uplo_val =
185+
static_cast<oneapi::mkl::uplo>(upper_lower);
186+
187+
std::vector<sycl::event> host_task_events;
188+
sycl::event potrf_ev =
189+
potrf_fn(q, uplo_val, n, a_array_data, lda, host_task_events, depends);
190+
191+
sycl::event args_ev =
192+
dpctl::utils::keep_args_alive(q, {a_array}, host_task_events);
193+
194+
return std::make_pair(args_ev, potrf_ev);
195+
}
196+
197+
template <typename fnT, typename T>
198+
struct PotrfContigFactory
199+
{
200+
fnT get()
201+
{
202+
if constexpr (types::PotrfTypePairSupportFactory<T>::is_defined) {
203+
return potrf_impl<T>;
204+
}
205+
else {
206+
return nullptr;
207+
}
208+
}
209+
};
210+
211+
void init_potrf_dispatch_vector(void)
212+
{
213+
dpctl_td_ns::DispatchVectorBuilder<potrf_impl_fn_ptr_t, PotrfContigFactory,
214+
dpctl_td_ns::num_types>
215+
contig;
216+
contig.populate_dispatch_vector(potrf_dispatch_vector);
217+
}
218+
} // namespace lapack
219+
} // namespace ext
220+
} // namespace backend
221+
} // namespace dpnp
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace lapack
40+
{
41+
extern std::pair<sycl::event, sycl::event>
42+
potrf(sycl::queue exec_q,
43+
dpctl::tensor::usm_ndarray a_array,
44+
const std::int8_t upper_lower,
45+
const std::vector<sycl::event> &depends = {});
46+
47+
extern std::pair<sycl::event, sycl::event>
48+
potrf_batch(sycl::queue exec_q,
49+
dpctl::tensor::usm_ndarray a_array,
50+
const std::int8_t upper_lower,
51+
const std::int64_t n,
52+
const std::int64_t stride_a,
53+
const std::int64_t batch_size,
54+
const std::vector<sycl::event> &depends = {});
55+
56+
extern void init_potrf_dispatch_vector(void);
57+
extern void init_potrf_batch_dispatch_vector(void);
58+
} // namespace lapack
59+
} // namespace ext
60+
} // namespace backend
61+
} // namespace dpnp

0 commit comments

Comments
 (0)