Skip to content

Commit 7a56304

Browse files
Update dpnp.linalg.inv() function (#1665)
* Impl dpnp.linalg.inv for 2d array * Remove an old impl of dpnp_inv * Add batch implementation of dpnp.linalg.inv func * Add cupy tests for dpnp.linalg.inf * Add dpnp tests for dpnp.linalg.inv * Add check_lapack_dev_info func * Add dev_info size check for getri_batch and getrf_batch * Add size check dev_info and error_matrices_ids * Remove dpnp_inv_ext_c * Rename check_lapack_dev_info to _check_lapack_dev_info * Skip test_inv in TestInvInvalid --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent a75e599 commit 7a56304

File tree

16 files changed

+871
-115
lines changed

16 files changed

+871
-115
lines changed

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(_module_src
3030
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
3435
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
3536
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ static sycl::event getrf_batch_impl(sycl::queue exec_q,
116116
// Get the indices of the first zero diagonal elements of these matrices
117117
auto error_info = be.exceptions();
118118

119+
auto error_matrices_ids_size = error_matrices_ids.size();
120+
auto dev_info_size = static_cast<std::size_t>(py::len(dev_info));
121+
if (error_matrices_ids_size != dev_info_size) {
122+
throw py::value_error("The size of `dev_info` must be equal to" +
123+
std::to_string(error_matrices_ids_size) +
124+
", but currently it is " +
125+
std::to_string(dev_info_size) + ".");
126+
}
127+
119128
for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
120129
// Assign the index of the first zero diagonal element in each
121130
// error matrix to the corresponding index in 'dev_info'
@@ -190,6 +199,14 @@ std::pair<sycl::event, sycl::event>
190199
", but a 2-dimensional array is expected.");
191200
}
192201

202+
const int dev_info_size = py::len(dev_info);
203+
if (dev_info_size != batch_size) {
204+
throw py::value_error("The size of 'dev_info' (" +
205+
std::to_string(dev_info_size) +
206+
") does not match the expected batch size (" +
207+
std::to_string(batch_size) + ").");
208+
}
209+
193210
// check compatibility of execution queue and allocation queue
194211
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) {
195212
throw py::value_error(
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace lapack
40+
{
41+
extern std::pair<sycl::event, sycl::event>
42+
getri_batch(sycl::queue exec_q,
43+
dpctl::tensor::usm_ndarray a_array,
44+
dpctl::tensor::usm_ndarray ipiv_array,
45+
py::list dev_info,
46+
std::int64_t n,
47+
std::int64_t stride_a,
48+
std::int64_t stride_ipiv,
49+
std::int64_t batch_size,
50+
const std::vector<sycl::event> &depends = {});
51+
52+
extern void init_getri_batch_dispatch_vector(void);
53+
} // namespace lapack
54+
} // namespace ext
55+
} // namespace backend
56+
} // namespace dpnp
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "getri.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace lapack
44+
{
45+
namespace mkl_lapack = oneapi::mkl::lapack;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*getri_batch_impl_fn_ptr_t)(
50+
sycl::queue,
51+
std::int64_t,
52+
char *,
53+
std::int64_t,
54+
std::int64_t,
55+
std::int64_t *,
56+
std::int64_t,
57+
std::int64_t,
58+
py::list,
59+
std::vector<sycl::event> &,
60+
const std::vector<sycl::event> &);
61+
62+
static getri_batch_impl_fn_ptr_t
63+
getri_batch_dispatch_vector[dpctl_td_ns::num_types];
64+
65+
template <typename T>
66+
static sycl::event getri_batch_impl(sycl::queue exec_q,
67+
std::int64_t n,
68+
char *in_a,
69+
std::int64_t lda,
70+
std::int64_t stride_a,
71+
std::int64_t *ipiv,
72+
std::int64_t stride_ipiv,
73+
std::int64_t batch_size,
74+
py::list dev_info,
75+
std::vector<sycl::event> &host_task_events,
76+
const std::vector<sycl::event> &depends)
77+
{
78+
type_utils::validate_type_for_device<T>(exec_q);
79+
80+
T *a = reinterpret_cast<T *>(in_a);
81+
82+
const std::int64_t scratchpad_size =
83+
mkl_lapack::getri_batch_scratchpad_size<T>(exec_q, n, lda, stride_a,
84+
stride_ipiv, batch_size);
85+
T *scratchpad = nullptr;
86+
87+
std::stringstream error_msg;
88+
std::int64_t info = 0;
89+
bool is_exception_caught = false;
90+
91+
sycl::event getri_batch_event;
92+
try {
93+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
94+
95+
getri_batch_event = mkl_lapack::getri_batch(
96+
exec_q,
97+
n, // The order of each square matrix in the batch; (0 ≤ n).
98+
// It must be a non-negative integer.
99+
a, // Pointer to the batch of square matrices, each of size (n x n).
100+
lda, // The leading dimension of each matrix in the batch.
101+
stride_a, // Stride between consecutive matrices in the batch.
102+
ipiv, // Pointer to the array of pivot indices for each matrix in
103+
// the batch.
104+
stride_ipiv, // Stride between pivot indices: Spacing between pivot
105+
// arrays in 'ipiv'.
106+
batch_size, // Total number of matrices in the batch.
107+
scratchpad, // Pointer to scratchpad memory to be used by MKL
108+
// routine for storing intermediate results.
109+
scratchpad_size, depends);
110+
} catch (mkl_lapack::batch_error const &be) {
111+
// Get the indices of matrices within the batch that encountered an
112+
// error
113+
auto error_matrices_ids = be.ids();
114+
// Get the indices of the first zero diagonal elements of these matrices
115+
auto error_info = be.exceptions();
116+
117+
auto error_matrices_ids_size = error_matrices_ids.size();
118+
auto dev_info_size = static_cast<std::size_t>(py::len(dev_info));
119+
if (error_matrices_ids_size != dev_info_size) {
120+
throw py::value_error("The size of `dev_info` must be equal to" +
121+
std::to_string(error_matrices_ids_size) +
122+
", but currently it is " +
123+
std::to_string(dev_info_size) + ".");
124+
}
125+
126+
for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
127+
// Assign the index of the first zero diagonal element in each
128+
// error matrix to the corresponding index in 'dev_info'
129+
dev_info[error_matrices_ids[i]] = error_info[i];
130+
}
131+
} catch (mkl_lapack::exception const &e) {
132+
is_exception_caught = true;
133+
info = e.info();
134+
135+
if (info < 0) {
136+
error_msg << "Parameter number " << -info
137+
<< " had an illegal value.";
138+
}
139+
else if (info == scratchpad_size && e.detail() != 0) {
140+
error_msg
141+
<< "Insufficient scratchpad size. Required size is at least "
142+
<< e.detail();
143+
}
144+
else {
145+
error_msg << "Unexpected MKL exception caught during getri_batch() "
146+
"call:\nreason: "
147+
<< e.what() << "\ninfo: " << e.info();
148+
}
149+
} catch (sycl::exception const &e) {
150+
is_exception_caught = true;
151+
error_msg
152+
<< "Unexpected SYCL exception caught during getri_batch() call:\n"
153+
<< e.what();
154+
}
155+
156+
if (is_exception_caught) // an unexpected error occurs
157+
{
158+
if (scratchpad != nullptr) {
159+
sycl::free(scratchpad, exec_q);
160+
}
161+
162+
throw std::runtime_error(error_msg.str());
163+
}
164+
165+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
166+
cgh.depends_on(getri_batch_event);
167+
auto ctx = exec_q.get_context();
168+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
169+
});
170+
host_task_events.push_back(clean_up_event);
171+
return getri_batch_event;
172+
}
173+
174+
std::pair<sycl::event, sycl::event>
175+
getri_batch(sycl::queue exec_q,
176+
dpctl::tensor::usm_ndarray a_array,
177+
dpctl::tensor::usm_ndarray ipiv_array,
178+
py::list dev_info,
179+
std::int64_t n,
180+
std::int64_t stride_a,
181+
std::int64_t stride_ipiv,
182+
std::int64_t batch_size,
183+
const std::vector<sycl::event> &depends)
184+
{
185+
const int a_array_nd = a_array.get_ndim();
186+
const int ipiv_array_nd = ipiv_array.get_ndim();
187+
188+
if (a_array_nd < 3) {
189+
throw py::value_error(
190+
"The input array has ndim=" + std::to_string(a_array_nd) +
191+
", but an array with ndim >= 3 is expected.");
192+
}
193+
194+
if (ipiv_array_nd != 2) {
195+
throw py::value_error("The array of pivot indices has ndim=" +
196+
std::to_string(ipiv_array_nd) +
197+
", but a 2-dimensional array is expected.");
198+
}
199+
200+
const int dev_info_size = py::len(dev_info);
201+
if (dev_info_size != batch_size) {
202+
throw py::value_error("The size of 'dev_info' (" +
203+
std::to_string(dev_info_size) +
204+
") does not match the expected batch size (" +
205+
std::to_string(batch_size) + ").");
206+
}
207+
208+
// check compatibility of execution queue and allocation queue
209+
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) {
210+
throw py::value_error(
211+
"Execution queue is not compatible with allocation queues");
212+
}
213+
214+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
215+
if (overlap(a_array, ipiv_array)) {
216+
throw py::value_error("The input array and the array of pivot indices "
217+
"are overlapping segments of memory");
218+
}
219+
220+
bool is_a_array_c_contig = a_array.is_c_contiguous();
221+
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
222+
if (!is_a_array_c_contig) {
223+
throw py::value_error("The input array "
224+
"must be C-contiguous");
225+
}
226+
if (!is_ipiv_array_c_contig) {
227+
throw py::value_error("The array of pivot indices "
228+
"must be C-contiguous");
229+
}
230+
231+
auto array_types = dpctl_td_ns::usm_ndarray_types();
232+
int a_array_type_id =
233+
array_types.typenum_to_lookup_id(a_array.get_typenum());
234+
235+
getri_batch_impl_fn_ptr_t getri_batch_fn =
236+
getri_batch_dispatch_vector[a_array_type_id];
237+
if (getri_batch_fn == nullptr) {
238+
throw py::value_error(
239+
"No getri_batch implementation defined for the provided type "
240+
"of the input matrix.");
241+
}
242+
243+
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
244+
int ipiv_array_type_id =
245+
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
246+
247+
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
248+
throw py::value_error("The type of 'ipiv_array' must be int64.");
249+
}
250+
251+
char *a_array_data = a_array.get_data();
252+
const std::int64_t lda = std::max<size_t>(1UL, n);
253+
254+
char *ipiv_array_data = ipiv_array.get_data();
255+
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
256+
257+
std::vector<sycl::event> host_task_events;
258+
sycl::event getri_batch_ev = getri_batch_fn(
259+
exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size,
260+
dev_info, host_task_events, depends);
261+
262+
sycl::event args_ev = dpctl::utils::keep_args_alive(
263+
exec_q, {a_array, ipiv_array}, host_task_events);
264+
265+
return std::make_pair(args_ev, getri_batch_ev);
266+
}
267+
268+
template <typename fnT, typename T>
269+
struct GetriBatchContigFactory
270+
{
271+
fnT get()
272+
{
273+
if constexpr (types::GetriBatchTypePairSupportFactory<T>::is_defined) {
274+
return getri_batch_impl<T>;
275+
}
276+
else {
277+
return nullptr;
278+
}
279+
}
280+
};
281+
282+
void init_getri_batch_dispatch_vector(void)
283+
{
284+
dpctl_td_ns::DispatchVectorBuilder<getri_batch_impl_fn_ptr_t,
285+
GetriBatchContigFactory,
286+
dpctl_td_ns::num_types>
287+
contig;
288+
contig.populate_dispatch_vector(getri_batch_dispatch_vector);
289+
}
290+
} // namespace lapack
291+
} // namespace ext
292+
} // namespace backend
293+
} // namespace dpnp

0 commit comments

Comments
 (0)