Skip to content

Commit 7e54eb8

Browse files
Implement of dpnp.linalg.slogdet() (#1607)
* Add a new impl of dpnp.linalg._lu_factor * Get dev_info_array after calling getrf * Add an extra dev_info return to _lu_factor * qwe * Add a logic for a.ndim > 2 in _lu_factor * Add an implementation of dpnp.linalg.slogdet * Add a new test_norms.py file in cupy tests * Expand test scope in public CI * A small update _lu_factor func * Remove w/a for dpnp.count_nonzero in slogdet * getrf returns pair of events and uses dpctl.utils.keep_args_alive * Update dpnp.linalg.det using slogdet * Add new cupy tests for dpnp.linalg.det * Add ipiv_vecs and dev_info_vecs in _lu_factor for the batch case * Skip test_det on CPU due to bug in MKL * Small update of cupy tests in test_norms.py * Add support of complex dtype for dpnp.diagonal and update test_diagonal * lu_factor func returns the result of LU decomposition as c-contiguous and add explanatory comments * Add getrf_batch MKL extension * Update docstring for slogdet * Add more tests * Remove accidentally added file * Modify sign parameter calculation * Remove the old backend implementation of dpnp_det * qwe * Keep lexographical order * Add dpnp_slogdet to dpnp_utils_linalg * Move _lu_factor above * A minor update * A minor changes for _lu_factor * Remove trash files * Use getrf_batch only on CPU * Update tests for dpnp.linalg.slogdet * Address remarks * Add _real_type func * Add test_det in test_usm_type * Add more checks in getrf and getf_batch functions * Improve error handler in getrf_impl * Improve error handler in getrf_batch_impl * dev_info is allocated as zeros * Remove skipif for singular tests * Implement _lu_factor logic with dev_info as a python list * Update getrf_rf error handler with mkl_lapack::batch_error * Remove passing n parameter to _getrf * Add a new test_slogdet_singular_matrix_3D test * Update tests for dpnp.linalg.det * Use is_exception_caught flag in getrf and getrf_batch error handler * Update gesv error handler * Reshape results after calling getrf_batch * Add a new dpnp.linalg.det impl and refresh dpnp_utils_linalg * Remove Limitations from dpnp_det and dpnp_slogdet docstings * Address remarks * Remove det_dtype variable and use the abs val of diag for det * Expand cupy tests for dpnp.linalg.det() * Update TestDet and TestSlogdet
1 parent 3fdb921 commit 7e54eb8

20 files changed

+1625
-135
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ env:
3030
test_umath.py
3131
test_usm_type.py
3232
third_party/cupy/core_tests
33+
third_party/cupy/linalg_tests/test_norms.py
3334
third_party/cupy/linalg_tests/test_product.py
3435
third_party/cupy/linalg_tests/test_solve.py
3536
third_party/cupy/logic_tests/test_comparison.py

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ set(python_module_name _lapack_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
32+
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
3133
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
3234
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
3335
)

dpnp/backend/extensions/lapack/gesv.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ static sycl::event gesv_impl(sycl::queue exec_q,
8484

8585
std::stringstream error_msg;
8686
std::int64_t info = 0;
87-
bool sycl_exception_caught = false;
87+
bool is_exception_caught = false;
8888

8989
sycl::event gesv_event;
9090
try {
@@ -106,12 +106,18 @@ static sycl::event gesv_impl(sycl::queue exec_q,
106106
// routine for storing intermediate results.
107107
scratchpad_size, depends);
108108
} catch (mkl_lapack::exception const &e) {
109+
is_exception_caught = true;
109110
info = e.info();
110111

111112
if (info < 0) {
112113
error_msg << "Parameter number " << -info
113114
<< " had an illegal value.";
114115
}
116+
else if (info == scratchpad_size && e.detail() != 0) {
117+
error_msg
118+
<< "Insufficient scratchpad size. Required size is at least "
119+
<< e.detail();
120+
}
115121
else if (info > 0) {
116122
T host_U;
117123
exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T))
@@ -131,23 +137,18 @@ static sycl::event gesv_impl(sycl::queue exec_q,
131137
<< e.what() << "\ninfo: " << e.info();
132138
}
133139
}
134-
else if (info == scratchpad_size && e.detail() != 0) {
135-
error_msg
136-
<< "Insufficient scratchpad size. Required size is at least "
137-
<< e.detail();
138-
}
139140
else {
140141
error_msg << "Unexpected MKL exception caught during gesv() "
141142
"call:\nreason: "
142143
<< e.what() << "\ninfo: " << e.info();
143144
}
144145
} catch (sycl::exception const &e) {
146+
is_exception_caught = true;
145147
error_msg << "Unexpected SYCL exception caught during gesv() call:\n"
146148
<< e.what();
147-
sycl_exception_caught = true;
148149
}
149150

150-
if (info != 0 || sycl_exception_caught) // an unexpected error occurs
151+
if (is_exception_caught) // an unexpected error occurs
151152
{
152153
if (scratchpad != nullptr) {
153154
sycl::free(scratchpad, exec_q);
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "getrf.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace lapack
44+
{
45+
namespace mkl_lapack = oneapi::mkl::lapack;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue,
50+
const std::int64_t,
51+
char *,
52+
std::int64_t,
53+
std::int64_t *,
54+
py::list,
55+
std::vector<sycl::event> &,
56+
const std::vector<sycl::event> &);
57+
58+
static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];
59+
60+
template <typename T>
61+
static sycl::event getrf_impl(sycl::queue exec_q,
62+
const std::int64_t n,
63+
char *in_a,
64+
std::int64_t lda,
65+
std::int64_t *ipiv,
66+
py::list dev_info,
67+
std::vector<sycl::event> &host_task_events,
68+
const std::vector<sycl::event> &depends)
69+
{
70+
type_utils::validate_type_for_device<T>(exec_q);
71+
72+
T *a = reinterpret_cast<T *>(in_a);
73+
74+
const std::int64_t scratchpad_size =
75+
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda);
76+
T *scratchpad = nullptr;
77+
78+
std::stringstream error_msg;
79+
std::int64_t info = 0;
80+
bool is_exception_caught = false;
81+
82+
sycl::event getrf_event;
83+
try {
84+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
85+
86+
getrf_event = mkl_lapack::getrf(
87+
exec_q,
88+
n, // The order of the square matrix A (0 ≤ n).
89+
// It must be a non-negative integer.
90+
n, // The number of columns in the square matrix A (0 ≤ n).
91+
// It must be a non-negative integer.
92+
a, // Pointer to the square matrix A (n x n).
93+
lda, // The leading dimension of matrix A.
94+
// It must be at least max(1, n).
95+
ipiv, // Pointer to the output array of pivot indices.
96+
scratchpad, // Pointer to scratchpad memory to be used by MKL
97+
// routine for storing intermediate results.
98+
scratchpad_size, depends);
99+
} catch (mkl_lapack::exception const &e) {
100+
is_exception_caught = true;
101+
info = e.info();
102+
103+
if (info < 0) {
104+
error_msg << "Parameter number " << -info
105+
<< " had an illegal value.";
106+
}
107+
else if (info == scratchpad_size && e.detail() != 0) {
108+
error_msg
109+
<< "Insufficient scratchpad size. Required size is at least "
110+
<< e.detail();
111+
}
112+
else if (info > 0) {
113+
// Store the positive 'info' value in the first element of
114+
// 'dev_info'. This indicates that the factorization has been
115+
// completed, but the factor U (upper triangular matrix) is exactly
116+
// singular. The 'info' value here is the index of the first zero
117+
// element in the diagonal of U.
118+
is_exception_caught = false;
119+
dev_info[0] = info;
120+
}
121+
else {
122+
error_msg << "Unexpected MKL exception caught during getrf() "
123+
"call:\nreason: "
124+
<< e.what() << "\ninfo: " << e.info();
125+
}
126+
} catch (sycl::exception const &e) {
127+
is_exception_caught = true;
128+
error_msg << "Unexpected SYCL exception caught during getrf() call:\n"
129+
<< e.what();
130+
}
131+
132+
if (is_exception_caught) // an unexpected error occurs
133+
{
134+
if (scratchpad != nullptr) {
135+
sycl::free(scratchpad, exec_q);
136+
}
137+
138+
throw std::runtime_error(error_msg.str());
139+
}
140+
141+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
142+
cgh.depends_on(getrf_event);
143+
auto ctx = exec_q.get_context();
144+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
145+
});
146+
host_task_events.push_back(clean_up_event);
147+
return getrf_event;
148+
}
149+
150+
std::pair<sycl::event, sycl::event>
151+
getrf(sycl::queue exec_q,
152+
dpctl::tensor::usm_ndarray a_array,
153+
dpctl::tensor::usm_ndarray ipiv_array,
154+
py::list dev_info,
155+
const std::vector<sycl::event> &depends)
156+
{
157+
const int a_array_nd = a_array.get_ndim();
158+
const int ipiv_array_nd = ipiv_array.get_ndim();
159+
160+
if (a_array_nd != 2) {
161+
throw py::value_error(
162+
"The input array has ndim=" + std::to_string(a_array_nd) +
163+
", but a 2-dimensional array is expected.");
164+
}
165+
166+
if (ipiv_array_nd != 1) {
167+
throw py::value_error("The array of pivot indices has ndim=" +
168+
std::to_string(ipiv_array_nd) +
169+
", but a 1-dimensional array is expected.");
170+
}
171+
172+
// check compatibility of execution queue and allocation queue
173+
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) {
174+
throw py::value_error(
175+
"Execution queue is not compatible with allocation queues");
176+
}
177+
178+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
179+
if (overlap(a_array, ipiv_array)) {
180+
throw py::value_error("The input array and the array of pivot indices "
181+
"are overlapping segments of memory");
182+
}
183+
184+
bool is_a_array_c_contig = a_array.is_c_contiguous();
185+
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
186+
if (!is_a_array_c_contig) {
187+
throw py::value_error("The input array "
188+
"must be C-contiguous");
189+
}
190+
if (!is_ipiv_array_c_contig) {
191+
throw py::value_error("The array of pivot indices "
192+
"must be C-contiguous");
193+
}
194+
195+
auto array_types = dpctl_td_ns::usm_ndarray_types();
196+
int a_array_type_id =
197+
array_types.typenum_to_lookup_id(a_array.get_typenum());
198+
199+
getrf_impl_fn_ptr_t getrf_fn = getrf_dispatch_vector[a_array_type_id];
200+
if (getrf_fn == nullptr) {
201+
throw py::value_error(
202+
"No getrf implementation defined for the provided type "
203+
"of the input matrix.");
204+
}
205+
206+
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
207+
int ipiv_array_type_id =
208+
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
209+
210+
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
211+
throw py::value_error("The type of 'ipiv_array' must be int64.");
212+
}
213+
214+
const std::int64_t n = a_array.get_shape_raw()[0];
215+
216+
char *a_array_data = a_array.get_data();
217+
const std::int64_t lda = std::max<size_t>(1UL, n);
218+
219+
char *ipiv_array_data = ipiv_array.get_data();
220+
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
221+
222+
std::vector<sycl::event> host_task_events;
223+
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv,
224+
dev_info, host_task_events, depends);
225+
226+
sycl::event args_ev = dpctl::utils::keep_args_alive(
227+
exec_q, {a_array, ipiv_array}, host_task_events);
228+
229+
return std::make_pair(args_ev, getrf_ev);
230+
}
231+
232+
template <typename fnT, typename T>
233+
struct GetrfContigFactory
234+
{
235+
fnT get()
236+
{
237+
if constexpr (types::GetrfTypePairSupportFactory<T>::is_defined) {
238+
return getrf_impl<T>;
239+
}
240+
else {
241+
return nullptr;
242+
}
243+
}
244+
};
245+
246+
void init_getrf_dispatch_vector(void)
247+
{
248+
dpctl_td_ns::DispatchVectorBuilder<getrf_impl_fn_ptr_t, GetrfContigFactory,
249+
dpctl_td_ns::num_types>
250+
contig;
251+
contig.populate_dispatch_vector(getrf_dispatch_vector);
252+
}
253+
} // namespace lapack
254+
} // namespace ext
255+
} // namespace backend
256+
} // namespace dpnp
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace lapack
40+
{
41+
extern std::pair<sycl::event, sycl::event>
42+
getrf(sycl::queue exec_q,
43+
dpctl::tensor::usm_ndarray a_array,
44+
dpctl::tensor::usm_ndarray ipiv_array,
45+
py::list dev_info,
46+
const std::vector<sycl::event> &depends = {});
47+
48+
extern std::pair<sycl::event, sycl::event>
49+
getrf_batch(sycl::queue exec_q,
50+
dpctl::tensor::usm_ndarray a_array,
51+
dpctl::tensor::usm_ndarray ipiv_array,
52+
py::list dev_info,
53+
std::int64_t n,
54+
std::int64_t stride_a,
55+
std::int64_t stride_ipiv,
56+
std::int64_t batch_size,
57+
const std::vector<sycl::event> &depends = {});
58+
59+
extern void init_getrf_dispatch_vector(void);
60+
extern void init_getrf_batch_dispatch_vector(void);
61+
} // namespace lapack
62+
} // namespace ext
63+
} // namespace backend
64+
} // namespace dpnp

0 commit comments

Comments
 (0)