Skip to content

Commit 809a993

Browse files
Merge 12a5cb5 into d9c1ca1
2 parents d9c1ca1 + 12a5cb5 commit 809a993

File tree

16 files changed

+1286
-273
lines changed

16 files changed

+1286
-273
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ env:
3030
test_umath.py
3131
test_usm_type.py
3232
third_party/cupy/core_tests
33+
third_party/cupy/linalg_tests/test_decomposition.py
3334
third_party/cupy/linalg_tests/test_product.py
3435
third_party/cupy/linalg_tests/test_solve.py
3536
third_party/cupy/logic_tests/test_comparison.py

dpnp/backend/extensions/lapack/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(python_module_name _lapack_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp
3030
${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp
3132
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
3233
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
3334
)
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <pybind11/pybind11.h>
27+
28+
// dpctl tensor headers
29+
#include "utils/memory_overlap.hpp"
30+
#include "utils/type_utils.hpp"
31+
32+
#include "gesvd.hpp"
33+
#include "types_matrix.hpp"
34+
35+
#include "dpnp_utils.hpp"
36+
37+
namespace dpnp
38+
{
39+
namespace backend
40+
{
41+
namespace ext
42+
{
43+
namespace lapack
44+
{
45+
namespace mkl_lapack = oneapi::mkl::lapack;
46+
namespace py = pybind11;
47+
namespace type_utils = dpctl::tensor::type_utils;
48+
49+
typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue,
50+
const oneapi::mkl::jobsvd,
51+
const oneapi::mkl::jobsvd,
52+
const std::int64_t,
53+
const std::int64_t,
54+
char *,
55+
const std::int64_t,
56+
char *,
57+
char *,
58+
const std::int64_t,
59+
char *,
60+
const std::int64_t,
61+
std::vector<sycl::event> &,
62+
const std::vector<sycl::event> &);
63+
64+
static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types]
65+
[dpctl_td_ns::num_types];
66+
67+
// Converts a given character code (ord) to the corresponding
68+
// oneapi::mkl::jobsvd enumeration value
69+
static oneapi::mkl::jobsvd process_job(std::int8_t job_val)
70+
{
71+
switch (job_val) {
72+
case 'A':
73+
return oneapi::mkl::jobsvd::vectors;
74+
case 'S':
75+
return oneapi::mkl::jobsvd::somevec;
76+
case 'O':
77+
return oneapi::mkl::jobsvd::vectorsina;
78+
case 'N':
79+
return oneapi::mkl::jobsvd::novec;
80+
default:
81+
throw std::invalid_argument("Unknown value for job");
82+
}
83+
}
84+
85+
template <typename T, typename RealT>
86+
static sycl::event gesvd_impl(sycl::queue exec_q,
87+
const oneapi::mkl::jobsvd jobu,
88+
const oneapi::mkl::jobsvd jobvt,
89+
const std::int64_t m,
90+
const std::int64_t n,
91+
char *in_a,
92+
const std::int64_t lda,
93+
char *out_s,
94+
char *out_u,
95+
const std::int64_t ldu,
96+
char *out_vt,
97+
const std::int64_t ldvt,
98+
std::vector<sycl::event> &host_task_events,
99+
const std::vector<sycl::event> &depends)
100+
{
101+
type_utils::validate_type_for_device<T>(exec_q);
102+
type_utils::validate_type_for_device<RealT>(exec_q);
103+
104+
T *a = reinterpret_cast<T *>(in_a);
105+
RealT *s = reinterpret_cast<RealT *>(out_s);
106+
T *u = reinterpret_cast<T *>(out_u);
107+
T *vt = reinterpret_cast<T *>(out_vt);
108+
109+
const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size<T>(
110+
exec_q, jobu, jobvt, m, n, lda, ldu, ldvt);
111+
T *scratchpad = nullptr;
112+
113+
std::stringstream error_msg;
114+
std::int64_t info = 0;
115+
std::int64_t detail = 0;
116+
117+
sycl::event gesvd_event;
118+
try {
119+
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);
120+
121+
gesvd_event = mkl_lapack::gesvd(
122+
exec_q,
123+
jobu, // Character specifying how to compute the matrix U:
124+
// 'A' computes all columns of U,
125+
// 'S' computes the first min(m,n) columns of U,
126+
// 'O' overwrites A with the columns of U,
127+
// 'N' does not compute U.
128+
jobvt, // Character specifying how to compute the matrix VT:
129+
// 'A' computes all rows of VT,
130+
// 'S' computes the first min(m,n) rows of VT,
131+
// 'O' overwrites A with the rows of VT,
132+
// 'N' does not compute VT.
133+
m, // The number of rows in the input matrix A (0 <= m).
134+
n, // The number of columns in the input matrix A (0 <= n).
135+
a, // Pointer to the input matrix A of size (m x n).
136+
lda, // The leading dimension of A, must be at least max(1, m).
137+
s, // Pointer to the array containing the singular values.
138+
u, // Pointer to the matrix U in the singular value decomposition.
139+
ldu, // The leading dimension of U, must be at least max(1, m).
140+
vt, // Pointer to the matrix VT in the singular value decomposition.
141+
ldvt, // The leading dimension of VT, must be at least max(1, n).
142+
scratchpad, // Pointer to scratchpad memory to be used by MKL
143+
// routine for storing intermediate results.
144+
scratchpad_size, depends);
145+
} catch (mkl_lapack::exception const &e) {
146+
info = e.info();
147+
detail = e.detail();
148+
error_msg << "MKL LAPACK exception caught during gesvd() call:\n"
149+
<< "Reason: " << e.what() << "\n"
150+
<< "Info: " << info << "\n";
151+
if (info < 0) {
152+
error_msg << "Parameter " << -info << " had an illegal value.\n";
153+
}
154+
else if (info > 0) {
155+
error_msg << "The algorithm computing SVD failed to converge; "
156+
<< info << " off-diagonal elements of an intermediate "
157+
<< "bidiagonal form did not converge to zero.\n";
158+
}
159+
else if (info == scratchpad_size && detail != 0) {
160+
error_msg << "Insufficient scratchpad size. Required size: "
161+
<< detail << ".\n";
162+
}
163+
} catch (sycl::exception const &e) {
164+
error_msg << "Unexpected SYCL exception caught during gesvd() call:\n"
165+
<< e.what();
166+
info = -1;
167+
}
168+
169+
if (info != 0) // an unexpected error occurs
170+
{
171+
if (scratchpad != nullptr) {
172+
sycl::free(scratchpad, exec_q);
173+
}
174+
throw std::runtime_error(error_msg.str());
175+
}
176+
177+
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
178+
cgh.depends_on(gesvd_event);
179+
auto ctx = exec_q.get_context();
180+
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
181+
});
182+
host_task_events.push_back(clean_up_event);
183+
return gesvd_event;
184+
}
185+
186+
std::pair<sycl::event, sycl::event>
187+
gesvd(sycl::queue exec_q,
188+
const std::int8_t jobu_val,
189+
const std::int8_t jobvt_val,
190+
const std::int64_t m,
191+
const std::int64_t n,
192+
dpctl::tensor::usm_ndarray a_array,
193+
dpctl::tensor::usm_ndarray out_s,
194+
dpctl::tensor::usm_ndarray out_u,
195+
dpctl::tensor::usm_ndarray out_vt,
196+
const std::vector<sycl::event> &depends)
197+
{
198+
const int a_array_nd = a_array.get_ndim();
199+
200+
if (a_array_nd != 2) {
201+
throw py::value_error(
202+
"The input array has ndim=" + std::to_string(a_array_nd) +
203+
", but a 2-dimensional array is expected.");
204+
}
205+
206+
// check compatibility of execution queue and allocation queue
207+
if (!dpctl::utils::queues_are_compatible(
208+
exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(),
209+
out_vt.get_queue()}))
210+
{
211+
throw std::runtime_error(
212+
"USM allocations are not compatible with the execution queue.");
213+
}
214+
215+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
216+
if (overlap(a_array, out_s) || overlap(a_array, out_u) ||
217+
overlap(a_array, out_vt) || overlap(out_s, out_u) ||
218+
overlap(out_s, out_vt) || overlap(out_u, out_vt))
219+
{
220+
throw py::value_error("Arrays have overlapping segments of memory");
221+
}
222+
223+
bool is_a_array_c_contig = a_array.is_c_contiguous();
224+
if (!is_a_array_c_contig) {
225+
throw py::value_error("The input array must be C-contiguous");
226+
}
227+
228+
auto array_types = dpctl_td_ns::usm_ndarray_types();
229+
int a_array_type_id =
230+
array_types.typenum_to_lookup_id(a_array.get_typenum());
231+
int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum());
232+
int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum());
233+
int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum());
234+
235+
if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) {
236+
throw py::type_error(
237+
"Input array, output left singular vectors array, "
238+
"and outpuy right singular vectors array must have "
239+
"the same data type");
240+
}
241+
242+
gesvd_impl_fn_ptr_t gesvd_fn =
243+
gesvd_dispatch_table[a_array_type_id][out_s_type_id];
244+
if (gesvd_fn == nullptr) {
245+
throw py::value_error(
246+
"No gesvd implementation is defined for the given pair "
247+
"of array type and output singular values type.");
248+
}
249+
250+
char *a_array_data = a_array.get_data();
251+
char *out_s_data = out_s.get_data();
252+
char *out_u_data = out_u.get_data();
253+
char *out_vt_data = out_vt.get_data();
254+
255+
const std::int64_t lda = std::max<size_t>(1UL, m);
256+
const std::int64_t ldu = std::max<size_t>(1UL, m);
257+
const std::int64_t ldvt = std::max<size_t>(1UL, n);
258+
259+
const oneapi::mkl::jobsvd jobu = process_job(jobu_val);
260+
const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val);
261+
262+
std::vector<sycl::event> host_task_events;
263+
sycl::event gesvd_ev =
264+
gesvd_fn(exec_q, jobu, jobvt, m, n, a_array_data, lda, out_s_data,
265+
out_u_data, ldu, out_vt_data, ldvt, host_task_events, depends);
266+
267+
sycl::event args_ev = dpctl::utils::keep_args_alive(
268+
exec_q, {a_array, out_s, out_u, out_vt}, host_task_events);
269+
270+
return std::make_pair(args_ev, gesvd_ev);
271+
}
272+
273+
template <typename fnT, typename T, typename RealT>
274+
struct GesvdContigFactory
275+
{
276+
fnT get()
277+
{
278+
if constexpr (types::GesvdTypePairSupportFactory<T, RealT>::is_defined)
279+
{
280+
return gesvd_impl<T, RealT>;
281+
}
282+
else {
283+
return nullptr;
284+
}
285+
}
286+
};
287+
288+
void init_gesvd_dispatch_table(void)
289+
{
290+
dpctl_td_ns::DispatchTableBuilder<gesvd_impl_fn_ptr_t, GesvdContigFactory,
291+
dpctl_td_ns::num_types>
292+
contig;
293+
contig.populate_dispatch_table(gesvd_dispatch_table);
294+
}
295+
} // namespace lapack
296+
} // namespace ext
297+
} // namespace backend
298+
} // namespace dpnp
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
#include <oneapi/mkl.hpp>
30+
31+
#include <dpctl4pybind11.hpp>
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace lapack
40+
{
41+
extern std::pair<sycl::event, sycl::event>
42+
gesvd(sycl::queue exec_q,
43+
const std::int8_t jobu_val,
44+
const std::int8_t jobvt_val,
45+
const std::int64_t m,
46+
const std::int64_t n,
47+
dpctl::tensor::usm_ndarray a_array,
48+
dpctl::tensor::usm_ndarray out_s,
49+
dpctl::tensor::usm_ndarray out_u,
50+
dpctl::tensor::usm_ndarray out_vt,
51+
const std::vector<sycl::event> &depends);
52+
53+
extern void init_gesvd_dispatch_table(void);
54+
} // namespace lapack
55+
} // namespace ext
56+
} // namespace backend
57+
} // namespace dpnp

0 commit comments

Comments
 (0)