Skip to content

Commit 74609d6

Browse files
Add dpnp.linalg.solve() function (#1598)
* Add dpnp.linalg.solve() function * Add cupy tests for dpnp.linalg.solve() * Register a LinAlgError in dpnp.linalg submodule * Implementation of dtype dispatching with _common_type for dpnp.linalg.solve * Add a common_helpers.hpp file * Add validation functions for array types and dimensions for linalg funcs * Skip test_solve_singular_empty --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
1 parent e2188ed commit 74609d6

File tree

14 files changed

+1093
-4
lines changed

14 files changed

+1093
-4
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ env:
3131
test_usm_type.py
3232
third_party/cupy/core_tests
3333
third_party/cupy/linalg_tests/test_product.py
34+
third_party/cupy/linalg_tests/test_solve.py
3435
third_party/cupy/logic_tests/test_comparison.py
3536
third_party/cupy/logic_tests/test_truth.py
3637
third_party/cupy/manipulation_tests/test_basic.py

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set(python_module_name _lapack_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
3233
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
#include <cstring>
28+
#include <stdexcept>
29+
30+
namespace dpnp
31+
{
32+
namespace backend
33+
{
34+
namespace ext
35+
{
36+
namespace lapack
37+
{
38+
namespace helper
39+
{
40+
template <typename T>
41+
struct value_type_of
42+
{
43+
using type = T;
44+
};
45+
46+
template <typename T>
47+
struct value_type_of<std::complex<T>>
48+
{
49+
using type = T;
50+
};
51+
} // namespace helper
52+
} // namespace lapack
53+
} // namespace ext
54+
} // namespace backend
55+
} // namespace dpnp
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "common_helpers.hpp"
33+
#include "gesv.hpp"
34+
#include "linalg_exceptions.hpp"
35+
#include "types_matrix.hpp"
36+
37+
#include "dpnp_utils.hpp"
38+
39+
namespace dpnp
40+
{
41+
namespace backend
42+
{
43+
namespace ext
44+
{
45+
namespace lapack
46+
{
47+
namespace mkl_lapack = oneapi::mkl::lapack;
48+
namespace py = pybind11;
49+
namespace type_utils = dpctl::tensor::type_utils;
50+
51+
typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue,
52+
const std::int64_t,
53+
const std::int64_t,
54+
char *,
55+
std::int64_t,
56+
char *,
57+
std::int64_t,
58+
std::vector<sycl::event> &,
59+
const std::vector<sycl::event> &);
60+
61+
static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types];
62+
63+
template <typename T>
64+
static sycl::event gesv_impl(sycl::queue exec_q,
65+
const std::int64_t n,
66+
const std::int64_t nrhs,
67+
char *in_a,
68+
std::int64_t lda,
69+
char *in_b,
70+
std::int64_t ldb,
71+
std::vector<sycl::event> &host_task_events,
72+
const std::vector<sycl::event> &depends)
73+
{
74+
type_utils::validate_type_for_device<T>(exec_q);
75+
76+
T *a = reinterpret_cast<T *>(in_a);
77+
T *b = reinterpret_cast<T *>(in_b);
78+
79+
const std::int64_t scratchpad_size =
80+
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);
81+
T *scratchpad = nullptr;
82+
83+
std::int64_t *ipiv = nullptr;
84+
85+
std::stringstream error_msg;
86+
std::int64_t info = 0;
87+
bool sycl_exception_caught = false;
88+
89+
sycl::event gesv_event;
90+
try {
91+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
92+
ipiv = sycl::malloc_device<std::int64_t>(n, exec_q);
93+
94+
gesv_event = mkl_lapack::gesv(
95+
exec_q,
96+
n, // The order of the matrix A (0 ≤ n).
97+
nrhs, // The number of right-hand sides B (0 ≤ nrhs).
98+
a, // Pointer to the square coefficient matrix A (n x n).
99+
lda, // The leading dimension of a, must be at least max(1, n).
100+
ipiv, // The pivot indices that define the permutation matrix P;
101+
// row i of the matrix was interchanged with row ipiv(i),
102+
// must be at least max(1, n).
103+
b, // Pointer to the right hand side matrix B (n x nrhs).
104+
ldb, // The leading dimension of b, must be at least max(1, n).
105+
scratchpad, // Pointer to scratchpad memory to be used by MKL
106+
// routine for storing intermediate results.
107+
scratchpad_size, depends);
108+
} catch (mkl_lapack::exception const &e) {
109+
info = e.info();
110+
111+
if (info < 0) {
112+
error_msg << "Parameter number " << -info
113+
<< " had an illegal value.";
114+
}
115+
else if (info > 0) {
116+
T host_U;
117+
exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T))
118+
.wait();
119+
120+
using ThresholdType = typename helper::value_type_of<T>::type;
121+
122+
const auto threshold =
123+
std::numeric_limits<ThresholdType>::epsilon() * 100;
124+
if (std::abs(host_U) < threshold) {
125+
sycl::free(scratchpad, exec_q);
126+
throw LinAlgError("The input coefficient matrix is singular.");
127+
}
128+
else {
129+
error_msg << "Unexpected MKL exception caught during gesv() "
130+
"call:\nreason: "
131+
<< e.what() << "\ninfo: " << e.info();
132+
}
133+
}
134+
else if (info == scratchpad_size && e.detail() != 0) {
135+
error_msg
136+
<< "Insufficient scratchpad size. Required size is at least "
137+
<< e.detail();
138+
}
139+
else {
140+
error_msg << "Unexpected MKL exception caught during gesv() "
141+
"call:\nreason: "
142+
<< e.what() << "\ninfo: " << e.info();
143+
}
144+
} catch (sycl::exception const &e) {
145+
error_msg << "Unexpected SYCL exception caught during gesv() call:\n"
146+
<< e.what();
147+
sycl_exception_caught = true;
148+
}
149+
150+
if (info != 0 || sycl_exception_caught) // an unexpected error occurs
151+
{
152+
if (scratchpad != nullptr) {
153+
sycl::free(scratchpad, exec_q);
154+
}
155+
if (ipiv != nullptr) {
156+
sycl::free(ipiv, exec_q);
157+
}
158+
throw std::runtime_error(error_msg.str());
159+
}
160+
161+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
162+
cgh.depends_on(gesv_event);
163+
auto ctx = exec_q.get_context();
164+
cgh.host_task([ctx, scratchpad, ipiv]() {
165+
sycl::free(scratchpad, ctx);
166+
sycl::free(ipiv, ctx);
167+
});
168+
});
169+
host_task_events.push_back(clean_up_event);
170+
171+
return gesv_event;
172+
}
173+
174+
std::pair<sycl::event, sycl::event>
175+
gesv(sycl::queue exec_q,
176+
dpctl::tensor::usm_ndarray coeff_matrix,
177+
dpctl::tensor::usm_ndarray dependent_vals,
178+
const std::vector<sycl::event> &depends)
179+
{
180+
const int coeff_matrix_nd = coeff_matrix.get_ndim();
181+
const int dependent_vals_nd = dependent_vals.get_ndim();
182+
183+
if (coeff_matrix_nd != 2) {
184+
throw py::value_error("The coefficient matrix has ndim=" +
185+
std::to_string(coeff_matrix_nd) +
186+
", but a 2-dimensional array is expected.");
187+
}
188+
189+
if (dependent_vals_nd > 2) {
190+
throw py::value_error(
191+
"The dependent values array has ndim=" +
192+
std::to_string(dependent_vals_nd) +
193+
", but a 1-dimensional or a 2-dimensional array is expected.");
194+
}
195+
196+
const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw();
197+
const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw();
198+
199+
if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) {
200+
throw py::value_error("The coefficient matrix must be square,"
201+
" but got a shape of (" +
202+
std::to_string(coeff_matrix_shape[0]) + ", " +
203+
std::to_string(coeff_matrix_shape[1]) + ").");
204+
}
205+
206+
// check compatibility of execution queue and allocation queue
207+
if (!dpctl::utils::queues_are_compatible(exec_q,
208+
{coeff_matrix, dependent_vals}))
209+
{
210+
throw py::value_error(
211+
"Execution queue is not compatible with allocation queues");
212+
}
213+
214+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
215+
if (overlap(coeff_matrix, dependent_vals)) {
216+
throw py::value_error(
217+
"The arrays of coefficients and dependent variables "
218+
"are overlapping segments of memory");
219+
}
220+
221+
bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous();
222+
if (!is_coeff_matrix_f_contig) {
223+
throw py::value_error("The coefficient matrix "
224+
"must be F-contiguous");
225+
}
226+
227+
bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous();
228+
if (!is_dependent_vals_f_contig) {
229+
throw py::value_error("The array of dependent variables "
230+
"must be F-contiguous");
231+
}
232+
233+
auto array_types = dpctl_td_ns::usm_ndarray_types();
234+
int coeff_matrix_type_id =
235+
array_types.typenum_to_lookup_id(coeff_matrix.get_typenum());
236+
int dependent_vals_type_id =
237+
array_types.typenum_to_lookup_id(dependent_vals.get_typenum());
238+
239+
if (coeff_matrix_type_id != dependent_vals_type_id) {
240+
throw py::value_error("The types of the coefficient matrix and "
241+
"dependent variables are mismatched");
242+
}
243+
244+
gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id];
245+
if (gesv_fn == nullptr) {
246+
throw py::value_error(
247+
"No gesv implementation defined for the provided type "
248+
"of the coefficient matrix.");
249+
}
250+
251+
char *coeff_matrix_data = coeff_matrix.get_data();
252+
char *dependent_vals_data = dependent_vals.get_data();
253+
254+
const std::int64_t n = coeff_matrix_shape[0];
255+
const std::int64_t m = dependent_vals_shape[0];
256+
const std::int64_t nrhs =
257+
(dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1;
258+
259+
const std::int64_t lda = std::max<size_t>(1UL, n);
260+
const std::int64_t ldb = std::max<size_t>(1UL, m);
261+
262+
std::vector<sycl::event> host_task_events;
263+
sycl::event gesv_ev =
264+
gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, dependent_vals_data,
265+
ldb, host_task_events, depends);
266+
267+
sycl::event args_ev = dpctl::utils::keep_args_alive(
268+
exec_q, {coeff_matrix, dependent_vals}, host_task_events);
269+
270+
return std::make_pair(args_ev, gesv_ev);
271+
}
272+
273+
template <typename fnT, typename T>
274+
struct GesvContigFactory
275+
{
276+
fnT get()
277+
{
278+
if constexpr (types::GesvTypePairSupportFactory<T>::is_defined) {
279+
return gesv_impl<T>;
280+
}
281+
else {
282+
return nullptr;
283+
}
284+
}
285+
};
286+
287+
void init_gesv_dispatch_vector(void)
288+
{
289+
dpctl_td_ns::DispatchVectorBuilder<gesv_impl_fn_ptr_t, GesvContigFactory,
290+
dpctl_td_ns::num_types>
291+
contig;
292+
contig.populate_dispatch_vector(gesv_dispatch_vector);
293+
}
294+
} // namespace lapack
295+
} // namespace ext
296+
} // namespace backend
297+
} // namespace dpnp

0 commit comments

Comments
 (0)