From 0f7ba4608c0d93786db999e901cc3eaeb522e5c1 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 7 Dec 2023 14:28:12 +0100 Subject: [PATCH 01/26] Add a new impl of dpnp.linalg.cholesky --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 + dpnp/backend/extensions/lapack/lapack_py.cpp | 19 +- dpnp/backend/extensions/lapack/potrf.cpp | 183 ++++++++++++++++++ dpnp/backend/extensions/lapack/potrf.hpp | 63 ++++++ .../extensions/lapack/types_matrix.hpp | 26 +++ dpnp/linalg/dpnp_iface_linalg.py | 85 +++++--- dpnp/linalg/dpnp_utils_linalg.py | 71 ++++++- 7 files changed, 412 insertions(+), 36 deletions(-) create mode 100644 dpnp/backend/extensions/lapack/potrf.cpp create mode 100644 dpnp/backend/extensions/lapack/potrf.hpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 7679db38d6a7..73287f9057a1 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -28,6 +28,7 @@ set(python_module_name _lapack_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp ) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 97b67d59e24e..aa8fc24a4524 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -31,6 +31,7 @@ #include #include "heevd.hpp" +#include "potrf.hpp" #include "syevd.hpp" namespace lapack_ext = dpnp::backend::ext::lapack; @@ -39,6 +40,7 @@ namespace py = pybind11; // populate dispatch vectors void init_dispatch_vectors(void) { + lapack_ext::init_potrf_dispatch_vector(); lapack_ext::init_syevd_dispatch_vector(); } @@ -60,10 +62,17 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("eig_vecs"), py::arg("eig_vals"), py::arg("depends") = py::list()); - m.def("_syevd", &lapack_ext::syevd, - "Call `syevd` from OneMKL LAPACK library to return " - "the eigenvalues and eigenvectors of a real symmetric matrix", - py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), - py::arg("eig_vecs"), py::arg("eig_vals"), + m.def("_potrf", &lapack_ext::potrf, + "Call `potrf` from OneMKL LAPACK library to return " + "the Cholesky factorization of a symmetric positive-definite matrix", + py::arg("sycl_queue"), py::arg("n"), py::arg("a_array"), py::arg("depends") = py::list()); + v + + m.def("_syevd", &lapack_ext::syevd, + "Call `syevd` from OneMKL LAPACK library to return " + "the eigenvalues and eigenvectors of a real symmetric matrix", + py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), + py::arg("eig_vecs"), py::arg("eig_vals"), + py::arg("depends") = py::list()); } diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp new file mode 100644 index 000000000000..88d305a9a3c1 --- /dev/null +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -0,0 +1,183 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "potrf.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*potrf_impl_fn_ptr_t)(sycl::queue, + oneapi::mkl::uplo, + const std::int64_t, + char *, + std::int64_t, + std::vector &, + const std::vector &); + +static potrf_impl_fn_ptr_t potrf_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event potrf_impl(sycl::queue exec_q, + oneapi::mkl::uplo upper_lower, + const std::int64_t n, + char *in_a, + std::int64_t lda, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + + const std::int64_t scratchpad_size = + oneapi::mkl::lapack::potrf_scratchpad_size(exec_q, upper_lower, n, + lda); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + + sycl::event potrf_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + potrf_event = oneapi::mkl::lapack::potrf( + exec_q, + upper_lower, // + n, // Order of the square matrix; (0 ≤ n). + a, // Pointer to the n-by-n matrix. + lda, // The leading dimension of `a`. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + error_msg + << "Unexpected MKL exception caught during potrf() call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + info = e.info(); + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during potrf() call:\n" + << e.what(); + info = -1; + } + + if (info != 0) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + // TODO: use LinAlgError + if (info == 2) { + throw py::value_error("Matrix is not positive definite"); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(potrf_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return potrf_event; +} + +std::pair + potrf(sycl::queue q, + const std::int64_t n, + dpctl::tensor::usm_ndarray a_array, + const std::vector &depends) +{ + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + + potrf_impl_fn_ptr_t potrf_fn = potrf_dispatch_vector[a_array_type_id]; + if (potrf_fn == nullptr) { + throw py::value_error( + "No potrf implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + const std::int64_t lda = std::max(1UL, n); + oneapi::mkl::uplo upper_lower = oneapi::mkl::uplo::upper; + + std::vector host_task_events; + sycl::event potrf_ev = potrf_fn(q, upper_lower, n, a_array_data, lda, + host_task_events, depends); + + sycl::event args_ev = + dpctl::utils::keep_args_alive(q, {a_array}, host_task_events); + + return std::make_pair(args_ev, potrf_ev); +} + +template +struct PotrfContigFactory +{ + fnT get() + { + if constexpr (types::PotrfTypePairSupportFactory::is_defined) { + return potrf_impl; + } + else { + return nullptr; + } + } +}; + +void init_potrf_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(potrf_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/potrf.hpp b/dpnp/backend/extensions/lapack/potrf.hpp new file mode 100644 index 000000000000..7652fe2f1403 --- /dev/null +++ b/dpnp/backend/extensions/lapack/potrf.hpp @@ -0,0 +1,63 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern std::pair + potrf(sycl::queue exec_q, + const std::int64_t n, + dpctl::tensor::usm_ndarray a_array, + const std::vector &depends = {}); + +// extern std::pair +// potrf_batch(sycl::queue exec_q, +// dpctl::tensor::usm_ndarray a_array, +// dpctl::tensor::usm_ndarray ipiv_array, +// dpctl::tensor::usm_ndarray dev_info_array, +// std::int64_t n, +// std::int64_t stride_a, +// std::int64_t stride_ipiv, +// std::int64_t batch_size, +// const std::vector &depends = {}); + +extern void init_potrf_dispatch_vector(void); +// extern void init_potrf_batch_dispatch_vector(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 3cab18d3c63d..1e18faddfd13 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -63,6 +63,32 @@ struct HeevdTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::potrf + * function. + * + * @tparam T Type of array containing input matrix, + * as well as the output array for storing the Cholesky factor L. + */ +template +struct PotrfTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::syevd diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index c7437b30da60..409ffe01c47d 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -47,7 +47,7 @@ from dpnp.dpnp_utils import * from dpnp.linalg.dpnp_algo_linalg import * -from .dpnp_utils_linalg import dpnp_eigh +from .dpnp_utils_linalg import dpnp_cholesky, dpnp_eigh __all__ = [ "cholesky", @@ -66,52 +66,77 @@ ] -def cholesky(input): +def cholesky(a): """ Cholesky decomposition. - Return the Cholesky decomposition, `L * L.H`, of the square matrix `input`, + Return the Cholesky decomposition, `L * L.H`, of the square matrix `a`, where `L` is lower-triangular and .H is the conjugate transpose operator - (which is the ordinary transpose if `input` is real-valued). `input` must be + (which is the ordinary transpose if `a` is real-valued). `a` must be Hermitian (symmetric if real-valued) and positive-definite. No checking is performed to verify whether `a` is Hermitian or not. - In addition, only the lower-triangular and diagonal elements of `input` + In addition, only the lower-triangular and diagonal elements of `a` are used. Only `L` is actually returned. + For full documentation refer to :obj:`numpy.linalg.cholesky`. + Parameters ---------- - input : (..., M, M) array_like + a : (..., M, M) dpnp.ndarray Hermitian (symmetric if all elements are real), positive-definite input matrix. Returns ------- - L : (..., M, M) array_like - Upper or lower-triangular Cholesky factor of `input`. Returns a - matrix object if `input` is a matrix object. + L : (..., M, M) dpnp.ndarray + Lower-triangular Cholesky factor of `a`. Returns `a` + matrix object if `a` is a matrix object. + + Limitations + ----------- + Parameter `a` is supported as :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. + Input array data types are limited by supported DPNP :ref:`Data types`. + + See Also + -------- + :obj:`scipy.linalg.cholesky` : Similar function in SciPy. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([[1.0, 2.0],[2.0, 5.0]]) + >>> A + array([[1., 2.], + [2., 5.]]) + >>> L = np.linalg.cholesky(A) + >>> L + array([[1., 0.], + [2., 1.]]) + >>> np.dot(L, L.T.conj()) # verify that L * L.H = A + array([[1., 2.], + [2., 5.]]) + """ - x1_desc = dpnp.get_dpnp_descriptor(input, copy_when_nondefault_queue=False) - if x1_desc: - if x1_desc.shape[-1] != x1_desc.shape[-2]: - pass - else: - if input.dtype == dpnp.int32 or input.dtype == dpnp.int64: - dev = x1_desc.get_array().sycl_device - if dev.has_aspect_fp64: - dtype = dpnp.float64 - else: - dtype = dpnp.float32 - # TODO memory copy. needs to move into DPNPC - input_ = dpnp.get_dpnp_descriptor( - dpnp.astype(input, dtype=dtype), - copy_when_nondefault_queue=False, - ) - else: - input_ = x1_desc - return dpnp_cholesky(input_).get_pyobj() - - return call_origin(numpy.linalg.cholesky, input) + # TODO: use _assert_dpnp_array + if not dpnp.is_supported_array_type(a): + raise TypeError( + "An array must be any of supported type, but got {}".format(type(a)) + ) + + # TODO: use _assert_stacked_2d + if a.ndim < 2: + raise ValueError( + f"{a.ndim}-dimensional array given. The input " + "array must be at least two-dimensional" + ) + + # TODO: use _assert_stacked_square + n, m = a.shape[-2:] + if m != n: + raise ValueError("Last 2 dimensions of the input array must be square") + + return dpnp_cholesky(a) def cond(input, p=None): diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 54c01c20248e..32b202ed1fa8 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -32,12 +32,81 @@ import dpnp import dpnp.backend.extensions.lapack._lapack_impl as li -__all__ = ["dpnp_eigh"] +__all__ = ["dpnp_eigh", "dpnp_cholesky"] _jobz = {"N": 0, "V": 1} _upper_lower = {"U": 0, "L": 1} +def dpnp_cholesky(a): + """ + dpnp_cholesky(a) + + Return the Cholesky factorization. + + """ + + a_usm_type = a.usm_type + a_sycl_queue = a.sycl_queue + a_shape = a.shape + + n = a.shape[-2] + + # TODO: Use linalg_common_type from #1598 + if dpnp.issubdtype(a.dtype, dpnp.floating): + res_type = ( + a.dtype + if a_sycl_queue.sycl_device.has_aspect_fp64 + else dpnp.float32 + ) + elif dpnp.issubdtype(a.dtype, dpnp.complexfloating): + res_type = ( + a.dtype + if a_sycl_queue.sycl_device.has_aspect_fp64 + else dpnp.complex64 + ) + else: + res_type = ( + dpnp.float64 + if a_sycl_queue.sycl_device.has_aspect_fp64 + else dpnp.float32 + ) + + if a.size == 0: + return dpnp.empty(a_shape, dtype=res_type, usm_type=a_usm_type) + + if a.ndim > 2: + pass + + else: + a_usm_arr = dpnp.get_usm_ndarray(a) + + # oneMKL LAPACK potrf overwrites `a` + a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue + ) + + # Call the LAPACK extension function _potrf + # to computes the Cholesky factorization + ht_lapack_ev, _ = li._potrf( + a_sycl_queue, + n, + a_h.get_array(), + [a_copy_ev], + ) + + ht_lapack_ev.wait() + a_ht_copy_ev.wait() + + a_h = dpnp.tril(a_h) + + return a_h + + def dpnp_eigh(a, UPLO): """ dpnp_eigh(a, UPLO) From ca9ce534b3cc339b67f289ac8d5051d2b01e5502 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 7 Dec 2023 14:33:10 +0100 Subject: [PATCH 02/26] Add cupy tests for dpnp.linalg.cholesky --- .../cupy/linalg_tests/test_decomposition.py | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/third_party/cupy/linalg_tests/test_decomposition.py diff --git a/tests/third_party/cupy/linalg_tests/test_decomposition.py b/tests/third_party/cupy/linalg_tests/test_decomposition.py new file mode 100644 index 000000000000..9375554ca0f4 --- /dev/null +++ b/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -0,0 +1,129 @@ +import unittest + +import numpy +import pytest + +import dpnp as cupy +from tests.helper import has_support_aspect64, is_cpu_device +from tests.third_party.cupy import testing + + +def random_matrix(shape, dtype, scale, sym=False): + m, n = shape[-2:] + dtype = numpy.dtype(dtype) + assert dtype.kind in "iufc" + low_s, high_s = scale + bias = None + if dtype.kind in "iu": + # For an m \times n matrix M whose element is in [-0.5, 0.5], it holds + # (singular value of M) <= \sqrt{mn} / 2 + err = numpy.sqrt(m * n) / 2.0 + low_s += err + high_s -= err + if dtype.kind in "u": + assert sym, ( + "generating nonsymmetric matrix with uint cells is not" + " supported." + ) + # (singular value of numpy.ones((m, n))) <= \sqrt{mn} + high_s = bias = high_s / (1 + numpy.sqrt(m * n)) + assert low_s <= high_s + a = numpy.random.standard_normal(shape) + if dtype.kind == "c": + a = a + 1j * numpy.random.standard_normal(shape) + u, s, vh = numpy.linalg.svd(a) + if sym: + assert m == n + vh = u.conj().swapaxes(-1, -2) + new_s = numpy.random.uniform(low_s, high_s, s.shape) + new_a = numpy.einsum("...ij,...j,...jk->...ik", u, new_s, vh) + if bias is not None: + new_a += bias + if dtype.kind in "iu": + new_a = numpy.rint(new_a) + return new_a.astype(dtype) + + +class TestCholeskyDecomposition: + @testing.numpy_cupy_allclose(atol=1e-3, type_check=has_support_aspect64()) + def check_L(self, array, xp): + a = xp.asarray(array) + return xp.linalg.cholesky(a) + + @testing.for_dtypes( + [ + numpy.int32, + numpy.int64, + numpy.uint32, + numpy.uint64, + numpy.float32, + numpy.float64, + numpy.complex64, + numpy.complex128, + ] + ) + def test_decomposition(self, dtype): + # A positive definite matrix + A = random_matrix((5, 5), dtype, scale=(10, 10000), sym=True) + self.check_L(A) + # np.linalg.cholesky only uses a lower triangle of an array + self.check_L(numpy.array([[1, 2], [1, 9]], dtype)) + + # @testing.for_dtypes([ + # numpy.int32, numpy.int64, numpy.uint32, numpy.uint64, + # numpy.float32, numpy.float64, numpy.complex64, numpy.complex128]) + # def test_batched_decomposition(self, dtype): + # if not cusolver.check_availability('potrfBatched'): + # pytest.skip('potrfBatched is not available') + # Ab1 = random_matrix((3, 5, 5), dtype, scale=(10, 10000), sym=True) + # self.check_L(Ab1) + # Ab2 = random_matrix((2, 2, 5, 5), dtype, scale=(10, 10000), sym=True) + # self.check_L(Ab2) + + @pytest.mark.parametrize( + "shape", + [ + # empty square + (0, 0), + (3, 0, 0), + # empty batch + (2, 0, 3, 4, 4), + ], + ) + @testing.for_dtypes( + [ + numpy.int32, + numpy.uint16, + numpy.float32, + numpy.float64, + numpy.complex64, + numpy.complex128, + ] + ) + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) + def test_empty(self, shape, xp, dtype): + a = xp.empty(shape, dtype=dtype) + return xp.linalg.cholesky(a) + + +class TestCholeskyInvalid(unittest.TestCase): + def check_L(self, array): + for xp in (numpy, cupy): + a = xp.asarray(array) + with pytest.raises((numpy.linalg.LinAlgError, ValueError)): + xp.linalg.cholesky(a) + + @pytest.mark.skipif(is_cpu_device(), reason="MKL bug") + @testing.for_dtypes( + [ + numpy.int32, + numpy.int64, + numpy.uint32, + numpy.uint64, + numpy.float32, + numpy.float64, + ] + ) + def test_decomposition(self, dtype): + A = numpy.array([[1, -2], [-2, 1]]).astype(dtype) + self.check_L(A) From b5582529cabd5fa8d2bcc40079a7ded46246ead2 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 7 Dec 2023 17:06:38 +0100 Subject: [PATCH 03/26] Add a batch impl of dpnp.linalg.cholesky --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 + dpnp/backend/extensions/lapack/lapack_py.cpp | 22 +- dpnp/backend/extensions/lapack/potrf.hpp | 19 +- .../backend/extensions/lapack/potrf_batch.cpp | 197 ++++++++++++++++++ .../extensions/lapack/types_matrix.hpp | 26 +++ dpnp/linalg/dpnp_utils_linalg.py | 36 +++- .../cupy/linalg_tests/test_decomposition.py | 4 +- 7 files changed, 284 insertions(+), 21 deletions(-) create mode 100644 dpnp/backend/extensions/lapack/potrf_batch.cpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 73287f9057a1..a992289da6ee 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -29,6 +29,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp ) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index aa8fc24a4524..3444f4fc0433 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -40,6 +40,7 @@ namespace py = pybind11; // populate dispatch vectors void init_dispatch_vectors(void) { + lapack_ext::init_potrf_batch_dispatch_vector(); lapack_ext::init_potrf_dispatch_vector(); lapack_ext::init_syevd_dispatch_vector(); } @@ -67,12 +68,19 @@ PYBIND11_MODULE(_lapack_impl, m) "the Cholesky factorization of a symmetric positive-definite matrix", py::arg("sycl_queue"), py::arg("n"), py::arg("a_array"), py::arg("depends") = py::list()); - v - m.def("_syevd", &lapack_ext::syevd, - "Call `syevd` from OneMKL LAPACK library to return " - "the eigenvalues and eigenvectors of a real symmetric matrix", - py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), - py::arg("eig_vecs"), py::arg("eig_vals"), - py::arg("depends") = py::list()); + m.def("_potrf_batch", &lapack_ext::potrf_batch, + "Call `potrf_batch` from OneMKL LAPACK library to return " + "the Cholesky factorization of a batch of symmetric " + "positive-definite matrix", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("n"), + py::arg("stride_a"), py::arg("batch_size"), + py::arg("depends") = py::list()); + + m.def("_syevd", &lapack_ext::syevd, + "Call `syevd` from OneMKL LAPACK library to return " + "the eigenvalues and eigenvectors of a real symmetric matrix", + py::arg("sycl_queue"), py::arg("jobz"), py::arg("upper_lower"), + py::arg("eig_vecs"), py::arg("eig_vals"), + py::arg("depends") = py::list()); } diff --git a/dpnp/backend/extensions/lapack/potrf.hpp b/dpnp/backend/extensions/lapack/potrf.hpp index 7652fe2f1403..35efda3b9b43 100644 --- a/dpnp/backend/extensions/lapack/potrf.hpp +++ b/dpnp/backend/extensions/lapack/potrf.hpp @@ -44,19 +44,16 @@ extern std::pair dpctl::tensor::usm_ndarray a_array, const std::vector &depends = {}); -// extern std::pair -// potrf_batch(sycl::queue exec_q, -// dpctl::tensor::usm_ndarray a_array, -// dpctl::tensor::usm_ndarray ipiv_array, -// dpctl::tensor::usm_ndarray dev_info_array, -// std::int64_t n, -// std::int64_t stride_a, -// std::int64_t stride_ipiv, -// std::int64_t batch_size, -// const std::vector &depends = {}); +extern std::pair + potrf_batch(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + std::int64_t n, + std::int64_t stride_a, + std::int64_t batch_size, + const std::vector &depends = {}); extern void init_potrf_dispatch_vector(void); -// extern void init_potrf_batch_dispatch_vector(void); +extern void init_potrf_batch_dispatch_vector(void); } // namespace lapack } // namespace ext } // namespace backend diff --git a/dpnp/backend/extensions/lapack/potrf_batch.cpp b/dpnp/backend/extensions/lapack/potrf_batch.cpp new file mode 100644 index 000000000000..e9e8c0c9e1b1 --- /dev/null +++ b/dpnp/backend/extensions/lapack/potrf_batch.cpp @@ -0,0 +1,197 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "potrf.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*potrf_batch_impl_fn_ptr_t)( + sycl::queue, + oneapi::mkl::uplo, + std::int64_t, + char *, + std::int64_t, + std::int64_t, + std::int64_t, + std::vector &, + const std::vector &); + +static potrf_batch_impl_fn_ptr_t + potrf_batch_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event potrf_batch_impl(sycl::queue exec_q, + oneapi::mkl::uplo upper_lower, + std::int64_t n, + char *in_a, + std::int64_t lda, + std::int64_t stride_a, + std::int64_t batch_size, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + + const std::int64_t scratchpad_size = + oneapi::mkl::lapack::potrf_batch_scratchpad_size( + exec_q, upper_lower, n, lda, stride_a, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + + sycl::event potrf_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + potrf_batch_event = oneapi::mkl::lapack::potrf_batch( + exec_q, + upper_lower, // + n, // Order of each square matrix in the batch; (0 ≤ n). + a, // Pointer to the batch of matrices. + lda, // The leading dimension of `a`. + stride_a, // Stride between matrices: Element spacing between + // matrices in `a`. + batch_size, // Total number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + error_msg + << "Unexpected MKL exception caught during potrf() call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + info = e.info(); + } catch (sycl::exception const &e) { + error_msg << "Unexpected SYCL exception caught during potrf() call:\n" + << e.what(); + info = -1; + } + + if (info != 0) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + // TODO: use LinAlgError + if (info == 2) { + throw py::value_error("Matrix is not positive definite"); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(potrf_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return potrf_batch_event; +} + +std::pair + potrf_batch(sycl::queue q, + dpctl::tensor::usm_ndarray a_array, + std::int64_t n, + std::int64_t stride_a, + std::int64_t batch_size, + const std::vector &depends) +{ + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + + potrf_batch_impl_fn_ptr_t potrf_batch_fn = + potrf_batch_dispatch_vector[a_array_type_id]; + if (potrf_batch_fn == nullptr) { + throw py::value_error( + "No potrf_batch implementation defined for the provided type " + "of the input matrix."); + } + + char *a_array_data = a_array.get_data(); + const std::int64_t lda = std::max(1UL, n); + oneapi::mkl::uplo upper_lower = oneapi::mkl::uplo::upper; + + std::vector host_task_events; + sycl::event potrf_batch_ev = + potrf_batch_fn(q, upper_lower, n, a_array_data, lda, stride_a, + batch_size, host_task_events, depends); + + sycl::event args_ev = + dpctl::utils::keep_args_alive(q, {a_array}, host_task_events); + + return std::make_pair(args_ev, potrf_batch_ev); +} + +template +struct PotrfBatchContigFactory +{ + fnT get() + { + if constexpr (types::PotrfBatchTypePairSupportFactory::is_defined) { + return potrf_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_potrf_batch_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(potrf_batch_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 1e18faddfd13..595125c1738d 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -89,6 +89,32 @@ struct PotrfTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::potrf + * function. + * + * @tparam T Type of array containing input matrices, + * as well as the output arrays for storing the Cholesky factor L. + */ +template +struct PotrfBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::syevd diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 32b202ed1fa8..0154cdec01cb 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -76,7 +76,41 @@ def dpnp_cholesky(a): return dpnp.empty(a_shape, dtype=res_type, usm_type=a_usm_type) if a.ndim > 2: - pass + # orig_shape = a.shape + # get 3d input arrays by reshape + a = a.reshape(-1, n, n) + batch_size = a.shape[0] + a_usm_arr = dpnp.get_usm_ndarray(a) + + # oneMKL LAPACK potrf overwrites `a` + a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue + ) + + a_stride = a_h.strides[0] + + # Call the LAPACK extension function _potrf_batch + # to perform the Cholesky factorization of a batch of + # symmetric positive-definite matrix + ht_lapack_ev, _ = li._potrf_batch( + a_sycl_queue, + a_h.get_array(), + n, + a_stride, + batch_size, + [a_copy_ev], + ) + + ht_lapack_ev.wait() + a_ht_copy_ev.wait() + + a_h = dpnp.tril(a_h) + + return a_h else: a_usm_arr = dpnp.get_usm_ndarray(a) diff --git a/tests/third_party/cupy/linalg_tests/test_decomposition.py b/tests/third_party/cupy/linalg_tests/test_decomposition.py index 9375554ca0f4..8e357768ab7e 100644 --- a/tests/third_party/cupy/linalg_tests/test_decomposition.py +++ b/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -73,8 +73,8 @@ def test_decomposition(self, dtype): # numpy.int32, numpy.int64, numpy.uint32, numpy.uint64, # numpy.float32, numpy.float64, numpy.complex64, numpy.complex128]) # def test_batched_decomposition(self, dtype): - # if not cusolver.check_availability('potrfBatched'): - # pytest.skip('potrfBatched is not available') + # # if not cusolver.check_availability('potrfBatched'): + # # pytest.skip('potrfBatched is not available') # Ab1 = random_matrix((3, 5, 5), dtype, scale=(10, 10000), sym=True) # self.check_L(Ab1) # Ab2 = random_matrix((2, 2, 5, 5), dtype, scale=(10, 10000), sym=True) From 7c71b3da9543a3a0bbc37a27b4ae9321fb4f01a2 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 7 Dec 2023 17:13:43 +0100 Subject: [PATCH 04/26] Remove an old impl of dpnp_cholesky --- dpnp/backend/kernels/dpnp_krnl_linalg.cpp | 5 --- dpnp/dpnp_algo/dpnp_algo.pxd | 2 -- dpnp/linalg/dpnp_algo_linalg.pyx | 41 ----------------------- 3 files changed, 48 deletions(-) diff --git a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp index 5bf54c2b84d3..9a2a21e18edf 100644 --- a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp @@ -852,11 +852,6 @@ void func_map_init_linalg_func(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_cholesky_default_c}; - fmap[DPNPFuncName::DPNP_FN_CHOLESKY_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_cholesky_ext_c}; - fmap[DPNPFuncName::DPNP_FN_CHOLESKY_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_cholesky_ext_c}; - fmap[DPNPFuncName::DPNP_FN_DET][eft_INT][eft_INT] = { eft_INT, (void *)dpnp_det_default_c}; fmap[DPNPFuncName::DPNP_FN_DET][eft_LNG][eft_LNG] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index d49adcf0b7fc..0fa1fa510dc1 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -38,8 +38,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_ARANGE DPNP_FN_ARGSORT DPNP_FN_ARGSORT_EXT - DPNP_FN_CHOLESKY - DPNP_FN_CHOLESKY_EXT DPNP_FN_CHOOSE DPNP_FN_CHOOSE_EXT DPNP_FN_COPY diff --git a/dpnp/linalg/dpnp_algo_linalg.pyx b/dpnp/linalg/dpnp_algo_linalg.pyx index c86b869acd3c..b3abcc640ac7 100644 --- a/dpnp/linalg/dpnp_algo_linalg.pyx +++ b/dpnp/linalg/dpnp_algo_linalg.pyx @@ -45,7 +45,6 @@ cimport numpy cimport dpnp.dpnp_utils as utils __all__ = [ - "dpnp_cholesky", "dpnp_cond", "dpnp_det", "dpnp_eig", @@ -68,9 +67,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_func_ptr_t_)(c_dpctl. ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_with_size_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef, void *, void * , size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_with_2size_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef, - void *, void * , size_t, size_t, - const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_3out_shape_t)(c_dpctl.DPCTLSyclQueueRef, void *, void * , void * , void * , size_t , size_t, const c_dpctl.DPCTLEventVectorRef) @@ -79,43 +75,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_2in_1out_func_ptr_t)(c_dpctl.D const c_dpctl.DPCTLEventVectorRef) -cpdef utils.dpnp_descriptor dpnp_cholesky(utils.dpnp_descriptor input_): - size_ = input_.shape[-1] - - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input_.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_CHOLESKY_EXT, param1_type, param1_type) - - input_obj = input_.get_array() - - # create result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(input_.shape, - kernel_data.return_type, - None, - device=input_obj.sycl_device, - usm_type=input_obj.usm_type, - sycl_queue=input_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef custom_linalg_1in_1out_with_2size_func_ptr_t_ func = kernel_data.ptr - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - input_.get_data(), - result.get_data(), - input_.size, - size_, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - cpdef object dpnp_cond(object input, object p): if p in ('f', 'fro'): # TODO: change order='K' when support is implemented From 3a41236c7bc12f0a26e016f85dae75650f55f33d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 8 Dec 2023 12:02:11 +0100 Subject: [PATCH 05/26] Remove DPNP_FN_CHOLESKY_EXT in dpnp_iface_fptr --- dpnp/backend/include/dpnp_iface_fptr.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 6a174b3b647e..1ac959017b05 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -88,8 +88,6 @@ enum class DPNPFuncName : size_t DPNP_FN_CBRT, /**< Used in numpy.cbrt() impl */ DPNP_FN_CEIL, /**< Used in numpy.ceil() impl */ DPNP_FN_CHOLESKY, /**< Used in numpy.linalg.cholesky() impl */ - DPNP_FN_CHOLESKY_EXT, /**< Used in numpy.linalg.cholesky() impl, requires - extra parameters */ DPNP_FN_CONJUGATE, /**< Used in numpy.conjugate() impl */ DPNP_FN_CHOOSE, /**< Used in numpy.choose() impl */ DPNP_FN_CHOOSE_EXT, /**< Used in numpy.choose() impl, requires extra From 06fc2078bad5b377c5abcd17bf78d496bb449083 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 8 Dec 2023 12:22:52 +0100 Subject: [PATCH 06/26] Remove dpnp_cholesky_ext_c --- dpnp/backend/kernels/dpnp_krnl_linalg.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp index 9a2a21e18edf..8898b70b67bb 100644 --- a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp @@ -128,15 +128,6 @@ template void (*dpnp_cholesky_default_c)(void *, void *, const size_t, const size_t) = dpnp_cholesky_c<_DataType>; -template -DPCTLSyclEventRef (*dpnp_cholesky_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - const size_t, - const size_t, - const DPCTLEventVectorRef) = - dpnp_cholesky_c<_DataType>; - template DPCTLSyclEventRef dpnp_det_c(DPCTLSyclQueueRef q_ref, void *array1_in, From 8e60468c10be240f3f7676554a16c31e83360fae Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 8 Dec 2023 12:58:05 +0100 Subject: [PATCH 07/26] Add a new _dpnp_cholesky_batch func --- dpnp/linalg/dpnp_utils_linalg.py | 136 ++++++++++++++++++------------- 1 file changed, 78 insertions(+), 58 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 0154cdec01cb..4cb9cb28b2d7 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -38,6 +38,56 @@ _upper_lower = {"U": 0, "L": 1} +def _dpnp_cholesky_batch(a, res_type): + """ + _dpnp_cholesky_batch(a, res_type) + + Batched Cholesky decomposition. + + """ + + a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type + + n = a.shape[-2] + + orig_shape = a.shape + # get 3d input arrays by reshape + a = a.reshape(-1, n, n) + batch_size = a.shape[0] + a_usm_arr = dpnp.get_usm_ndarray(a) + + # oneMKL LAPACK potrf overwrites `a` + a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue + ) + + a_stride = a_h.strides[0] + + # Call the LAPACK extension function _potrf_batch + # to perform the Cholesky factorization of a batch of + # symmetric positive-definite matrix + ht_lapack_ev, _ = li._potrf_batch( + a_sycl_queue, + a_h.get_array(), + n, + a_stride, + batch_size, + [a_copy_ev], + ) + + ht_lapack_ev.wait() + a_ht_copy_ev.wait() + + a_h = dpnp.tril(a_h.reshape(orig_shape)) + + return a_h + + def dpnp_cholesky(a): """ dpnp_cholesky(a) @@ -46,8 +96,8 @@ def dpnp_cholesky(a): """ - a_usm_type = a.usm_type a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type a_shape = a.shape n = a.shape[-2] @@ -73,72 +123,42 @@ def dpnp_cholesky(a): ) if a.size == 0: - return dpnp.empty(a_shape, dtype=res_type, usm_type=a_usm_type) - - if a.ndim > 2: - # orig_shape = a.shape - # get 3d input arrays by reshape - a = a.reshape(-1, n, n) - batch_size = a.shape[0] - a_usm_arr = dpnp.get_usm_ndarray(a) - - # oneMKL LAPACK potrf overwrites `a` - a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) - - # use DPCTL tensor function to fill the сopy of the input array - # from the input array - a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue - ) - - a_stride = a_h.strides[0] - - # Call the LAPACK extension function _potrf_batch - # to perform the Cholesky factorization of a batch of - # symmetric positive-definite matrix - ht_lapack_ev, _ = li._potrf_batch( - a_sycl_queue, - a_h.get_array(), - n, - a_stride, - batch_size, - [a_copy_ev], + return dpnp.empty( + a_shape, + dtype=res_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, ) - ht_lapack_ev.wait() - a_ht_copy_ev.wait() - - a_h = dpnp.tril(a_h) - - return a_h + if a.ndim > 2: + return _dpnp_cholesky_batch(a, res_type) - else: - a_usm_arr = dpnp.get_usm_ndarray(a) + a_usm_arr = dpnp.get_usm_ndarray(a) - # oneMKL LAPACK potrf overwrites `a` - a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) + # oneMKL LAPACK potrf overwrites `a` + a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) - # use DPCTL tensor function to fill the сopy of the input array - # from the input array - a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue - ) + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue + ) - # Call the LAPACK extension function _potrf - # to computes the Cholesky factorization - ht_lapack_ev, _ = li._potrf( - a_sycl_queue, - n, - a_h.get_array(), - [a_copy_ev], - ) + # Call the LAPACK extension function _potrf + # to computes the Cholesky factorization + ht_lapack_ev, _ = li._potrf( + a_sycl_queue, + n, + a_h.get_array(), + [a_copy_ev], + ) - ht_lapack_ev.wait() - a_ht_copy_ev.wait() + ht_lapack_ev.wait() + a_ht_copy_ev.wait() - a_h = dpnp.tril(a_h) + a_h = dpnp.tril(a_h) - return a_h + return a_h def dpnp_eigh(a, UPLO): From 98911fcc6d7dd327baa707381515f8b32708453d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 8 Dec 2023 12:59:35 +0100 Subject: [PATCH 08/26] Update test_cholesky in test_sycl_queue --- tests/test_sycl_queue.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 3c658c14fe52..a55a502c0f97 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -883,19 +883,31 @@ def test_fft_rfft(type, shape, device): assert_sycl_queue_equal(result_queue, expected_queue) +@pytest.mark.parametrize( + "data, is_empty", + [ + ([[1, -2], [2, 5]], False), + ([[[1, -2], [2, 5]], [[1, -2], [2, 5]]], False), + ((0, 0), True), + ((3, 0, 0), True), + ], + ids=["2D", "3D", "Empty_2D", "Empty_3D"], +) @pytest.mark.parametrize( "device", valid_devices, ids=[device.filter_string for device in valid_devices], ) -def test_cholesky(device): - data = [[[1.0, -2.0], [2.0, 5.0]], [[1.0, -2.0], [2.0, 5.0]]] - numpy_data = numpy.array(data) - dpnp_data = dpnp.array(data, device=device) +def test_cholesky(data, is_empty, device): + if is_empty: + numpy_data = numpy.empty(data, dtype=dpnp.default_float_type(device)) + else: + numpy_data = numpy.array(data, dtype=dpnp.default_float_type(device)) + dpnp_data = dpnp.array(numpy_data, device=device) result = dpnp.linalg.cholesky(dpnp_data) expected = numpy.linalg.cholesky(numpy_data) - assert_array_equal(expected, result) + assert_dtype_allclose(result, expected) expected_queue = dpnp_data.get_array().sycl_queue result_queue = result.get_array().sycl_queue From 8eeeb09b070908c1296d1e6c4642b1780551408b Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 8 Dec 2023 13:18:48 +0100 Subject: [PATCH 09/26] Expand test scope in public CI --- .github/workflows/conda-package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index 1a6650798e92..f3ae98fffd87 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -30,6 +30,7 @@ env: test_umath.py test_usm_type.py third_party/cupy/core_tests + third_party/cupy/linalg_tests/test_decomposition.py third_party/cupy/linalg_tests/test_product.py third_party/cupy/logic_tests/test_comparison.py third_party/cupy/logic_tests/test_truth.py From 72d64438f16c2d53ff9a21f27b24c0cddf2af765 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 8 Dec 2023 13:19:59 +0100 Subject: [PATCH 10/26] Add more tests for dpnp.linalg.cholesky --- tests/test_usm_type.py | 22 ++++++++++++++ .../cupy/linalg_tests/test_decomposition.py | 30 ++++++++++++------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 4982ed424140..0725eee8e40a 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -486,3 +486,25 @@ def test_take(func, usm_type_x, usm_type_ind): assert x.usm_type == usm_type_x assert ind.usm_type == usm_type_ind assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_ind]) + + +@pytest.mark.parametrize( + "data, is_empty", + [ + ([[1, -2], [2, 5]], False), + ([[[1, -2], [2, 5]], [[1, -2], [2, 5]]], False), + ((0, 0), True), + ((3, 0, 0), True), + ], + ids=["2D", "3D", "Empty_2D", "Empty_3D"], +) +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_cholesky(data, is_empty, usm_type): + if is_empty: + x = dp.empty(data, dtype=dp.default_float_type(), usm_type=usm_type) + else: + x = dp.array(data, dtype=dp.default_float_type(), usm_type=usm_type) + + result = dp.linalg.cholesky(x) + + assert x.usm_type == result.usm_type diff --git a/tests/third_party/cupy/linalg_tests/test_decomposition.py b/tests/third_party/cupy/linalg_tests/test_decomposition.py index 8e357768ab7e..3279faac06a1 100644 --- a/tests/third_party/cupy/linalg_tests/test_decomposition.py +++ b/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -69,16 +69,23 @@ def test_decomposition(self, dtype): # np.linalg.cholesky only uses a lower triangle of an array self.check_L(numpy.array([[1, 2], [1, 9]], dtype)) - # @testing.for_dtypes([ - # numpy.int32, numpy.int64, numpy.uint32, numpy.uint64, - # numpy.float32, numpy.float64, numpy.complex64, numpy.complex128]) - # def test_batched_decomposition(self, dtype): - # # if not cusolver.check_availability('potrfBatched'): - # # pytest.skip('potrfBatched is not available') - # Ab1 = random_matrix((3, 5, 5), dtype, scale=(10, 10000), sym=True) - # self.check_L(Ab1) - # Ab2 = random_matrix((2, 2, 5, 5), dtype, scale=(10, 10000), sym=True) - # self.check_L(Ab2) + @testing.for_dtypes( + [ + numpy.int32, + numpy.int64, + numpy.uint32, + numpy.uint64, + numpy.float32, + numpy.float64, + numpy.complex64, + numpy.complex128, + ] + ) + def test_batched_decomposition(self, dtype): + Ab1 = random_matrix((3, 5, 5), dtype, scale=(10, 10000), sym=True) + self.check_L(Ab1) + Ab2 = random_matrix((2, 2, 5, 5), dtype, scale=(10, 10000), sym=True) + self.check_L(Ab2) @pytest.mark.parametrize( "shape", @@ -113,7 +120,8 @@ def check_L(self, array): with pytest.raises((numpy.linalg.LinAlgError, ValueError)): xp.linalg.cholesky(a) - @pytest.mark.skipif(is_cpu_device(), reason="MKL bug") + # TODO: remove skipif when MKLD-16626 is resolved + @pytest.mark.skipif(is_cpu_device(), reason="MKLD-16626") @testing.for_dtypes( [ numpy.int32, From 253dc938d0da3a1fb87bb9b27a658fbaef4c4d84 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 16:00:42 +0100 Subject: [PATCH 11/26] Remove TODOs in cholesky() and update docstings --- dpnp/linalg/dpnp_iface_linalg.py | 46 +++----------------------------- 1 file changed, 4 insertions(+), 42 deletions(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index ea9e0935684f..8efecc75e89c 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -75,39 +75,15 @@ def cholesky(a): """ - Cholesky decomposition. - - Return the Cholesky decomposition, `L * L.H`, of the square matrix `a`, - where `L` is lower-triangular and .H is the conjugate transpose operator - (which is the ordinary transpose if `a` is real-valued). `a` must be - Hermitian (symmetric if real-valued) and positive-definite. No - checking is performed to verify whether `a` is Hermitian or not. - In addition, only the lower-triangular and diagonal elements of `a` - are used. Only `L` is actually returned. + Compute the Cholesky decomposition of a square array. For full documentation refer to :obj:`numpy.linalg.cholesky`. - Parameters - ---------- - a : (..., M, M) dpnp.ndarray - Hermitian (symmetric if all elements are real), positive-definite - input matrix. - - Returns - ------- - L : (..., M, M) dpnp.ndarray - Lower-triangular Cholesky factor of `a`. Returns `a` - matrix object if `a` is a matrix object. - Limitations ----------- Parameter `a` is supported as :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. Input array data types are limited by supported DPNP :ref:`Data types`. - See Also - -------- - :obj:`scipy.linalg.cholesky` : Similar function in SciPy. - Examples -------- >>> import dpnp as np @@ -125,23 +101,9 @@ def cholesky(a): """ - # TODO: use _assert_dpnp_array - if not dpnp.is_supported_array_type(a): - raise TypeError( - "An array must be any of supported type, but got {}".format(type(a)) - ) - - # TODO: use _assert_stacked_2d - if a.ndim < 2: - raise ValueError( - f"{a.ndim}-dimensional array given. The input " - "array must be at least two-dimensional" - ) - - # TODO: use _assert_stacked_square - n, m = a.shape[-2:] - if m != n: - raise ValueError("Last 2 dimensions of the input array must be square") + dpnp.check_supported_arrays_type(a) + check_stacked_2d(a) + check_stacked_square(a) return dpnp_cholesky(a) From 0613072b5fdb68c218caf6b0725c90ba14fc5d80 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 16:01:26 +0100 Subject: [PATCH 12/26] Use _common_type in dpnp_cholesky --- dpnp/linalg/dpnp_utils_linalg.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index c01e25ded714..d5c9b110a00b 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -548,29 +548,11 @@ def dpnp_cholesky(a): a_sycl_queue = a.sycl_queue a_usm_type = a.usm_type - a_shape = a.shape - n = a.shape[-2] + res_type = _common_type(a) - # TODO: Use linalg_common_type from #1598 - if dpnp.issubdtype(a.dtype, dpnp.floating): - res_type = ( - a.dtype - if a_sycl_queue.sycl_device.has_aspect_fp64 - else dpnp.float32 - ) - elif dpnp.issubdtype(a.dtype, dpnp.complexfloating): - res_type = ( - a.dtype - if a_sycl_queue.sycl_device.has_aspect_fp64 - else dpnp.complex64 - ) - else: - res_type = ( - dpnp.float64 - if a_sycl_queue.sycl_device.has_aspect_fp64 - else dpnp.float32 - ) + a_shape = a.shape + n = a.shape[-2] if a.size == 0: return dpnp.empty( From 1f87a65810bb54cce484febaa9ae546d95efdcda Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 16:11:07 +0100 Subject: [PATCH 13/26] Update dpnp_cholesky and dpnp_cholesky_batch --- dpnp/linalg/dpnp_utils_linalg.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index d5c9b110a00b..56162e291800 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -488,11 +488,11 @@ def dpnp_det(a): return det.reshape(shape) -def _dpnp_cholesky_batch(a, res_type): +def dpnp_cholesky_batch(a, res_type): """ - _dpnp_cholesky_batch(a, res_type) + dpnp_cholesky_batch(a, res_type) - Batched Cholesky decomposition. + Return the batched Cholesky decomposition of `a` array. """ @@ -507,7 +507,7 @@ def _dpnp_cholesky_batch(a, res_type): batch_size = a.shape[0] a_usm_arr = dpnp.get_usm_ndarray(a) - # oneMKL LAPACK potrf overwrites `a` + # `a` must be copied because potrf_batch destroys the input matrix a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) # use DPCTL tensor function to fill the сopy of the input array @@ -519,8 +519,8 @@ def _dpnp_cholesky_batch(a, res_type): a_stride = a_h.strides[0] # Call the LAPACK extension function _potrf_batch - # to perform the Cholesky factorization of a batch of - # symmetric positive-definite matrix + # to computes the Cholesky decomposition of a batch of + # symmetric positive-definite matrices ht_lapack_ev, _ = li._potrf_batch( a_sycl_queue, a_h.get_array(), @@ -542,7 +542,7 @@ def dpnp_cholesky(a): """ dpnp_cholesky(a) - Return the Cholesky factorization. + Return the Cholesky decomposition of `a` array. """ @@ -563,11 +563,11 @@ def dpnp_cholesky(a): ) if a.ndim > 2: - return _dpnp_cholesky_batch(a, res_type) + return dpnp_cholesky_batch(a, res_type) a_usm_arr = dpnp.get_usm_ndarray(a) - # oneMKL LAPACK potrf overwrites `a` + # `a` must be copied because potrf destroys the input matrix a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) # use DPCTL tensor function to fill the сopy of the input array @@ -577,7 +577,7 @@ def dpnp_cholesky(a): ) # Call the LAPACK extension function _potrf - # to computes the Cholesky factorization + # to computes the Cholesky decomposition ht_lapack_ev, _ = li._potrf( a_sycl_queue, n, From 3c62207e8f42da3e408729af809beabf6d68f720 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 16:20:12 +0100 Subject: [PATCH 14/26] Keep the lexicographic order --- dpnp/backend/extensions/lapack/lapack_py.cpp | 2 +- dpnp/linalg/dpnp_utils_linalg.py | 94 ++++++++++---------- 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 2598e0ce67a4..2864d9f9d6e4 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -46,8 +46,8 @@ void init_dispatch_vectors(void) lapack_ext::init_gesv_dispatch_vector(); lapack_ext::init_getrf_batch_dispatch_vector(); lapack_ext::init_getrf_dispatch_vector(); - lapack_ext::init_potrf_dispatch_vector(); lapack_ext::init_potrf_batch_dispatch_vector(); + lapack_ext::init_potrf_dispatch_vector(); lapack_ext::init_syevd_dispatch_vector(); } diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 56162e291800..37dac5191a24 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -441,53 +441,6 @@ def _lu_factor(a, res_type): return (a_h, ipiv_h, dev_info_array) -def dpnp_det(a): - """ - dpnp_det(a) - - Returns the determinant of `a` array. - - """ - - a_usm_type = a.usm_type - a_sycl_queue = a.sycl_queue - - res_type = _common_type(a) - - a_shape = a.shape - shape = a_shape[:-2] - n = a_shape[-2] - - if a.size == 0: - # empty batch (result is empty, too) or empty matrices det([[]]) == 1 - det = dpnp.ones( - shape, - dtype=res_type, - usm_type=a_usm_type, - sycl_queue=a_sycl_queue, - ) - return det - - lu, ipiv, dev_info = _lu_factor(a, res_type) - - # Transposing 'lu' to swap the last two axes for compatibility - # with 'dpnp.diagonal' as it does not support 'axis1' and 'axis2' arguments. - # TODO: Replace with 'dpnp.diagonal(lu, axis1=-2, axis2=-1)' when supported. - lu_transposed = lu.transpose(-2, -1, *range(lu.ndim - 2)) - diag = dpnp.diagonal(lu_transposed) - - det = dpnp.prod(dpnp.abs(diag), axis=-1) - - sign = _calculate_determinant_sign(ipiv, diag, res_type, n) - - det = sign * det - det = det.astype(res_type, copy=False) - singular = dev_info > 0 - det = dpnp.where(singular, res_type.type(0), det) - - return det.reshape(shape) - - def dpnp_cholesky_batch(a, res_type): """ dpnp_cholesky_batch(a, res_type) @@ -593,6 +546,53 @@ def dpnp_cholesky(a): return a_h +def dpnp_det(a): + """ + dpnp_det(a) + + Returns the determinant of `a` array. + + """ + + a_usm_type = a.usm_type + a_sycl_queue = a.sycl_queue + + res_type = _common_type(a) + + a_shape = a.shape + shape = a_shape[:-2] + n = a_shape[-2] + + if a.size == 0: + # empty batch (result is empty, too) or empty matrices det([[]]) == 1 + det = dpnp.ones( + shape, + dtype=res_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return det + + lu, ipiv, dev_info = _lu_factor(a, res_type) + + # Transposing 'lu' to swap the last two axes for compatibility + # with 'dpnp.diagonal' as it does not support 'axis1' and 'axis2' arguments. + # TODO: Replace with 'dpnp.diagonal(lu, axis1=-2, axis2=-1)' when supported. + lu_transposed = lu.transpose(-2, -1, *range(lu.ndim - 2)) + diag = dpnp.diagonal(lu_transposed) + + det = dpnp.prod(dpnp.abs(diag), axis=-1) + + sign = _calculate_determinant_sign(ipiv, diag, res_type, n) + + det = sign * det + det = det.astype(res_type, copy=False) + singular = dev_info > 0 + det = dpnp.where(singular, res_type.type(0), det) + + return det.reshape(shape) + + def dpnp_eigh(a, UPLO): """ dpnp_eigh(a, UPLO) From c90a0068b903355a2b48da96e19d57ae8c7fe26c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 16:32:47 +0100 Subject: [PATCH 15/26] Remove passing n parameter to _potrf --- dpnp/backend/extensions/lapack/lapack_py.cpp | 2 +- dpnp/backend/extensions/lapack/potrf.cpp | 3 +-- dpnp/backend/extensions/lapack/potrf.hpp | 1 - dpnp/linalg/dpnp_utils_linalg.py | 2 -- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 2864d9f9d6e4..0a11ee7a1706 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -98,7 +98,7 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_potrf", &lapack_ext::potrf, "Call `potrf` from OneMKL LAPACK library to return " "the Cholesky factorization of a symmetric positive-definite matrix", - py::arg("sycl_queue"), py::arg("n"), py::arg("a_array"), + py::arg("sycl_queue"), py::arg("a_array"), py::arg("depends") = py::list()); m.def("_potrf_batch", &lapack_ext::potrf_batch, diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index 88d305a9a3c1..2b84bece4722 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -126,11 +126,9 @@ static sycl::event potrf_impl(sycl::queue exec_q, std::pair potrf(sycl::queue q, - const std::int64_t n, dpctl::tensor::usm_ndarray a_array, const std::vector &depends) { - auto array_types = dpctl_td_ns::usm_ndarray_types(); int a_array_type_id = array_types.typenum_to_lookup_id(a_array.get_typenum()); @@ -143,6 +141,7 @@ std::pair } char *a_array_data = a_array.get_data(); + const std::int64_t n = a_array.get_shape_raw()[0]; const std::int64_t lda = std::max(1UL, n); oneapi::mkl::uplo upper_lower = oneapi::mkl::uplo::upper; diff --git a/dpnp/backend/extensions/lapack/potrf.hpp b/dpnp/backend/extensions/lapack/potrf.hpp index 35efda3b9b43..2713f2231a0c 100644 --- a/dpnp/backend/extensions/lapack/potrf.hpp +++ b/dpnp/backend/extensions/lapack/potrf.hpp @@ -40,7 +40,6 @@ namespace lapack { extern std::pair potrf(sycl::queue exec_q, - const std::int64_t n, dpctl::tensor::usm_ndarray a_array, const std::vector &depends = {}); diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 37dac5191a24..384b070d9370 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -505,7 +505,6 @@ def dpnp_cholesky(a): res_type = _common_type(a) a_shape = a.shape - n = a.shape[-2] if a.size == 0: return dpnp.empty( @@ -533,7 +532,6 @@ def dpnp_cholesky(a): # to computes the Cholesky decomposition ht_lapack_ev, _ = li._potrf( a_sycl_queue, - n, a_h.get_array(), [a_copy_ev], ) From ad17d19930e821f3bd71c3c35c707d5ea490f4eb Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 16:53:06 +0100 Subject: [PATCH 16/26] Add additional checks to potrf and potrf_batch --- dpnp/backend/extensions/lapack/potrf.cpp | 25 ++++++++++++++++++- .../backend/extensions/lapack/potrf_batch.cpp | 23 +++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index 2b84bece4722..a1eb145003e7 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -129,6 +129,29 @@ std::pair dpctl::tensor::usm_ndarray a_array, const std::vector &depends) { + const int a_array_nd = a_array.get_ndim(); + + if (a_array_nd != 2) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but a 2-dimensional array is expected."); + } + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + + if (a_array_shape[0] != a_array_shape[1]) { + throw py::value_error("The input array must be square," + " but got a shape of (" + + std::to_string(a_array_shape[0]) + ", " + + std::to_string(a_array_shape[1]) + ")."); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } + auto array_types = dpctl_td_ns::usm_ndarray_types(); int a_array_type_id = array_types.typenum_to_lookup_id(a_array.get_typenum()); @@ -141,7 +164,7 @@ std::pair } char *a_array_data = a_array.get_data(); - const std::int64_t n = a_array.get_shape_raw()[0]; + const std::int64_t n = a_array_shape[0]; const std::int64_t lda = std::max(1UL, n); oneapi::mkl::uplo upper_lower = oneapi::mkl::uplo::upper; diff --git a/dpnp/backend/extensions/lapack/potrf_batch.cpp b/dpnp/backend/extensions/lapack/potrf_batch.cpp index e9e8c0c9e1b1..378a18053077 100644 --- a/dpnp/backend/extensions/lapack/potrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/potrf_batch.cpp @@ -141,6 +141,29 @@ std::pair std::int64_t batch_size, const std::vector &depends) { + const int a_array_nd = a_array.get_ndim(); + + if (a_array_nd < 3) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but a 3-dimensional or higher array is expected."); + } + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + + if (a_array_shape[a_array_nd - 1] != a_array_shape[a_array_nd - 2]) { + throw py::value_error( + "The last two dimensions of the input array must be square," + " but got a shape of (" + + std::to_string(a_array_shape[a_array_nd - 1]) + ", " + + std::to_string(a_array_shape[a_array_nd - 2]) + ")."); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + if (!is_a_array_c_contig) { + throw py::value_error("The input array " + "must be C-contiguous"); + } auto array_types = dpctl_td_ns::usm_ndarray_types(); int a_array_type_id = From e36bdcb32faf745e07134c190cd00ea4909b7d6b Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 19:32:11 +0100 Subject: [PATCH 17/26] Extend potrf error handler --- dpnp/backend/extensions/lapack/potrf.cpp | 39 ++++++++++++++++-------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index a1eb145003e7..2355659415bf 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -29,6 +29,7 @@ #include "utils/memory_overlap.hpp" #include "utils/type_utils.hpp" +#include "linalg_exceptions.hpp" #include "potrf.hpp" #include "types_matrix.hpp" @@ -76,6 +77,7 @@ static sycl::event potrf_impl(sycl::queue exec_q, std::stringstream error_msg; std::int64_t info = 0; + bool is_exception_caught = false; sycl::event potrf_event; try { @@ -83,7 +85,10 @@ static sycl::event potrf_impl(sycl::queue exec_q, potrf_event = oneapi::mkl::lapack::potrf( exec_q, - upper_lower, // + upper_lower, // An enumeration value of type oneapi::mkl::uplo: + // oneapi::mkl::uplo::upper for the upper triangular + // part; oneapi::mkl::uplo::lower for the lower + // triangular part. n, // Order of the square matrix; (0 ≤ n). a, // Pointer to the n-by-n matrix. lda, // The leading dimension of `a`. @@ -91,27 +96,37 @@ static sycl::event potrf_impl(sycl::queue exec_q, // routine for storing intermediate results. scratchpad_size, depends); } catch (mkl_lapack::exception const &e) { - error_msg - << "Unexpected MKL exception caught during potrf() call:\nreason: " - << e.what() << "\ninfo: " << e.info(); + is_exception_caught = true; info = e.info(); + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info > 0 && e.detail() == 0) { + sycl::free(scratchpad, exec_q); + throw LinAlgError("Matrix is not positive definite."); + } + else { + error_msg << "Unexpected MKL exception caught during getrf() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } } catch (sycl::exception const &e) { + is_exception_caught = true; error_msg << "Unexpected SYCL exception caught during potrf() call:\n" << e.what(); - info = -1; } - if (info != 0) // an unexpected error occurs + if (is_exception_caught) // an unexpected error occurs { if (scratchpad != nullptr) { sycl::free(scratchpad, exec_q); } - - // TODO: use LinAlgError - if (info == 2) { - throw py::value_error("Matrix is not positive definite"); - } - throw std::runtime_error(error_msg.str()); } From 8565346476aa60cd8b68b846e3e434ff6f5ccc7c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 20:29:51 +0100 Subject: [PATCH 18/26] Extend potrf_batch error handler --- .../backend/extensions/lapack/potrf_batch.cpp | 63 ++++++++++++++----- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/dpnp/backend/extensions/lapack/potrf_batch.cpp b/dpnp/backend/extensions/lapack/potrf_batch.cpp index 378a18053077..cf74f7c9100c 100644 --- a/dpnp/backend/extensions/lapack/potrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/potrf_batch.cpp @@ -29,6 +29,7 @@ #include "utils/memory_overlap.hpp" #include "utils/type_utils.hpp" +#include "linalg_exceptions.hpp" #include "potrf.hpp" #include "types_matrix.hpp" @@ -82,6 +83,7 @@ static sycl::event potrf_batch_impl(sycl::queue exec_q, std::stringstream error_msg; std::int64_t info = 0; + bool is_exception_caught = false; sycl::event potrf_batch_event; try { @@ -89,7 +91,10 @@ static sycl::event potrf_batch_impl(sycl::queue exec_q, potrf_batch_event = oneapi::mkl::lapack::potrf_batch( exec_q, - upper_lower, // + upper_lower, // An enumeration value of type oneapi::mkl::uplo: + // oneapi::mkl::uplo::upper for the upper triangular + // part; oneapi::mkl::uplo::lower for the lower + // triangular part. n, // Order of each square matrix in the batch; (0 ≤ n). a, // Pointer to the batch of matrices. lda, // The leading dimension of `a`. @@ -99,28 +104,58 @@ static sycl::event potrf_batch_impl(sycl::queue exec_q, scratchpad, // Pointer to scratchpad memory to be used by MKL // routine for storing intermediate results. scratchpad_size, depends); - } catch (mkl_lapack::exception const &e) { + } catch (mkl_lapack::batch_error const &be) { + // Get the indices of matrices within the batch that encountered an + // error + auto error_matrices_ids = be.ids(); + error_msg - << "Unexpected MKL exception caught during potrf() call:\nreason: " - << e.what() << "\ninfo: " << e.info(); + << "Matrix is not positive definite. Errors in matrices with IDs: "; + for (size_t i = 0; i < error_matrices_ids.size(); ++i) { + error_msg << error_matrices_ids[i]; + if (i < error_matrices_ids.size() - 1) { + error_msg << ", "; + } + } + error_msg << "."; + + sycl::free(scratchpad, exec_q); + throw LinAlgError(error_msg.str().c_str()); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info != 0 && e.detail() == 0) { + error_msg << "Error in batch processing. " + "Number of failed calculations: " + << info; + } + else { + error_msg << "Unexpected MKL exception caught during potrf_batch() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } } catch (sycl::exception const &e) { - error_msg << "Unexpected SYCL exception caught during potrf() call:\n" - << e.what(); - info = -1; + is_exception_caught = true; + error_msg + << "Unexpected SYCL exception caught during potrf_batch() call:\n" + << e.what(); } - if (info != 0) // an unexpected error occurs + if (is_exception_caught) // an unexpected error occurs { if (scratchpad != nullptr) { sycl::free(scratchpad, exec_q); } - - // TODO: use LinAlgError - if (info == 2) { - throw py::value_error("Matrix is not positive definite"); - } - throw std::runtime_error(error_msg.str()); } From 9d7411bc06f67d8bee3a90fd4fc715e3203d38e3 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Jan 2024 22:36:09 +0100 Subject: [PATCH 19/26] Update tests for dpnp.linalg.cholesky --- tests/test_linalg.py | 116 ++++++++++++------ .../cupy/linalg_tests/test_decomposition.py | 2 +- 2 files changed, 81 insertions(+), 37 deletions(-) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index cb780abd9253..201c4c3f83da 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -44,44 +44,88 @@ def vvsort(val, vec, size, xp): vec[:, imax] = temp -@pytest.mark.parametrize( - "array", - [ - [[[1, -2], [2, 5]]], - [[[1.0, -2.0], [2.0, 5.0]]], - [[[1.0, -2.0], [2.0, 5.0]], [[1.0, -2.0], [2.0, 5.0]]], - ], - ids=[ - "[[[1, -2], [2, 5]]]", - "[[[1., -2.], [2., 5.]]]", - "[[[1., -2.], [2., 5.]], [[1., -2.], [2., 5.]]]", - ], -) -def test_cholesky(array): - a = numpy.array(array) - ia = inp.array(a) - result = inp.linalg.cholesky(ia) - expected = numpy.linalg.cholesky(a) - assert_array_equal(expected, result) +class TestCholesky: + @pytest.mark.parametrize( + "array", + [ + [[1, 2], [2, 5]], + [[[5, 2], [2, 6]], [[7, 3], [3, 8]], [[3, 1], [1, 4]]], + [ + [[[5, 2], [2, 5]], [[6, 3], [3, 6]]], + [[[7, 2], [2, 7]], [[8, 3], [3, 8]]], + ], + ], + ids=[ + "2D_array", + "3D_array", + "4D_array", + ], + ) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_cholesky_3d_4d(self, array, dtype): + a = numpy.array(array, dtype=dtype) + ia = inp.array(a) + result = inp.linalg.cholesky(ia) + expected = numpy.linalg.cholesky(a) + assert_dtype_allclose(result, expected) + def test_cholesky_strides(self): + a_np = numpy.array( + [ + [5, 2, 0, 0, 1], + [2, 6, 0, 0, 2], + [0, 0, 7, 0, 0], + [0, 0, 0, 4, 0], + [1, 2, 0, 0, 5], + ] + ) -@pytest.mark.parametrize( - "shape", - [ - (0, 0), - (3, 0, 0), - ], - ids=[ - "(0, 0)", - "(3, 0, 0)", - ], -) -def test_cholesky_0D(shape): - a = numpy.empty(shape) - ia = inp.array(a) - result = inp.linalg.cholesky(ia) - expected = numpy.linalg.cholesky(a) - assert_array_equal(expected, result) + a_dp = inp.array(a_np) + + # positive strides + expected = numpy.linalg.cholesky(a_np[::2, ::2]) + result = inp.linalg.cholesky(a_dp[::2, ::2]) + assert_allclose(expected, result, rtol=1e-3, atol=1e-4) + + # negative strides + expected = numpy.linalg.cholesky(a_np[::-2, ::-2]) + result = inp.linalg.cholesky(a_dp[::-2, ::-2]) + assert_allclose(expected, result, rtol=1e-3, atol=1e-4) + + @pytest.mark.parametrize( + "shape", + [ + (0, 0), + (3, 0, 0), + (0, 2, 2), + ], + ids=[ + "(0, 0)", + "(3, 0, 0)", + "(0, 2, 2)", + ], + ) + def test_cholesky_empty(self, shape): + a = numpy.empty(shape) + ia = inp.array(a) + result = inp.linalg.cholesky(ia) + expected = numpy.linalg.cholesky(a) + assert_array_equal(expected, result) + + def test_cholesky_errors(self): + a_dp = inp.array([[1, 2], [2, 5]], dtype="float32") + + # unsupported type + a_np = inp.asnumpy(a_dp) + assert_raises(TypeError, inp.linalg.cholesky, a_np) + + # a.ndim < 2 + a_dp_ndim_1 = a_dp.flatten() + assert_raises(inp.linalg.LinAlgError, inp.linalg.cholesky, a_dp_ndim_1) + + # a is not square + a_dp = inp.ones((2, 3)) + assert_raises(inp.linalg.LinAlgError, inp.linalg.cholesky, a_dp) @pytest.mark.parametrize( diff --git a/tests/third_party/cupy/linalg_tests/test_decomposition.py b/tests/third_party/cupy/linalg_tests/test_decomposition.py index 3279faac06a1..42bcf122ff40 100644 --- a/tests/third_party/cupy/linalg_tests/test_decomposition.py +++ b/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -117,7 +117,7 @@ class TestCholeskyInvalid(unittest.TestCase): def check_L(self, array): for xp in (numpy, cupy): a = xp.asarray(array) - with pytest.raises((numpy.linalg.LinAlgError, ValueError)): + with pytest.raises(xp.linalg.LinAlgError): xp.linalg.cholesky(a) # TODO: remove skipif when MKLD-16626 is resolved From 3d0484c54e69b16ae0a82353f96976f57e9dbdb9 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 17 Jan 2024 12:46:09 +0100 Subject: [PATCH 20/26] Update license year --- dpnp/backend/extensions/lapack/getrf.cpp | 2 +- dpnp/backend/extensions/lapack/getrf.hpp | 2 +- dpnp/backend/extensions/lapack/getrf_batch.cpp | 2 +- dpnp/backend/extensions/lapack/potrf.cpp | 2 +- dpnp/backend/extensions/lapack/potrf.hpp | 2 +- dpnp/backend/extensions/lapack/potrf_batch.cpp | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrf.cpp b/dpnp/backend/extensions/lapack/getrf.cpp index faebb9f7d42f..f97d395bcd64 100644 --- a/dpnp/backend/extensions/lapack/getrf.cpp +++ b/dpnp/backend/extensions/lapack/getrf.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/lapack/getrf.hpp b/dpnp/backend/extensions/lapack/getrf.hpp index d8ea425f6804..fee9b209426e 100644 --- a/dpnp/backend/extensions/lapack/getrf.hpp +++ b/dpnp/backend/extensions/lapack/getrf.hpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/lapack/getrf_batch.cpp b/dpnp/backend/extensions/lapack/getrf_batch.cpp index 0944d7437aed..76977bf6628a 100644 --- a/dpnp/backend/extensions/lapack/getrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrf_batch.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index 2355659415bf..23533519efcb 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/lapack/potrf.hpp b/dpnp/backend/extensions/lapack/potrf.hpp index 2713f2231a0c..ac88026c3057 100644 --- a/dpnp/backend/extensions/lapack/potrf.hpp +++ b/dpnp/backend/extensions/lapack/potrf.hpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without diff --git a/dpnp/backend/extensions/lapack/potrf_batch.cpp b/dpnp/backend/extensions/lapack/potrf_batch.cpp index cf74f7c9100c..4f50cc4fdd50 100644 --- a/dpnp/backend/extensions/lapack/potrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/potrf_batch.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2023, Intel Corporation +// Copyright (c) 2024, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without From f62ea45f761258a7840bd71501a4ab730a94f0e2 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 17 Jan 2024 13:00:54 +0100 Subject: [PATCH 21/26] Update cholesky docstrings --- dpnp/linalg/dpnp_iface_linalg.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 8efecc75e89c..148399c7e9ac 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -75,14 +75,28 @@ def cholesky(a): """ - Compute the Cholesky decomposition of a square array. + Cholesky decomposition. + + Return the Cholesky decomposition, `L * L.H`, of the square matrix `a`, + where `L` is lower-triangular and .H is the conjugate transpose operator + (which is the ordinary transpose if `a` is real-valued). `a` must be + Hermitian (symmetric if real-valued) and positive-definite. No + checking is performed to verify whether `a` is Hermitian or not. + In addition, only the lower-triangular and diagonal elements of `a` + are used. Only `L` is actually returned. For full documentation refer to :obj:`numpy.linalg.cholesky`. - Limitations - ----------- - Parameter `a` is supported as :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. - Input array data types are limited by supported DPNP :ref:`Data types`. + Parameters + ---------- + a : (..., M, M) {dpnp.ndarray, usm_ndarray} + Hermitian (symmetric if all elements are real), positive-definite + input matrix. + + Returns + ------- + L : (..., M, M) dpnp.ndarray + Lower-triangular Cholesky factor of `a`. Examples -------- From 0668d747481cbdc34ec82e1e5197001ea1a3809c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 17 Jan 2024 20:51:45 +0100 Subject: [PATCH 22/26] Add support upper paramenetr for potrf --- dpnp/backend/extensions/lapack/lapack_py.cpp | 2 +- dpnp/backend/extensions/lapack/potrf.cpp | 18 +++++++----- dpnp/backend/extensions/lapack/potrf.hpp | 1 + dpnp/linalg/dpnp_iface_linalg.py | 31 +++++++++++++------- dpnp/linalg/dpnp_utils_linalg.py | 19 +++++++++--- 5 files changed, 48 insertions(+), 23 deletions(-) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 0a11ee7a1706..eea43d71ee32 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -98,7 +98,7 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_potrf", &lapack_ext::potrf, "Call `potrf` from OneMKL LAPACK library to return " "the Cholesky factorization of a symmetric positive-definite matrix", - py::arg("sycl_queue"), py::arg("a_array"), + py::arg("sycl_queue"), py::arg("a_array"), py::arg("upper_lower"), py::arg("depends") = py::list()); m.def("_potrf_batch", &lapack_ext::potrf_batch, diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index 23533519efcb..e5b62b7e059c 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -48,7 +48,7 @@ namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*potrf_impl_fn_ptr_t)(sycl::queue, - oneapi::mkl::uplo, + const oneapi::mkl::uplo, const std::int64_t, char *, std::int64_t, @@ -59,7 +59,7 @@ static potrf_impl_fn_ptr_t potrf_dispatch_vector[dpctl_td_ns::num_types]; template static sycl::event potrf_impl(sycl::queue exec_q, - oneapi::mkl::uplo upper_lower, + const oneapi::mkl::uplo upper_lower, const std::int64_t n, char *in_a, std::int64_t lda, @@ -142,6 +142,7 @@ static sycl::event potrf_impl(sycl::queue exec_q, std::pair potrf(sycl::queue q, dpctl::tensor::usm_ndarray a_array, + const std::int8_t upper_lower, const std::vector &depends) { const int a_array_nd = a_array.get_ndim(); @@ -161,10 +162,10 @@ std::pair std::to_string(a_array_shape[1]) + ")."); } - bool is_a_array_c_contig = a_array.is_c_contiguous(); - if (!is_a_array_c_contig) { + bool is_a_array_f_contig = a_array.is_f_contiguous(); + if (!is_a_array_f_contig) { throw py::value_error("The input array " - "must be C-contiguous"); + "must be F-contiguous"); } auto array_types = dpctl_td_ns::usm_ndarray_types(); @@ -181,11 +182,12 @@ std::pair char *a_array_data = a_array.get_data(); const std::int64_t n = a_array_shape[0]; const std::int64_t lda = std::max(1UL, n); - oneapi::mkl::uplo upper_lower = oneapi::mkl::uplo::upper; + const oneapi::mkl::uplo uplo_val = + static_cast(upper_lower); std::vector host_task_events; - sycl::event potrf_ev = potrf_fn(q, upper_lower, n, a_array_data, lda, - host_task_events, depends); + sycl::event potrf_ev = + potrf_fn(q, uplo_val, n, a_array_data, lda, host_task_events, depends); sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array}, host_task_events); diff --git a/dpnp/backend/extensions/lapack/potrf.hpp b/dpnp/backend/extensions/lapack/potrf.hpp index ac88026c3057..fe52b5c85aad 100644 --- a/dpnp/backend/extensions/lapack/potrf.hpp +++ b/dpnp/backend/extensions/lapack/potrf.hpp @@ -41,6 +41,7 @@ namespace lapack extern std::pair potrf(sycl::queue exec_q, dpctl::tensor::usm_ndarray a_array, + const std::int8_t upper_lower, const std::vector &depends = {}); extern std::pair diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 148399c7e9ac..2f7420d13292 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -73,17 +73,18 @@ ] -def cholesky(a): +def cholesky(a, upper=False): """ Cholesky decomposition. - Return the Cholesky decomposition, `L * L.H`, of the square matrix `a`, - where `L` is lower-triangular and .H is the conjugate transpose operator - (which is the ordinary transpose if `a` is real-valued). `a` must be - Hermitian (symmetric if real-valued) and positive-definite. No - checking is performed to verify whether `a` is Hermitian or not. - In addition, only the lower-triangular and diagonal elements of `a` - are used. Only `L` is actually returned. + Return the lower or upper Cholesky decomposition, ``L * L.H`` or + ``U.H * U``, of the square matrix ``a``, where ``L`` is lower-triangular, + ``U`` is upper-triangular, and ``.H`` is the conjugate transpose operator + (which is the ordinary transpose if ``a`` is real-valued). ``a`` must be + Hermitian (symmetric if real-valued) and positive-definite. No checking is + performed to verify whether ``a`` is Hermitian or not. In addition, only + the lower or upper-triangular and diagonal elements of ``a`` are used. + Only ``L`` or ``U`` is actually returned. For full documentation refer to :obj:`numpy.linalg.cholesky`. @@ -92,11 +93,15 @@ def cholesky(a): a : (..., M, M) {dpnp.ndarray, usm_ndarray} Hermitian (symmetric if all elements are real), positive-definite input matrix. + upper : bool, optional + If ``True``, the result must be the upper-triangular Cholesky factor. + If ``False``, the result must be the lower-triangular Cholesky factor. + Default: ``False``. Returns ------- L : (..., M, M) dpnp.ndarray - Lower-triangular Cholesky factor of `a`. + Lower or upper-triangular Cholesky factor of `a`. Examples -------- @@ -113,13 +118,19 @@ def cholesky(a): array([[1., 2.], [2., 5.]]) + The upper-triangular Cholesky factor can also be obtained: + + >>> np.linalg.cholesky(A, upper=True) + array([[ 1.+0.j, -0.-2.j], + [ 0.+0.j, 1.+0.j]] + """ dpnp.check_supported_arrays_type(a) check_stacked_2d(a) check_stacked_square(a) - return dpnp_cholesky(a) + return dpnp_cholesky(a, upper=upper) def cond(input, p=None): diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 384b070d9370..2ea63fae9979 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -43,6 +43,9 @@ _jobz = {"N": 0, "V": 1} _upper_lower = {"U": 0, "L": 1} +# Map booleans to MKL`s `uplo`` values: +# True -> 0 (upper), False -> 1 (lower). +_upper_lower_bool = {False: 1, True: 0} _real_types_map = { "float32": "float32", # single : single @@ -491,9 +494,9 @@ def dpnp_cholesky_batch(a, res_type): return a_h -def dpnp_cholesky(a): +def dpnp_cholesky(a, upper): """ - dpnp_cholesky(a) + dpnp_cholesky(a, upper) Return the Cholesky decomposition of `a` array. @@ -514,13 +517,16 @@ def dpnp_cholesky(a): sycl_queue=a_sycl_queue, ) + # Set `uplo` value for MKL functions based on boolean input + upper_lower = _upper_lower_bool[upper] + if a.ndim > 2: return dpnp_cholesky_batch(a, res_type) a_usm_arr = dpnp.get_usm_ndarray(a) # `a` must be copied because potrf destroys the input matrix - a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) + a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=a_usm_type) # use DPCTL tensor function to fill the сopy of the input array # from the input array @@ -533,13 +539,18 @@ def dpnp_cholesky(a): ht_lapack_ev, _ = li._potrf( a_sycl_queue, a_h.get_array(), + upper_lower, [a_copy_ev], ) ht_lapack_ev.wait() a_ht_copy_ev.wait() - a_h = dpnp.tril(a_h) + # Get upper or lower-triangular matrix part as per `upper` value + if upper: + a_h = dpnp.triu(a_h) + else: + a_h = dpnp.tril(a_h) return a_h From 153f96337dfd788fa7a2aea731d0f46e9216a80d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 18 Jan 2024 15:29:59 +0100 Subject: [PATCH 23/26] Add support upper paramenetr for potrf_batch and update dpnp_cholesky --- dpnp/backend/extensions/lapack/lapack_py.cpp | 4 +-- dpnp/backend/extensions/lapack/potrf.cpp | 6 ++-- dpnp/backend/extensions/lapack/potrf.hpp | 7 ++-- .../backend/extensions/lapack/potrf_batch.cpp | 34 ++++++++++--------- dpnp/linalg/dpnp_utils_linalg.py | 30 ++++++++++------ 5 files changed, 47 insertions(+), 34 deletions(-) diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index eea43d71ee32..f97a0dc4a433 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -105,8 +105,8 @@ PYBIND11_MODULE(_lapack_impl, m) "Call `potrf_batch` from OneMKL LAPACK library to return " "the Cholesky factorization of a batch of symmetric " "positive-definite matrix", - py::arg("sycl_queue"), py::arg("a_array"), py::arg("n"), - py::arg("stride_a"), py::arg("batch_size"), + py::arg("sycl_queue"), py::arg("a_array"), py::arg("upper_lower"), + py::arg("n"), py::arg("stride_a"), py::arg("batch_size"), py::arg("depends") = py::list()); m.def("_syevd", &lapack_ext::syevd, diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index e5b62b7e059c..29bb98027c15 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -162,10 +162,10 @@ std::pair std::to_string(a_array_shape[1]) + ")."); } - bool is_a_array_f_contig = a_array.is_f_contiguous(); - if (!is_a_array_f_contig) { + bool is_a_array_c_contig = a_array.is_c_contiguous(); + if (!is_a_array_c_contig) { throw py::value_error("The input array " - "must be F-contiguous"); + "must be C-contiguous"); } auto array_types = dpctl_td_ns::usm_ndarray_types(); diff --git a/dpnp/backend/extensions/lapack/potrf.hpp b/dpnp/backend/extensions/lapack/potrf.hpp index fe52b5c85aad..f0850b3fd98d 100644 --- a/dpnp/backend/extensions/lapack/potrf.hpp +++ b/dpnp/backend/extensions/lapack/potrf.hpp @@ -47,9 +47,10 @@ extern std::pair extern std::pair potrf_batch(sycl::queue exec_q, dpctl::tensor::usm_ndarray a_array, - std::int64_t n, - std::int64_t stride_a, - std::int64_t batch_size, + const std::int8_t upper_lower, + const std::int64_t n, + const std::int64_t stride_a, + const std::int64_t batch_size, const std::vector &depends = {}); extern void init_potrf_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/potrf_batch.cpp b/dpnp/backend/extensions/lapack/potrf_batch.cpp index 4f50cc4fdd50..54fb6580147c 100644 --- a/dpnp/backend/extensions/lapack/potrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/potrf_batch.cpp @@ -49,12 +49,12 @@ namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*potrf_batch_impl_fn_ptr_t)( sycl::queue, - oneapi::mkl::uplo, - std::int64_t, + const oneapi::mkl::uplo, + const std::int64_t, char *, - std::int64_t, - std::int64_t, - std::int64_t, + const std::int64_t, + const std::int64_t, + const std::int64_t, std::vector &, const std::vector &); @@ -63,12 +63,12 @@ static potrf_batch_impl_fn_ptr_t template static sycl::event potrf_batch_impl(sycl::queue exec_q, - oneapi::mkl::uplo upper_lower, - std::int64_t n, + const oneapi::mkl::uplo upper_lower, + const std::int64_t n, char *in_a, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t batch_size, + const std::int64_t lda, + const std::int64_t stride_a, + const std::int64_t batch_size, std::vector &host_task_events, const std::vector &depends) { @@ -171,9 +171,10 @@ static sycl::event potrf_batch_impl(sycl::queue exec_q, std::pair potrf_batch(sycl::queue q, dpctl::tensor::usm_ndarray a_array, - std::int64_t n, - std::int64_t stride_a, - std::int64_t batch_size, + const std::int8_t upper_lower, + const std::int64_t n, + const std::int64_t stride_a, + const std::int64_t batch_size, const std::vector &depends) { const int a_array_nd = a_array.get_ndim(); @@ -214,12 +215,13 @@ std::pair char *a_array_data = a_array.get_data(); const std::int64_t lda = std::max(1UL, n); - oneapi::mkl::uplo upper_lower = oneapi::mkl::uplo::upper; + const oneapi::mkl::uplo uplo_val = + static_cast(upper_lower); std::vector host_task_events; sycl::event potrf_batch_ev = - potrf_batch_fn(q, upper_lower, n, a_array_data, lda, stride_a, - batch_size, host_task_events, depends); + potrf_batch_fn(q, uplo_val, n, a_array_data, lda, stride_a, batch_size, + host_task_events, depends); sycl::event args_ev = dpctl::utils::keep_args_alive(q, {a_array}, host_task_events); diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 2ea63fae9979..11bccfb4740d 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -43,9 +43,6 @@ _jobz = {"N": 0, "V": 1} _upper_lower = {"U": 0, "L": 1} -# Map booleans to MKL`s `uplo`` values: -# True -> 0 (upper), False -> 1 (lower). -_upper_lower_bool = {False: 1, True: 0} _real_types_map = { "float32": "float32", # single : single @@ -444,9 +441,9 @@ def _lu_factor(a, res_type): return (a_h, ipiv_h, dev_info_array) -def dpnp_cholesky_batch(a, res_type): +def dpnp_cholesky_batch(a, upper_lower, res_type): """ - dpnp_cholesky_batch(a, res_type) + dpnp_cholesky_batch(a, upper_lower, res_type) Return the batched Cholesky decomposition of `a` array. @@ -480,6 +477,7 @@ def dpnp_cholesky_batch(a, res_type): ht_lapack_ev, _ = li._potrf_batch( a_sycl_queue, a_h.get_array(), + upper_lower, n, a_stride, batch_size, @@ -489,7 +487,12 @@ def dpnp_cholesky_batch(a, res_type): ht_lapack_ev.wait() a_ht_copy_ev.wait() - a_h = dpnp.tril(a_h.reshape(orig_shape)) + # Get upper or lower-triangular matrix part as per `upper_lower` value + # upper_lower is 0 (lower) or 1 (upper) + if upper_lower: + a_h = dpnp.triu(a_h.reshape(orig_shape)) + else: + a_h = dpnp.tril(a_h.reshape(orig_shape)) return a_h @@ -517,16 +520,23 @@ def dpnp_cholesky(a, upper): sycl_queue=a_sycl_queue, ) - # Set `uplo` value for MKL functions based on boolean input - upper_lower = _upper_lower_bool[upper] + # Set `uplo` value for `potrf` and `potrf_batch` function based on the boolean input `upper`. + # In oneapi::mkl, `uplo` value of 1 is equivalent to oneapi::mkl::uplo::lower + # and `uplo` value of 0 is equivalent to oneapi::mkl::uplo::upper. + # However, we adjust this logic based on the array's memory layout. + # Note: lower for row-major (which is used here) is upper for column-major layout. + # Reference: comment from tbmkl/tests/lapack/unit/dpcpp/potrf_usm/potrf_usm.cpp + # This means that if `upper` is False (lower triangular), + # we actually use oneapi::mkl::uplo::upper (0) for the row-major layout, and vice versa. + upper_lower = int(upper) if a.ndim > 2: - return dpnp_cholesky_batch(a, res_type) + return dpnp_cholesky_batch(a, upper_lower, res_type) a_usm_arr = dpnp.get_usm_ndarray(a) # `a` must be copied because potrf destroys the input matrix - a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=a_usm_type) + a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type) # use DPCTL tensor function to fill the сopy of the input array # from the input array From 48cc6576cecbe152d83b4d46975dc99bd17298eb Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 18 Jan 2024 16:28:43 +0100 Subject: [PATCH 24/26] Add tests for upper parameter of dpnp.linalg.cholesky --- tests/test_linalg.py | 74 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 201c4c3f83da..5f38421c6ec2 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -4,6 +4,7 @@ from numpy.testing import assert_allclose, assert_array_equal, assert_raises import dpnp as inp +from tests.third_party.cupy import testing from .helper import ( assert_dtype_allclose, @@ -62,13 +63,84 @@ class TestCholesky: ], ) @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) - def test_cholesky_3d_4d(self, array, dtype): + def test_cholesky(self, array, dtype): a = numpy.array(array, dtype=dtype) ia = inp.array(a) result = inp.linalg.cholesky(ia) expected = numpy.linalg.cholesky(a) assert_dtype_allclose(result, expected) + @pytest.mark.parametrize( + "array", + [ + [[1, 2], [2, 5]], + [[[5, 2], [2, 6]], [[7, 3], [3, 8]], [[3, 1], [1, 4]]], + [ + [[[5, 2], [2, 5]], [[6, 3], [3, 6]]], + [[[7, 2], [2, 7]], [[8, 3], [3, 8]]], + ], + ], + ids=[ + "2D_array", + "3D_array", + "4D_array", + ], + ) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_cholesky_upper(self, array, dtype): + ia = inp.array(array, dtype=dtype) + result = inp.linalg.cholesky(ia, upper=True) + + if ia.ndim > 2: + n = ia.shape[-1] + ia_reshaped = ia.reshape(-1, n, n) + res_reshaped = result.reshape(-1, n, n) + batch_size = ia_reshaped.shape[0] + for idx in range(batch_size): + # Reconstruct the matrix using the Cholesky decomposition result + if inp.issubdtype(dtype, inp.complexfloating): + reconstructed = ( + res_reshaped[idx].T.conj() @ res_reshaped[idx] + ) + else: + reconstructed = res_reshaped[idx].T @ res_reshaped[idx] + assert_dtype_allclose( + reconstructed, ia_reshaped[idx], check_type=False + ) + else: + # Reconstruct the matrix using the Cholesky decomposition result + if inp.issubdtype(dtype, inp.complexfloating): + reconstructed = result.T.conj() @ result + else: + reconstructed = result.T @ result + assert_dtype_allclose(reconstructed, ia, check_type=False) + + # upper parameter support will be added in numpy 2.0 version + @testing.with_requires("numpy>=2.0") + @pytest.mark.parametrize( + "array", + [ + [[1, 2], [2, 5]], + [[[5, 2], [2, 6]], [[7, 3], [3, 8]], [[3, 1], [1, 4]]], + [ + [[[5, 2], [2, 5]], [[6, 3], [3, 6]]], + [[[7, 2], [2, 7]], [[8, 3], [3, 8]]], + ], + ], + ids=[ + "2D_array", + "3D_array", + "4D_array", + ], + ) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_cholesky_upper_numpy(self, array, dtype): + a = numpy.array(array, dtype=dtype) + ia = inp.array(a) + result = inp.linalg.cholesky(ia, upper=True) + expected = numpy.linalg.cholesky(a, upper=True) + assert_dtype_allclose(result, expected) + def test_cholesky_strides(self): a_np = numpy.array( [ From ca423f8418658056de2130c8292a7896a2855471 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 19 Jan 2024 13:18:32 +0100 Subject: [PATCH 25/26] Address remarks --- dpnp/backend/extensions/lapack/potrf.cpp | 5 ++--- dpnp/backend/extensions/lapack/potrf_batch.cpp | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/lapack/potrf.cpp b/dpnp/backend/extensions/lapack/potrf.cpp index 29bb98027c15..610a629a9eb4 100644 --- a/dpnp/backend/extensions/lapack/potrf.cpp +++ b/dpnp/backend/extensions/lapack/potrf.cpp @@ -71,8 +71,7 @@ static sycl::event potrf_impl(sycl::queue exec_q, T *a = reinterpret_cast(in_a); const std::int64_t scratchpad_size = - oneapi::mkl::lapack::potrf_scratchpad_size(exec_q, upper_lower, n, - lda); + mkl_lapack::potrf_scratchpad_size(exec_q, upper_lower, n, lda); T *scratchpad = nullptr; std::stringstream error_msg; @@ -83,7 +82,7 @@ static sycl::event potrf_impl(sycl::queue exec_q, try { scratchpad = sycl::malloc_device(scratchpad_size, exec_q); - potrf_event = oneapi::mkl::lapack::potrf( + potrf_event = mkl_lapack::potrf( exec_q, upper_lower, // An enumeration value of type oneapi::mkl::uplo: // oneapi::mkl::uplo::upper for the upper triangular diff --git a/dpnp/backend/extensions/lapack/potrf_batch.cpp b/dpnp/backend/extensions/lapack/potrf_batch.cpp index 54fb6580147c..1a36bae4efd5 100644 --- a/dpnp/backend/extensions/lapack/potrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/potrf_batch.cpp @@ -77,8 +77,8 @@ static sycl::event potrf_batch_impl(sycl::queue exec_q, T *a = reinterpret_cast(in_a); const std::int64_t scratchpad_size = - oneapi::mkl::lapack::potrf_batch_scratchpad_size( - exec_q, upper_lower, n, lda, stride_a, batch_size); + mkl_lapack::potrf_batch_scratchpad_size(exec_q, upper_lower, n, lda, + stride_a, batch_size); T *scratchpad = nullptr; std::stringstream error_msg; @@ -89,7 +89,7 @@ static sycl::event potrf_batch_impl(sycl::queue exec_q, try { scratchpad = sycl::malloc_device(scratchpad_size, exec_q); - potrf_batch_event = oneapi::mkl::lapack::potrf_batch( + potrf_batch_event = mkl_lapack::potrf_batch( exec_q, upper_lower, // An enumeration value of type oneapi::mkl::uplo: // oneapi::mkl::uplo::upper for the upper triangular From cf701c91bf1cc61077b1745206e62bafa1e0bd3c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 19 Jan 2024 15:50:06 +0100 Subject: [PATCH 26/26] Fix validation check --- dpnp/linalg/dpnp_utils_linalg.py | 2 +- tests/test_random_state.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 11bccfb4740d..40159ac02bce 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -521,7 +521,7 @@ def dpnp_cholesky(a, upper): ) # Set `uplo` value for `potrf` and `potrf_batch` function based on the boolean input `upper`. - # In oneapi::mkl, `uplo` value of 1 is equivalent to oneapi::mkl::uplo::lower + # In oneMKL, `uplo` value of 1 is equivalent to oneapi::mkl::uplo::lower # and `uplo` value of 0 is equivalent to oneapi::mkl::uplo::upper. # However, we adjust this logic based on the array's memory layout. # Note: lower for row-major (which is used here) is upper for column-major layout. diff --git a/tests/test_random_state.py b/tests/test_random_state.py index 4771eadc42ea..70940501d2ee 100644 --- a/tests/test_random_state.py +++ b/tests/test_random_state.py @@ -491,7 +491,7 @@ def test_rng_zero_and_extremes(self): sycl_device = dpctl.SyclQueue().sycl_device if sycl_device.has_aspect_gpu and not sycl_device.has_aspect_fp64: - # TODO: discuss with opneMKL + # TODO: discuss with oneMKL pytest.skip( f"Due to some reason, oneMKL wrongly returns high value instead of low" )