From 6500a6c39006650468db92219ed4cfbf7563d745 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 5 Oct 2023 19:34:58 +0200 Subject: [PATCH 01/14] Add dpnp.linalg.solve() function --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 + dpnp/backend/extensions/lapack/gesv.cpp | 236 ++++++++++++++++++ dpnp/backend/extensions/lapack/gesv.hpp | 50 ++++ dpnp/backend/extensions/lapack/lapack_py.cpp | 10 + .../extensions/lapack/types_matrix.hpp | 26 ++ dpnp/linalg/dpnp_iface_linalg.py | 56 ++++- dpnp/linalg/dpnp_utils_linalg.py | 71 ++++++ 7 files changed, 449 insertions(+), 1 deletion(-) create mode 100644 dpnp/backend/extensions/lapack/gesv.cpp create mode 100644 dpnp/backend/extensions/lapack/gesv.hpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 0c90b4f0ca52..81196a78ab98 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -29,6 +29,7 @@ pybind11_add_module(${python_module_name} MODULE lapack_py.cpp heevd.cpp syevd.cpp + gesv.cpp ) if (WIN32) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp new file mode 100644 index 000000000000..277886b78954 --- /dev/null +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -0,0 +1,236 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "gesv.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue, + const std::int64_t, + const std::int64_t, + char *, + std::int64_t, + std::int64_t *, + char *, + std::int64_t, + std::vector &, + const std::vector &); + +static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event gesv_impl(sycl::queue exec_q, + const std::int64_t n, + const std::int64_t nrhs, + char *in_a, + std::int64_t lda, + std::int64_t *ipiv, + char *in_b, + std::int64_t ldb, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t scratchpad_size = + mkl_lapack::gesv_scratchpad_size(exec_q, n, nrhs, lda, ldb); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + + sycl::event gesv_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + gesv_event = mkl_lapack::gesv( + exec_q, + n, // The order of the matrix A (0 ≤ n). + nrhs, // The number of right-hand sides B (0 ≤ nrhs). + a, // Pointer to the square coefficient matrix A (n x n). + lda, // The leading dimension of a, must be at least max(1, n). + ipiv, // The pivot indices that define the permutation matrix P; + // row i of the matrix was interchanged with row ipiv(i), + // must be at least max(1, n). + b, // Pointer to the right hand side matrix B (n x nrhs). + ldb, // The leading dimension of b, must be at least max(1, n). + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + error_msg + << "Unexpected MKL exception caught during gesv() call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + info = e.info(); + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during gesv() call:\n" + << e.what(); + info = -1; + } + + if (info != 0) // an unexected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gesv_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + + return gesv_event; +} + +sycl::event gesv(sycl::queue exec_q, + dpctl::tensor::usm_ndarray coeff_matrix, + dpctl::tensor::usm_ndarray hand_sides, + const std::vector &depends) +{ + // check ndim hand_sides + + const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw(); + const py::ssize_t *hand_sides_shape = hand_sides.get_shape_raw(); + + if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) { + throw py::value_error("The input coefficients array must be square "); + } + + // elif check shape coeff_matrix and hand_sides + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, + {coeff_matrix, hand_sides})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(coeff_matrix, hand_sides)) { + throw py::value_error("Arrays with coefficients and hand sides are " + "overlapping segments of memory"); + } + + bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous(); + bool is_hand_sides_c_contig = hand_sides.is_c_contiguous(); + if (!is_coeff_matrix_f_contig) { + throw py::value_error("An array with coefficients " + "must be F-contiguous"); + } + else if (!is_hand_sides_c_contig) { + throw py::value_error( + "An array with the output solutions of the coefficient matrix" + "must be C-contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int coeff_matrix_type_id = + array_types.typenum_to_lookup_id(coeff_matrix.get_typenum()); + int hand_sides_type_id = + array_types.typenum_to_lookup_id(hand_sides.get_typenum()); + + if (coeff_matrix_type_id != hand_sides_type_id) { + throw py::value_error( + "Types of coefficients and hand sides are missmatched"); + } + + gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id]; + if (gesv_fn == nullptr) { + throw py::value_error("No gesv implementation defined for a type of " + "coefficient matrix and hand sides"); + } + + char *coeff_matrix_data = coeff_matrix.get_data(); + char *hand_sides_data = hand_sides.get_data(); + + const std::int64_t n = coeff_matrix_shape[0]; + const std::int64_t nrhs = hand_sides_shape[0]; + const std::int64_t m = hand_sides_shape[0]; + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, m); + + std::vector ipiv(n); + std::int64_t *d_ipiv = sycl::malloc_device(n, exec_q); + + std::vector host_task_events; + sycl::event gesv_ev = + gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, d_ipiv, + hand_sides_data, ldb, host_task_events, depends); + + return gesv_ev; +} + +template +struct GesvContigFactory +{ + fnT get() + { + if constexpr (types::GesvTypePairSupportFactory::is_defined) { + return gesv_impl; + } + else { + return nullptr; + } + } +}; + +void init_gesv_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(gesv_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/gesv.hpp b/dpnp/backend/extensions/lapack/gesv.hpp new file mode 100644 index 000000000000..6c84b08acaf1 --- /dev/null +++ b/dpnp/backend/extensions/lapack/gesv.hpp @@ -0,0 +1,50 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern sycl::event gesv(sycl::queue exec_q, + dpctl::tensor::usm_ndarray coeff_matrix, + dpctl::tensor::usm_ndarray hand_sides, + const std::vector &depends); + +extern void init_gesv_dispatch_vector(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 97b67d59e24e..e7c419942cdb 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -30,6 +30,7 @@ #include #include +#include "gesv.hpp" #include "heevd.hpp" #include "syevd.hpp" @@ -40,6 +41,7 @@ namespace py = pybind11; void init_dispatch_vectors(void) { lapack_ext::init_syevd_dispatch_vector(); + lapack_ext::init_gesv_dispatch_vector(); } // populate dispatch tables @@ -66,4 +68,12 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), py::arg("eig_vecs"), py::arg("eig_vals"), py::arg("depends") = py::list()); + + m.def("_gesv", &lapack_ext::gesv, + "Call `gesv` from OneMKL LAPACK library to return " + "solution to the system of linear equations with a square " + "coefficient matrix A" + "and multiple right-hand sides", + py::arg("sycl_queue"), py::arg("coeff_matrix"), py::arg("hand_sides"), + py::arg("depends") = py::list()); } diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 3cab18d3c63d..9d0a55c41ccf 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -80,6 +80,32 @@ struct SyevdTypePairSupportFactory // fall-through dpctl_td_ns::NotDefinedEntry>::is_defined; }; + +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::gesv + * function. + * + * @tparam T Type of array containing input matrix A and an output arrays with + * coefficient matrix and hand sides. + */ +template +struct GesvTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; } // namespace types } // namespace lapack } // namespace ext diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 3addcdc32585..a089473fffb0 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -47,7 +47,7 @@ from dpnp.dpnp_utils import * from dpnp.linalg.dpnp_algo_linalg import * -from .dpnp_utils_linalg import dpnp_eigh +from .dpnp_utils_linalg import dpnp_eigh, dpnp_solve __all__ = [ "cholesky", @@ -62,6 +62,7 @@ "multi_dot", "norm", "qr", + "solve", "svd", ] @@ -498,6 +499,59 @@ def qr(x1, mode="reduced"): return call_origin(numpy.linalg.qr, x1, mode) +def solve(a, b): + """ + Solve a linear matrix equation, or system of linear scalar equations. + + For full documentation refer to :obj:`numpy.linalg.solve`. + + Returns + ------- + out : {(…, M,), (…, M, K)} dpnp.ndarray + Solution to the system ax = b. Returned shape is identical to b. + + Limitations + ----------- + Parameter `a` is supported as :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. + Input array data types are limited by supported DPNP :ref:`Data types`. + + See Also + -------- + :obj:`dpnp.dot` : Returns the dot product of two arrays. + + Examples + -------- + >>> import dpnp as dp + >>> a = dp.array([[1, 2], [3, 5]]) + >>> b = dp.array([1, 2]) + >>> x = dp.linalg.solve(a, b) + >>> x + array([-1., 1.]) + + """ + if not dpnp.is_supported_array_type(a): + raise TypeError( + "An array must be any of supported type, but got {}".format(type(a)) + ) + + if not dpnp.is_supported_array_type(b): + raise TypeError( + "An array must be any of supported type, but got {}".format(type(b)) + ) + + if a.ndim < 2: + raise ValueError( + f"{a.ndim}-dimensional array given. Array must be " + "at least two-dimensional" + ) + + m, n = a.shape[-2:] + if m != n: + raise ValueError("Last 2 dimensions of the array must be square") + + return dpnp_solve(a, b) + + def svd(x1, full_matrices=True, compute_uv=True, hermitian=False): """ Singular Value Decomposition. diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index e818835ecbee..baca6d43ce3e 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -27,6 +27,7 @@ # ***************************************************************************** +import dpctl import dpctl.tensor._tensor_impl as ti import dpnp @@ -164,3 +165,73 @@ def dpnp_eigh(a, UPLO): ht_copy_ev.wait() return w, out_v + + +def dpnp_solve(a, b): + """ + dpnp_solve(a, b) + + Return the the solution to the system of linear equations with + a square coefficient matrix `a` and multiple right-hand sides `b`. + + """ + + a_usm_arr = dpnp.get_usm_ndarray(a) + b_usm_arr = dpnp.get_usm_ndarray(b) + + b_order = "C" if b.flags.c_contiguous else "F" + + if a.dtype != b.dtype: + raise ValueError("a and b must be of the same type") + + exec_q = dpctl.utils.get_execution_queue((a.sycl_queue, b.sycl_queue)) + if exec_q is None: + raise ValueError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + + if dpnp.issubdtype(a.dtype, dpnp.floating): + res_type = ( + a.dtype if exec_q.sycl_device.has_aspect_fp64 else dpnp.float32 + ) + elif dpnp.issubdtype(a.dtype, dpnp.complexfloating): + res_type = ( + a.dtype if exec_q.sycl_device.has_aspect_fp64 else dpnp.complex64 + ) + else: + res_type = ( + dpnp.float64 if exec_q.sycl_device.has_aspect_fp64 else dpnp.float32 + ) + + a_f = dpnp.empty_like(a, order="F", dtype=res_type) + b_f = dpnp.empty_like(b, order="F", dtype=res_type) + + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_f.get_array(), sycl_queue=a.sycl_queue + ) + b_ht_copy_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, dst=b_f.get_array(), sycl_queue=b.sycl_queue + ) + + lapack_ev = li._gesv( + exec_q, a_f.get_array(), b_f.get_array(), [a_copy_ev, b_copy_ev] + ) + + if b_order != "F": + out_v = dpnp.empty_like(b_f, order=b_order) + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_f.get_array(), + dst=out_v.get_array(), + sycl_queue=b.sycl_queue, + depends=[lapack_ev], + ) + ht_copy_out_ev.wait() + else: + out_v = b_f + + lapack_ev.wait() + b_ht_copy_ev.wait() + a_ht_copy_ev.wait() + + return out_v From 51663d0d038a6c2dac1ba2d79635b88043bcbdfc Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 6 Oct 2023 13:58:29 +0200 Subject: [PATCH 02/14] Check validity of input array shapes --- dpnp/linalg/dpnp_iface_linalg.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index a089473fffb0..ae007ee9e823 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -549,6 +549,15 @@ def solve(a, b): if m != n: raise ValueError("Last 2 dimensions of the array must be square") + if not ( + (a.ndim == b.ndim or a.ndim == b.ndim + 1) + and a.shape[:-1] == b.shape[: a.ndim - 1] + ): + raise ValueError( + "a must have (..., M, M) shape and b must have (..., M) " + "or (..., M, K)" + ) + return dpnp_solve(a, b) From 00b7041e39f024e85b1b46b832d2c916b0f1e44d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 6 Oct 2023 14:16:04 +0200 Subject: [PATCH 03/14] Add logic for a.ndim > 2 --- dpnp/linalg/dpnp_utils_linalg.py | 132 ++++++++++++++++++++++++------- 1 file changed, 104 insertions(+), 28 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index baca6d43ce3e..48940c15a193 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -29,6 +29,7 @@ import dpctl import dpctl.tensor._tensor_impl as ti +from numpy import prod import dpnp import dpnp.backend.extensions.lapack._lapack_impl as li @@ -204,34 +205,109 @@ def dpnp_solve(a, b): dpnp.float64 if exec_q.sycl_device.has_aspect_fp64 else dpnp.float32 ) - a_f = dpnp.empty_like(a, order="F", dtype=res_type) - b_f = dpnp.empty_like(b, order="F", dtype=res_type) - - a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, dst=a_f.get_array(), sycl_queue=a.sycl_queue - ) - b_ht_copy_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_usm_arr, dst=b_f.get_array(), sycl_queue=b.sycl_queue - ) - - lapack_ev = li._gesv( - exec_q, a_f.get_array(), b_f.get_array(), [a_copy_ev, b_copy_ev] - ) - - if b_order != "F": - out_v = dpnp.empty_like(b_f, order=b_order) - ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_f.get_array(), - dst=out_v.get_array(), - sycl_queue=b.sycl_queue, - depends=[lapack_ev], - ) - ht_copy_out_ev.wait() + if a.ndim > 2: + reshape = False + orig_shape_b = b.shape + if a.ndim > 3: + # get 3d input arrays by reshape + if a.ndim == b.ndim: + b = b.reshape(prod(b.shape[:-2]), b.shape[-2], b.shape[-1]) + else: + b = b.reshape(prod(b.shape[:-1]), b.shape[-1]) + + a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1]) + + a_usm_arr = dpnp.get_usm_ndarray(a) + b_usm_arr = dpnp.get_usm_ndarray(b) + reshape = True + + op_count = a.shape[0] + if op_count == 0: + return dpnp.empty_like(b, dtype=res_type) + + coeff_vecs = [None] * op_count + val_vecs = [None] * op_count + a_ht_copy_ev = [None] * op_count + b_ht_copy_ev = [None] * op_count + ht_lapack_ev = [None] * op_count + + for i in range(op_count): + # oneMKL LAPACK assumes fortran-like array as input, so + # allocate a memory with 'F' order for dpnp array of coefficient matrix + # and multiple right-hand sides + coeff_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=res_type) + val_vecs[i] = dpnp.empty_like(b[i], order="F", dtype=res_type) + + # use DPCTL tensor function to fill the array of coefficient matrix + # and multiple right-hand sides with content of input array + a_ht_copy_ev[i], a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr[i], + dst=coeff_vecs[i].get_array(), + sycl_queue=a.sycl_queue, + ) + b_ht_copy_ev[i], b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr[i], + dst=val_vecs[i].get_array(), + sycl_queue=b.sycl_queue, + ) + + # call LAPACK extension function to get the solution of the system of linear + # equations with a portion of the coefficients square matrix + ht_lapack_ev[i] = li._gesv( + exec_q, + coeff_vecs[i].get_array(), + val_vecs[i].get_array(), + depends=[a_copy_ev, b_copy_ev], + ) + + for i in range(op_count): + ht_lapack_ev[i].wait() + b_ht_copy_ev[i].wait() + a_ht_copy_ev[i].wait() + + # combine the list of solutions into a single array + out_v = dpnp.array(val_vecs, order=b_order) + if reshape: + # shape of the out_t must be equal to the shape of the right-hand sides + out_v = out_v.reshape(orig_shape_b) + return out_v else: - out_v = b_f + # oneMKL LAPACK assumes fortran-like array as input, so + # allocate a memory with 'F' order for dpnp array of coefficient matrix + # and multiple right-hand sides + a_f = dpnp.empty_like(a, order="F", dtype=res_type) + b_f = dpnp.empty_like(b, order="F", dtype=res_type) + + # use DPCTL tensor function to fill the array of coefficient matrix + # and multiple right-hand sides with content of input array + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_f.get_array(), sycl_queue=a.sycl_queue + ) + b_ht_copy_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, dst=b_f.get_array(), sycl_queue=b.sycl_queue + ) + + # call LAPACK extension function to get the solution of the system of linear + # equations with the coefficients square matrix + lapack_ev = li._gesv( + exec_q, a_f.get_array(), b_f.get_array(), [a_copy_ev, b_copy_ev] + ) + + if b_order != "F": + # need to align order of the result of solutions with the right-hand sides + out_v = dpnp.empty_like(b_f, order=b_order) + ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_f.get_array(), + dst=out_v.get_array(), + sycl_queue=b.sycl_queue, + depends=[lapack_ev], + ) + ht_copy_out_ev.wait() + else: + out_v = b_f - lapack_ev.wait() - b_ht_copy_ev.wait() - a_ht_copy_ev.wait() + lapack_ev.wait() + b_ht_copy_ev.wait() + a_ht_copy_ev.wait() - return out_v + return out_v From e4e15ad2d2faad28cf5169ce5bda21b9c7118e7c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 6 Oct 2023 14:19:50 +0200 Subject: [PATCH 04/14] Raise value_error if coeff_matrix_nd != 2 in gesv --- dpnp/backend/extensions/lapack/gesv.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 277886b78954..d84f51a0005e 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -135,7 +135,13 @@ sycl::event gesv(sycl::queue exec_q, dpctl::tensor::usm_ndarray hand_sides, const std::vector &depends) { - // check ndim hand_sides + const int coeff_matrix_nd = coeff_matrix.get_ndim(); + + if (coeff_matrix_nd != 2) { + throw py::value_error( + "Unexpected ndim=" + std::to_string(coeff_matrix_nd) + + " of an input array with coefficients"); + } const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw(); const py::ssize_t *hand_sides_shape = hand_sides.get_shape_raw(); @@ -160,16 +166,10 @@ sycl::event gesv(sycl::queue exec_q, } bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous(); - bool is_hand_sides_c_contig = hand_sides.is_c_contiguous(); if (!is_coeff_matrix_f_contig) { throw py::value_error("An array with coefficients " "must be F-contiguous"); } - else if (!is_hand_sides_c_contig) { - throw py::value_error( - "An array with the output solutions of the coefficient matrix" - "must be C-contiguous"); - } auto array_types = dpctl_td_ns::usm_ndarray_types(); int coeff_matrix_type_id = From e762c66fa0fecd718868d8df388f0274f64c63fc Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 6 Oct 2023 14:28:38 +0200 Subject: [PATCH 05/14] Add cupy tests for dpnp.linalg.solve() --- .../cupy/linalg_tests/test_solve.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/third_party/cupy/linalg_tests/test_solve.py diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py new file mode 100644 index 000000000000..17cc11be461e --- /dev/null +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -0,0 +1,80 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from tests.helper import has_support_aspect64 +from tests.third_party.cupy import testing + + +@testing.parameterize( + *testing.product( + { + "order": ["C", "F"], + } + ) +) +class TestSolve(unittest.TestCase): + # TODO: add get_batched_gesv_limit + # def setUp(self): + # if self.batched_gesv_limit is not None: + # self.old_limit = get_batched_gesv_limit() + # set_batched_gesv_limit(self.batched_gesv_limit) + + # def tearDown(self): + # if self.batched_gesv_limit is not None: + # set_batched_gesv_limit(self.old_limit) + + @testing.for_dtypes("ifdFD") + @testing.numpy_cupy_allclose( + atol=1e-3, contiguous_check=False, type_check=has_support_aspect64() + ) + def check_x(self, a_shape, b_shape, xp, dtype): + a = testing.shaped_random(a_shape, xp, dtype=dtype, seed=0, scale=20) + b = testing.shaped_random(b_shape, xp, dtype=dtype, seed=1) + a = a.copy(order=self.order) + b = b.copy(order=self.order) + a_copy = a.copy() + b_copy = b.copy() + result = xp.linalg.solve(a, b) + numpy.testing.assert_array_equal(a_copy, a) + numpy.testing.assert_array_equal(b_copy, b) + return result + + def test_solve(self): + self.check_x((4, 4), (4,)) + self.check_x((5, 5), (5, 2)) + self.check_x((2, 4, 4), (2, 4)) + self.check_x((2, 5, 5), (2, 5, 2)) + self.check_x((2, 3, 2, 2), (2, 3, 2)) + self.check_x((2, 3, 3, 3), (2, 3, 3, 2)) + self.check_x((0, 0), (0,)) + self.check_x((0, 0), (0, 2)) + self.check_x((0, 2, 2), (0, 2)) + self.check_x((0, 2, 2), (0, 2, 3)) + + def check_shape(self, a_shape, b_shape, error_type): + for xp in (numpy, cupy): + a = xp.random.rand(*a_shape) + b = xp.random.rand(*b_shape) + with pytest.raises(error_type): + xp.linalg.solve(a, b) + + # dpnp.linalg.solve() raises RuntimeError instead of numpy.linalg.LinAlgError + # @testing.numpy_cupy_allclose() + # def test_solve_singular_empty(self, xp): + # a = xp.zeros((3, 3)) # singular + # b = xp.empty((3, 0)) # nrhs = 0 + # # LinAlgError("Singular matrix") is not raised + # return xp.linalg.solve(a, b) + + # dpnp.linalg.solve() raises RuntimeError instead of numpy.linalg.LinAlgError + def test_invalid_shape(self): + # self.check_shape((2, 3), (4,), numpy.linalg.LinAlgError) + self.check_shape((3, 3), (2,), ValueError) + self.check_shape((3, 3), (2, 2), ValueError) + # self.check_shape((3, 3, 4), (3,), numpy.linalg.LinAlgError) + self.check_shape((2, 3, 3), (3,), ValueError) + self.check_shape((3, 3), (0,), ValueError) + # self.check_shape((0, 3, 4), (3,), numpy.linalg.LinAlgError) From a0f76d529d12c57aea5fddeae0066bb655edb08d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 10 Oct 2023 22:34:46 +0200 Subject: [PATCH 06/14] Add LinAlgError exception and extend error handling for mkl::lapack::gesv --- dpnp/backend/extensions/lapack/gesv.cpp | 46 +++++++++++++++++-- dpnp/backend/extensions/lapack/lapack_py.cpp | 4 ++ .../extensions/lapack/linalg_exceptions.hpp | 29 ++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 dpnp/backend/extensions/lapack/linalg_exceptions.hpp diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index d84f51a0005e..87a9941df879 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -30,6 +30,7 @@ #include "utils/type_utils.hpp" #include "gesv.hpp" +#include "linalg_exceptions.hpp" #include "types_matrix.hpp" #include "dpnp_utils.hpp" @@ -102,14 +103,51 @@ static sycl::event gesv_impl(sycl::queue exec_q, // routine for storing intermediate results. scratchpad_size, depends); } catch (mkl_lapack::exception const &e) { - error_msg - << "Unexpected MKL exception caught during gesv() call:\nreason: " - << e.what() << "\ninfo: " << e.info(); info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info > 0) { + T host_U; + exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T)) + .wait(); + + using ThresholdType = typename std::conditional< + std::is_same::value, float, + typename std::conditional< + std::is_same::value, double, + typename std::conditional< + std::is_same>::value, float, + double>::type>::type>::type; + + const auto threshold = + std::numeric_limits::epsilon() * 100; + if (std::abs(host_U) < threshold) { + sycl::free(scratchpad, exec_q); + throw LinAlgError("The input coefficient matrix is singular."); + } + else { + error_msg << "Unexpected MKL exception caught during gesv() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else { + error_msg << "Unexpected MKL exception caught during gesv() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } } catch (sycl::exception const &e) { error_msg << "Unexpected SYCL exception caught during gesv() call:\n" << e.what(); - info = -1; + info = -11; } if (info != 0) // an unexected error occurs diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index e7c419942cdb..8a9c07c20c12 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -32,6 +32,7 @@ #include "gesv.hpp" #include "heevd.hpp" +#include "linalg_exceptions.hpp" #include "syevd.hpp" namespace lapack_ext = dpnp::backend::ext::lapack; @@ -52,6 +53,9 @@ void init_dispatch_tables(void) PYBIND11_MODULE(_lapack_impl, m) { + py::register_local_exception(m, "LinAlgError", + PyExc_ValueError); + init_dispatch_vectors(); init_dispatch_tables(); diff --git a/dpnp/backend/extensions/lapack/linalg_exceptions.hpp b/dpnp/backend/extensions/lapack/linalg_exceptions.hpp new file mode 100644 index 000000000000..9963b2aa834e --- /dev/null +++ b/dpnp/backend/extensions/lapack/linalg_exceptions.hpp @@ -0,0 +1,29 @@ +#pragma once +#include +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +class LinAlgError : public std::exception +{ +public: + explicit LinAlgError(const char *message) : msg_(message) {} + + const char *what() const noexcept override + { + return msg_.c_str(); + } + +private: + std::string msg_; +}; +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp From 89f47c75c42c239963e4431a72199b2a8b3fda97 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 10 Oct 2023 22:36:00 +0200 Subject: [PATCH 07/14] Update test_solve --- .../cupy/linalg_tests/test_solve.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index 17cc11be461e..68c4a0ea1b60 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -61,20 +61,24 @@ def check_shape(self, a_shape, b_shape, error_type): with pytest.raises(error_type): xp.linalg.solve(a, b) - # dpnp.linalg.solve() raises RuntimeError instead of numpy.linalg.LinAlgError - # @testing.numpy_cupy_allclose() - # def test_solve_singular_empty(self, xp): - # a = xp.zeros((3, 3)) # singular - # b = xp.empty((3, 0)) # nrhs = 0 - # # LinAlgError("Singular matrix") is not raised - # return xp.linalg.solve(a, b) + def test_solve_singular_empty(self): + for xp in (numpy, cupy): + a = xp.zeros((3, 3)) # singular + b = xp.empty((3, 0)) # nrhs = 0 + with pytest.raises((numpy.linalg.LinAlgError, ValueError)): + xp.linalg.solve(a, b) - # dpnp.linalg.solve() raises RuntimeError instead of numpy.linalg.LinAlgError + # dpnp.linalg.solve() raises a LinAlgError which is defined + # through a ValueError in the C++ bindings using pybind11 def test_invalid_shape(self): - # self.check_shape((2, 3), (4,), numpy.linalg.LinAlgError) + self.check_shape((2, 3), (4,), (numpy.linalg.LinAlgError, ValueError)) self.check_shape((3, 3), (2,), ValueError) self.check_shape((3, 3), (2, 2), ValueError) - # self.check_shape((3, 3, 4), (3,), numpy.linalg.LinAlgError) + self.check_shape( + (3, 3, 4), (3,), (numpy.linalg.LinAlgError, ValueError) + ) self.check_shape((2, 3, 3), (3,), ValueError) self.check_shape((3, 3), (0,), ValueError) - # self.check_shape((0, 3, 4), (3,), numpy.linalg.LinAlgError) + self.check_shape( + (0, 3, 4), (3,), (numpy.linalg.LinAlgError, ValueError) + ) From 159c46019189de5900b8abb7d8344bc239dd7ac6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 10 Oct 2023 22:45:24 +0200 Subject: [PATCH 08/14] Add test_solve to test scope --- .github/workflows/conda-package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 365d8490ec67..f0c28edaaa58 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -29,6 +29,7 @@ env: test_usm_type.py third_party/cupy/core_tests/test_ndarray_complex_ops.py third_party/cupy/linalg_tests/test_product.py + third_party/cupy/linalg_tests/test_solve.py third_party/cupy/logic_tests/test_comparison.py third_party/cupy/logic_tests/test_truth.py third_party/cupy/manipulation_tests/test_basic.py From 2ad3aa816867f9ef842a50e757d6c1306bdf7534 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 12 Oct 2023 16:39:06 +0200 Subject: [PATCH 09/14] Fix getting nrhs to avoid CPU falling tests --- dpnp/backend/extensions/lapack/gesv.cpp | 5 ++--- dpnp/linalg/dpnp_utils_linalg.py | 3 +++ tests/third_party/cupy/linalg_tests/test_solve.py | 10 ++++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 87a9941df879..5b8b781d7d4b 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -174,6 +174,7 @@ sycl::event gesv(sycl::queue exec_q, const std::vector &depends) { const int coeff_matrix_nd = coeff_matrix.get_ndim(); + const int hand_sides_nd = hand_sides.get_ndim(); if (coeff_matrix_nd != 2) { throw py::value_error( @@ -188,8 +189,6 @@ sycl::event gesv(sycl::queue exec_q, throw py::value_error("The input coefficients array must be square "); } - // elif check shape coeff_matrix and hand_sides - // check compatibility of execution queue and allocation queue if (!dpctl::utils::queues_are_compatible(exec_q, {coeff_matrix, hand_sides})) { @@ -230,8 +229,8 @@ sycl::event gesv(sycl::queue exec_q, char *hand_sides_data = hand_sides.get_data(); const std::int64_t n = coeff_matrix_shape[0]; - const std::int64_t nrhs = hand_sides_shape[0]; const std::int64_t m = hand_sides_shape[0]; + const std::int64_t nrhs = (hand_sides_nd > 1) ? hand_sides_shape[1] : 1; const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, m); diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 48940c15a193..d6bcd92fb788 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -205,6 +205,9 @@ def dpnp_solve(a, b): dpnp.float64 if exec_q.sycl_device.has_aspect_fp64 else dpnp.float32 ) + if b.size == 0: + return dpnp.empty(b.shape, dtype=res_type) + if a.ndim > 2: reshape = False orig_shape_b = b.shape diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index 68c4a0ea1b60..6ebb1f213b69 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -65,8 +65,14 @@ def test_solve_singular_empty(self): for xp in (numpy, cupy): a = xp.zeros((3, 3)) # singular b = xp.empty((3, 0)) # nrhs = 0 - with pytest.raises((numpy.linalg.LinAlgError, ValueError)): - xp.linalg.solve(a, b) + # numpy <= 1.24.* raises LinAlgError when b.size == 0 + # numpy >= 1.25 returns an empty array + if xp == numpy: + with pytest.raises(numpy.linalg.LinAlgError): + xp.linalg.solve(a, b) + else: + result = xp.linalg.solve(a, b) + assert result.size == 0 # dpnp.linalg.solve() raises a LinAlgError which is defined # through a ValueError in the C++ bindings using pybind11 From ac02cf2368b43da5f5e800bc69252d32abe12140 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 12 Oct 2023 17:00:44 +0200 Subject: [PATCH 10/14] Add test_solve to test_sycl_queue --- tests/test_sycl_queue.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index a11880d9d444..b8d94abb54b9 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1194,3 +1194,27 @@ def test_take(device): result_queue = result.get_array().sycl_queue assert_sycl_queue_equal(result_queue, expected_queue) + + +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_solve(device): + x = [[1.0, 2.0], [3.0, 5.0]] + y = [1.0, 2.0] + + numpy_x = numpy.array(x) + numpy_y = numpy.array(y) + dpnp_x = dpnp.array(x, device=device) + dpnp_y = dpnp.array(y, device=device) + + result = dpnp.linalg.solve(dpnp_x, dpnp_y) + expected = numpy.linalg.solve(numpy_x, numpy_y) + assert_allclose(expected, result, rtol=1e-06) + + result_queue = result.sycl_queue + + assert_sycl_queue_equal(result_queue, dpnp_x.sycl_queue) + assert_sycl_queue_equal(result_queue, dpnp_y.sycl_queue) From 035d9831d3a31898d15242dd11fd6d0db03a5d63 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 13 Oct 2023 00:15:41 +0200 Subject: [PATCH 11/14] Add more tests for solve() --- tests/test_linalg.py | 39 ++++++++++++++++++- .../cupy/linalg_tests/test_solve.py | 2 +- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 2327ca2e9401..2e9ea5986d8c 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,7 +1,7 @@ import dpctl import numpy import pytest -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal, assert_raises import dpnp as inp @@ -446,3 +446,40 @@ def test_svd(type, shape): assert_allclose( inp.asnumpy(dpnp_vt)[i, :], np_vt[i, :], rtol=tol, atol=tol ) + + +class TestSolve: + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_solve(self, dtype): + a_np = numpy.array([[1, 0.5], [0.5, 1]]) + a_dp = inp.array(a_np, dtype=dtype) + + expected = numpy.linalg.solve(a_np, a_np) + result = inp.linalg.solve(a_dp, a_dp) + + assert_allclose(expected, result, rtol=1e-06) + + def test_solve_strides(self): + a_np = numpy.array( + [ + [2, 3, 1, 4, 5], + [5, 6, 7, 8, 9], + [9, 7, 7, 2, 3], + [1, 4, 5, 1, 8], + [8, 9, 8, 5, 3], + ] + ) + b_np = numpy.array([5, 8, 9, 2, 1]) + + a_dp = inp.array(a_np) + b_dp = inp.array(b_np) + + # positive strides + expected = numpy.linalg.solve(a_np[::2, ::2], b_np[::2]) + result = inp.linalg.solve(a_dp[::2, ::2], b_dp[::2]) + assert_allclose(expected, result, rtol=1e-06) + + # negative strides + expected = numpy.linalg.solve(a_np[::-2, ::-2], b_np[::-2]) + result = inp.linalg.solve(a_dp[::-2, ::-2], b_dp[::-2]) + assert_allclose(expected, result, rtol=1e-06) diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index 6ebb1f213b69..2af6362377ca 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -67,7 +67,7 @@ def test_solve_singular_empty(self): b = xp.empty((3, 0)) # nrhs = 0 # numpy <= 1.24.* raises LinAlgError when b.size == 0 # numpy >= 1.25 returns an empty array - if xp == numpy: + if xp == numpy and testing.numpy_satisfies("<1.25.0"): with pytest.raises(numpy.linalg.LinAlgError): xp.linalg.solve(a, b) else: From f80ba4b6844d6a6672082c9923766b3e051d5edd Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 13 Oct 2023 14:01:07 +0200 Subject: [PATCH 12/14] Register a LinAlgError in dpnp.linalg submodule --- dpnp/backend/extensions/lapack/lapack_py.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 8a9c07c20c12..5ae26d595cb8 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -53,8 +53,9 @@ void init_dispatch_tables(void) PYBIND11_MODULE(_lapack_impl, m) { - py::register_local_exception(m, "LinAlgError", - PyExc_ValueError); + py::module_ linalg_module = py::module_::import("dpnp.linalg"); + py::register_exception( + linalg_module, "LinAlgError", PyExc_ValueError); init_dispatch_vectors(); init_dispatch_tables(); From 9a3db74320da194594f8fc69dea1d99f85ff9708 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 13 Oct 2023 15:11:04 +0200 Subject: [PATCH 13/14] Raise dpnp.linalg.LinAlgError in solve() --- dpnp/linalg/dpnp_iface_linalg.py | 8 +++-- tests/test_linalg.py | 4 +-- .../cupy/linalg_tests/test_solve.py | 33 ++++++++++--------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index ae007ee9e823..b9a0d889e046 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -540,20 +540,22 @@ def solve(a, b): ) if a.ndim < 2: - raise ValueError( + raise dpnp.linalg.LinAlgError( f"{a.ndim}-dimensional array given. Array must be " "at least two-dimensional" ) m, n = a.shape[-2:] if m != n: - raise ValueError("Last 2 dimensions of the array must be square") + raise dpnp.linalg.LinAlgError( + "Last 2 dimensions of the array must be square" + ) if not ( (a.ndim == b.ndim or a.ndim == b.ndim + 1) and a.shape[:-1] == b.shape[: a.ndim - 1] ): - raise ValueError( + raise dpnp.linalg.LinAlgError( "a must have (..., M, M) shape and b must have (..., M) " "or (..., M, K)" ) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 2e9ea5986d8c..7782fde2431c 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -477,9 +477,9 @@ def test_solve_strides(self): # positive strides expected = numpy.linalg.solve(a_np[::2, ::2], b_np[::2]) result = inp.linalg.solve(a_dp[::2, ::2], b_dp[::2]) - assert_allclose(expected, result, rtol=1e-06) + assert_allclose(expected, result, rtol=1e-05) # negative strides expected = numpy.linalg.solve(a_np[::-2, ::-2], b_np[::-2]) result = inp.linalg.solve(a_dp[::-2, ::-2], b_dp[::-2]) - assert_allclose(expected, result, rtol=1e-06) + assert_allclose(expected, result, rtol=1e-05) diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index 2af6362377ca..5de4aacc789b 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -54,8 +54,8 @@ def test_solve(self): self.check_x((0, 2, 2), (0, 2)) self.check_x((0, 2, 2), (0, 2, 3)) - def check_shape(self, a_shape, b_shape, error_type): - for xp in (numpy, cupy): + def check_shape(self, a_shape, b_shape, error_types): + for xp, error_type in error_types.items(): a = xp.random.rand(*a_shape) b = xp.random.rand(*b_shape) with pytest.raises(error_type): @@ -74,17 +74,20 @@ def test_solve_singular_empty(self): result = xp.linalg.solve(a, b) assert result.size == 0 - # dpnp.linalg.solve() raises a LinAlgError which is defined - # through a ValueError in the C++ bindings using pybind11 def test_invalid_shape(self): - self.check_shape((2, 3), (4,), (numpy.linalg.LinAlgError, ValueError)) - self.check_shape((3, 3), (2,), ValueError) - self.check_shape((3, 3), (2, 2), ValueError) - self.check_shape( - (3, 3, 4), (3,), (numpy.linalg.LinAlgError, ValueError) - ) - self.check_shape((2, 3, 3), (3,), ValueError) - self.check_shape((3, 3), (0,), ValueError) - self.check_shape( - (0, 3, 4), (3,), (numpy.linalg.LinAlgError, ValueError) - ) + linalg_errors = { + numpy: numpy.linalg.LinAlgError, + cupy: cupy.linalg.LinAlgError, + } + value_errors = { + numpy: ValueError, + cupy: ValueError, + } + + self.check_shape((2, 3), (4,), linalg_errors) + self.check_shape((3, 3), (2,), value_errors) + self.check_shape((3, 3), (2, 2), value_errors) + self.check_shape((3, 3, 4), (3,), linalg_errors) + self.check_shape((2, 3, 3), (3,), value_errors) + self.check_shape((3, 3), (0,), value_errors) + self.check_shape((0, 3, 4), (3,), linalg_errors) From a8b4fec4d8d295bd2c1d8aecfef2f9a8bd51f1f4 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Sun, 15 Oct 2023 20:01:50 +0200 Subject: [PATCH 14/14] Small changes to the docstrings --- dpnp/backend/extensions/lapack/gesv.cpp | 53 +++++++++++-------- dpnp/backend/extensions/lapack/gesv.hpp | 2 +- dpnp/backend/extensions/lapack/lapack_py.cpp | 10 ++-- .../extensions/lapack/linalg_exceptions.hpp | 25 +++++++++ dpnp/linalg/dpnp_iface_linalg.py | 6 +-- dpnp/linalg/dpnp_utils_linalg.py | 32 ++++++----- 6 files changed, 83 insertions(+), 45 deletions(-) diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 5b8b781d7d4b..94b4f326eee1 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -170,67 +170,74 @@ static sycl::event gesv_impl(sycl::queue exec_q, sycl::event gesv(sycl::queue exec_q, dpctl::tensor::usm_ndarray coeff_matrix, - dpctl::tensor::usm_ndarray hand_sides, + dpctl::tensor::usm_ndarray dependent_vals, const std::vector &depends) { const int coeff_matrix_nd = coeff_matrix.get_ndim(); - const int hand_sides_nd = hand_sides.get_ndim(); + const int dependent_vals_nd = dependent_vals.get_ndim(); if (coeff_matrix_nd != 2) { - throw py::value_error( - "Unexpected ndim=" + std::to_string(coeff_matrix_nd) + - " of an input array with coefficients"); + throw py::value_error("The coefficient matrix has ndim=" + + std::to_string(coeff_matrix_nd) + + ", but a 2-dimensional array is expected."); } const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw(); - const py::ssize_t *hand_sides_shape = hand_sides.get_shape_raw(); + const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw(); if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) { - throw py::value_error("The input coefficients array must be square "); + throw py::value_error("The coefficient matrix must be square," + " but got a shape of (" + + std::to_string(coeff_matrix_shape[0]) + ", " + + std::to_string(coeff_matrix_shape[1]) + ")."); } // check compatibility of execution queue and allocation queue if (!dpctl::utils::queues_are_compatible(exec_q, - {coeff_matrix, hand_sides})) { + {coeff_matrix, dependent_vals})) + { throw py::value_error( "Execution queue is not compatible with allocation queues"); } auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(coeff_matrix, hand_sides)) { - throw py::value_error("Arrays with coefficients and hand sides are " - "overlapping segments of memory"); + if (overlap(coeff_matrix, dependent_vals)) { + throw py::value_error( + "The arrays of coefficients and dependent variables " + "are overlapping segments of memory"); } bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous(); if (!is_coeff_matrix_f_contig) { - throw py::value_error("An array with coefficients " + throw py::value_error("The coefficient matrix " "must be F-contiguous"); } auto array_types = dpctl_td_ns::usm_ndarray_types(); int coeff_matrix_type_id = array_types.typenum_to_lookup_id(coeff_matrix.get_typenum()); - int hand_sides_type_id = - array_types.typenum_to_lookup_id(hand_sides.get_typenum()); + int dependent_vals_type_id = + array_types.typenum_to_lookup_id(dependent_vals.get_typenum()); - if (coeff_matrix_type_id != hand_sides_type_id) { - throw py::value_error( - "Types of coefficients and hand sides are missmatched"); + if (coeff_matrix_type_id != dependent_vals_type_id) { + throw py::value_error("The types of the coefficient matrix and " + "dependent variables are mismatched"); } gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id]; if (gesv_fn == nullptr) { - throw py::value_error("No gesv implementation defined for a type of " - "coefficient matrix and hand sides"); + throw py::value_error( + "No gesv implementation defined for the provided type " + "of the coefficient matrix."); } char *coeff_matrix_data = coeff_matrix.get_data(); - char *hand_sides_data = hand_sides.get_data(); + char *dependent_vals_data = dependent_vals.get_data(); const std::int64_t n = coeff_matrix_shape[0]; - const std::int64_t m = hand_sides_shape[0]; - const std::int64_t nrhs = (hand_sides_nd > 1) ? hand_sides_shape[1] : 1; + const std::int64_t m = dependent_vals_shape[0]; + const std::int64_t nrhs = + (dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1; const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, m); @@ -241,7 +248,7 @@ sycl::event gesv(sycl::queue exec_q, std::vector host_task_events; sycl::event gesv_ev = gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, d_ipiv, - hand_sides_data, ldb, host_task_events, depends); + dependent_vals_data, ldb, host_task_events, depends); return gesv_ev; } diff --git a/dpnp/backend/extensions/lapack/gesv.hpp b/dpnp/backend/extensions/lapack/gesv.hpp index 6c84b08acaf1..3b9a099e9795 100644 --- a/dpnp/backend/extensions/lapack/gesv.hpp +++ b/dpnp/backend/extensions/lapack/gesv.hpp @@ -40,7 +40,7 @@ namespace lapack { extern sycl::event gesv(sycl::queue exec_q, dpctl::tensor::usm_ndarray coeff_matrix, - dpctl::tensor::usm_ndarray hand_sides, + dpctl::tensor::usm_ndarray dependent_vals, const std::vector &depends); extern void init_gesv_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 5ae26d595cb8..feb96dc0772c 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -53,6 +53,7 @@ void init_dispatch_tables(void) PYBIND11_MODULE(_lapack_impl, m) { + // Register a custom LinAlgError exception in the dpnp.linalg submodule py::module_ linalg_module = py::module_::import("dpnp.linalg"); py::register_exception( linalg_module, "LinAlgError", PyExc_ValueError); @@ -76,9 +77,8 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_gesv", &lapack_ext::gesv, "Call `gesv` from OneMKL LAPACK library to return " - "solution to the system of linear equations with a square " - "coefficient matrix A" - "and multiple right-hand sides", - py::arg("sycl_queue"), py::arg("coeff_matrix"), py::arg("hand_sides"), - py::arg("depends") = py::list()); + "the solution of a system of linear equations with " + "a square coefficient matrix A and multiple dependent variables", + py::arg("sycl_queue"), py::arg("coeff_matrix"), + py::arg("dependent_vals"), py::arg("depends") = py::list()); } diff --git a/dpnp/backend/extensions/lapack/linalg_exceptions.hpp b/dpnp/backend/extensions/lapack/linalg_exceptions.hpp index 9963b2aa834e..083be22429c0 100644 --- a/dpnp/backend/extensions/lapack/linalg_exceptions.hpp +++ b/dpnp/backend/extensions/lapack/linalg_exceptions.hpp @@ -1,3 +1,28 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + #pragma once #include #include diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index b9a0d889e046..3558010a8301 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -541,14 +541,14 @@ def solve(a, b): if a.ndim < 2: raise dpnp.linalg.LinAlgError( - f"{a.ndim}-dimensional array given. Array must be " - "at least two-dimensional" + f"{a.ndim}-dimensional array given. The input coefficient " + "array must be at least two-dimensional" ) m, n = a.shape[-2:] if m != n: raise dpnp.linalg.LinAlgError( - "Last 2 dimensions of the array must be square" + "Last 2 dimensions of the input coefficient array must be square" ) if not ( diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index d6bcd92fb788..a6c5ceadc8ab 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -173,7 +173,8 @@ def dpnp_solve(a, b): dpnp_solve(a, b) Return the the solution to the system of linear equations with - a square coefficient matrix `a` and multiple right-hand sides `b`. + a square coefficient matrix `a` and multiple dependent variables + array `b`. """ @@ -237,12 +238,13 @@ def dpnp_solve(a, b): for i in range(op_count): # oneMKL LAPACK assumes fortran-like array as input, so # allocate a memory with 'F' order for dpnp array of coefficient matrix - # and multiple right-hand sides + # and multiple dependent variables array coeff_vecs[i] = dpnp.empty_like(a[i], order="F", dtype=res_type) val_vecs[i] = dpnp.empty_like(b[i], order="F", dtype=res_type) - # use DPCTL tensor function to fill the array of coefficient matrix - # and multiple right-hand sides with content of input array + # use DPCTL tensor function to fill the coefficient matrix array + # and the array of multiple dependent variables with content + # from the input arrays a_ht_copy_ev[i], a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=a_usm_arr[i], dst=coeff_vecs[i].get_array(), @@ -254,8 +256,9 @@ def dpnp_solve(a, b): sycl_queue=b.sycl_queue, ) - # call LAPACK extension function to get the solution of the system of linear - # equations with a portion of the coefficients square matrix + # Call the LAPACK extension function _gesv to solve the system of linear + # equations using a portion of the coefficient square matrix and a + # corresponding portion of the dependent variables array. ht_lapack_ev[i] = li._gesv( exec_q, coeff_vecs[i].get_array(), @@ -271,18 +274,20 @@ def dpnp_solve(a, b): # combine the list of solutions into a single array out_v = dpnp.array(val_vecs, order=b_order) if reshape: - # shape of the out_t must be equal to the shape of the right-hand sides + # shape of the out_v must be equal to the shape of the array of + # dependent variables out_v = out_v.reshape(orig_shape_b) return out_v else: # oneMKL LAPACK assumes fortran-like array as input, so # allocate a memory with 'F' order for dpnp array of coefficient matrix - # and multiple right-hand sides + # and multiple dependent variables a_f = dpnp.empty_like(a, order="F", dtype=res_type) b_f = dpnp.empty_like(b, order="F", dtype=res_type) - # use DPCTL tensor function to fill the array of coefficient matrix - # and multiple right-hand sides with content of input array + # use DPCTL tensor function to fill the coefficient matrix array + # and the array of multiple dependent variables with content + # from the input arrays a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=a_usm_arr, dst=a_f.get_array(), sycl_queue=a.sycl_queue ) @@ -290,14 +295,15 @@ def dpnp_solve(a, b): src=b_usm_arr, dst=b_f.get_array(), sycl_queue=b.sycl_queue ) - # call LAPACK extension function to get the solution of the system of linear - # equations with the coefficients square matrix + # Call the LAPACK extension function _gesv to solve the system of linear + # equations with the coefficient square matrix and the dependent variables array. lapack_ev = li._gesv( exec_q, a_f.get_array(), b_f.get_array(), [a_copy_ev, b_copy_ev] ) if b_order != "F": - # need to align order of the result of solutions with the right-hand sides + # need to align order of the result of solutions with the + # input array of multiple dependent variables out_v = dpnp.empty_like(b_f, order=b_order) ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=b_f.get_array(),