-
Notifications
You must be signed in to change notification settings - Fork 22
Add dpnp.linalg.solve() function #1598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
2613795
Correct return of object type at zero copy
vlad-perevezentsev efdcf2a
Add tests for gh-1570
vlad-perevezentsev 6500a6c
Add dpnp.linalg.solve() function
vlad-perevezentsev 51663d0
Check validity of input array shapes
vlad-perevezentsev 00b7041
Add logic for a.ndim > 2
vlad-perevezentsev e4e15ad
Raise value_error if coeff_matrix_nd != 2 in gesv
vlad-perevezentsev e762c66
Add cupy tests for dpnp.linalg.solve()
vlad-perevezentsev a0f76d5
Add LinAlgError exception and extend error handling for mkl::lapack::…
vlad-perevezentsev 89f47c7
Update test_solve
vlad-perevezentsev 159c460
Add test_solve to test scope
vlad-perevezentsev 2ad3aa8
Fix getting nrhs to avoid CPU falling tests
vlad-perevezentsev ac02cf2
Add test_solve to test_sycl_queue
vlad-perevezentsev 035d983
Add more tests for solve()
vlad-perevezentsev f80ba4b
Register a LinAlgError in dpnp.linalg submodule
vlad-perevezentsev 9a3db74
Raise dpnp.linalg.LinAlgError in solve()
vlad-perevezentsev a8b4fec
Small changes to the docstrings
vlad-perevezentsev 40e74ae
Merge master into impl_solve
vlad-perevezentsev 3f88dc5
Simplify ThresholdType determination
vlad-perevezentsev 22e8734
Small changes to the docstrings
vlad-perevezentsev 79c60df
Merge master into impl_solve_1
vlad-perevezentsev 87cda0b
Remove if op_count due to unreachable
vlad-perevezentsev 76e035d
Improve test coverage
vlad-perevezentsev 1bfc81a
Impl dtype dispatching with linalg_common_type for dpnp.linalg.solve
vlad-perevezentsev c7b284b
Add a new test_solve_diff_type
vlad-perevezentsev 353b756
Merge master into impl_solve_1
vlad-perevezentsev e5c7626
Merge master into impl_solve_1
vlad-perevezentsev e99d37c
Add a common_helpers.hpp file
vlad-perevezentsev ec1e966
Use bool flag for sycl exception
vlad-perevezentsev 6885aa3
Refactor memory management for ipiv in gesv_impl
vlad-perevezentsev 56c920f
Rename linalg_common_type to _common_type and change the number of ty…
vlad-perevezentsev 07b1418
Address the remarks
vlad-perevezentsev 711a62f
Merge master into impl_solve_1
vlad-perevezentsev 0336c00
Remove the use of prod to get 3d array and rename op_count to batch_size
vlad-perevezentsev 31f6f10
gesv returns pair of events and uses dpctl.utils.keep_args_alive
vlad-perevezentsev d3717a6
Return the use prod to get 3d arrays
vlad-perevezentsev 8e19740
Merge master into impl_solve_1
vlad-perevezentsev 515df4e
Add test_solve_singular_matrix in TestSolve
vlad-perevezentsev f992b0b
Adress the remarks
vlad-perevezentsev cd21c7f
Add res_usm_type variavble and new tests in test_usm_type for dpnp.li…
vlad-perevezentsev 5781f0c
Add skipif for test_solve_singular_matrix on cpu
vlad-perevezentsev 0f87471
Merge master into impl_solve_1
vlad-perevezentsev b4eb1ad
Merge branch 'master' into impl_solve_1
antonwolfy cec8154
Modify _common_inexact_type and add a description for it
vlad-perevezentsev c3e5a0f
A small update of the desctiption of dpnp.linalg.solve() func
vlad-perevezentsev eb2dd4c
Use device param for default_float_type in _common_type
vlad-perevezentsev 4780597
Simplify getting 3d array in dpnp_solve
vlad-perevezentsev f1b6a81
Remove unnecessary copying to F order after invoking gesv
vlad-perevezentsev b8f4cb9
Use get_usm_allocations instead of get_execution_queue
vlad-perevezentsev 00ebef1
Move copying just after the memory allocation
vlad-perevezentsev 22d4d6f
Add additional checks to gesv implementation
vlad-perevezentsev 04f8f41
Add validation functions for array types and dimensions for linalg funcs
vlad-perevezentsev 366be7a
Update test_solve_diff_type in test_linalg.py
vlad-perevezentsev 08ac7fe
Address the remarks
vlad-perevezentsev eb6a840
Small update
vlad-perevezentsev 61a6073
qwe
vlad-perevezentsev 6a25e69
Merge origin/master into impl_solve_1
vlad-perevezentsev 4ba0c7f
Merge master into impl_solve_1
vlad-perevezentsev eba811a
Rename assert funcs and make them external in dpnp_utils_linalg
vlad-perevezentsev 3a9d459
Use assert_dtype_allclose for test_solve in test_sycl_queue
vlad-perevezentsev 0de6968
Remove an unnecessary file
vlad-perevezentsev df72c77
Fix validation for CI
vlad-perevezentsev 4991a81
Remove eqec_q check that will never happen
vlad-perevezentsev 965b89a
Merge master into impl_solve_1
vlad-perevezentsev 9b5a5a5
Set usm_type for out_v
vlad-perevezentsev 3c7ad07
Update test_solve_singular_empty
vlad-perevezentsev e76b278
Merge master into impl_solve_1
vlad-perevezentsev 3001872
Skip test_solve_singular_empty
vlad-perevezentsev 5c3693b
Merge master into impl_solve_1
vlad-perevezentsev 37f5400
Merge master into impl_solve_1
vlad-perevezentsev 5392a45
Fix validation fell
vlad-perevezentsev 6bea640
Merge master into impl_solve_1
vlad-perevezentsev c105570
A small update
vlad-perevezentsev 3165397
Merge master into impl_solve_1
vlad-perevezentsev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#pragma once | ||
#include <cstring> | ||
#include <stdexcept> | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
namespace helper | ||
{ | ||
template <typename T> | ||
struct value_type_of | ||
{ | ||
using type = T; | ||
}; | ||
|
||
template <typename T> | ||
struct value_type_of<std::complex<T>> | ||
{ | ||
using type = T; | ||
}; | ||
} // namespace helper | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
// dpctl tensor headers | ||
#include "utils/memory_overlap.hpp" | ||
#include "utils/type_utils.hpp" | ||
|
||
#include "common_helpers.hpp" | ||
#include "gesv.hpp" | ||
#include "linalg_exceptions.hpp" | ||
#include "types_matrix.hpp" | ||
|
||
#include "dpnp_utils.hpp" | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
namespace mkl_lapack = oneapi::mkl::lapack; | ||
namespace py = pybind11; | ||
namespace type_utils = dpctl::tensor::type_utils; | ||
|
||
typedef sycl::event (*gesv_impl_fn_ptr_t)(sycl::queue, | ||
const std::int64_t, | ||
const std::int64_t, | ||
char *, | ||
std::int64_t, | ||
char *, | ||
std::int64_t, | ||
std::vector<sycl::event> &, | ||
const std::vector<sycl::event> &); | ||
|
||
static gesv_impl_fn_ptr_t gesv_dispatch_vector[dpctl_td_ns::num_types]; | ||
|
||
template <typename T> | ||
static sycl::event gesv_impl(sycl::queue exec_q, | ||
const std::int64_t n, | ||
const std::int64_t nrhs, | ||
char *in_a, | ||
std::int64_t lda, | ||
char *in_b, | ||
std::int64_t ldb, | ||
std::vector<sycl::event> &host_task_events, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
type_utils::validate_type_for_device<T>(exec_q); | ||
|
||
T *a = reinterpret_cast<T *>(in_a); | ||
T *b = reinterpret_cast<T *>(in_b); | ||
|
||
const std::int64_t scratchpad_size = | ||
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb); | ||
T *scratchpad = nullptr; | ||
|
||
std::int64_t *ipiv = nullptr; | ||
|
||
std::stringstream error_msg; | ||
std::int64_t info = 0; | ||
bool sycl_exception_caught = false; | ||
|
||
sycl::event gesv_event; | ||
try { | ||
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q); | ||
ipiv = sycl::malloc_device<std::int64_t>(n, exec_q); | ||
|
||
gesv_event = mkl_lapack::gesv( | ||
exec_q, | ||
n, // The order of the matrix A (0 ≤ n). | ||
nrhs, // The number of right-hand sides B (0 ≤ nrhs). | ||
a, // Pointer to the square coefficient matrix A (n x n). | ||
lda, // The leading dimension of a, must be at least max(1, n). | ||
ipiv, // The pivot indices that define the permutation matrix P; | ||
// row i of the matrix was interchanged with row ipiv(i), | ||
// must be at least max(1, n). | ||
b, // Pointer to the right hand side matrix B (n x nrhs). | ||
ldb, // The leading dimension of b, must be at least max(1, n). | ||
scratchpad, // Pointer to scratchpad memory to be used by MKL | ||
// routine for storing intermediate results. | ||
scratchpad_size, depends); | ||
} catch (mkl_lapack::exception const &e) { | ||
info = e.info(); | ||
|
||
if (info < 0) { | ||
error_msg << "Parameter number " << -info | ||
<< " had an illegal value."; | ||
} | ||
else if (info > 0) { | ||
T host_U; | ||
exec_q.memcpy(&host_U, &a[(info - 1) * lda + info - 1], sizeof(T)) | ||
.wait(); | ||
|
||
using ThresholdType = typename helper::value_type_of<T>::type; | ||
|
||
const auto threshold = | ||
std::numeric_limits<ThresholdType>::epsilon() * 100; | ||
if (std::abs(host_U) < threshold) { | ||
sycl::free(scratchpad, exec_q); | ||
throw LinAlgError("The input coefficient matrix is singular."); | ||
} | ||
else { | ||
error_msg << "Unexpected MKL exception caught during gesv() " | ||
"call:\nreason: " | ||
<< e.what() << "\ninfo: " << e.info(); | ||
} | ||
} | ||
else if (info == scratchpad_size && e.detail() != 0) { | ||
error_msg | ||
<< "Insufficient scratchpad size. Required size is at least " | ||
<< e.detail(); | ||
} | ||
else { | ||
error_msg << "Unexpected MKL exception caught during gesv() " | ||
"call:\nreason: " | ||
<< e.what() << "\ninfo: " << e.info(); | ||
} | ||
} catch (sycl::exception const &e) { | ||
error_msg << "Unexpected SYCL exception caught during gesv() call:\n" | ||
<< e.what(); | ||
sycl_exception_caught = true; | ||
} | ||
|
||
if (info != 0 || sycl_exception_caught) // an unexpected error occurs | ||
{ | ||
if (scratchpad != nullptr) { | ||
sycl::free(scratchpad, exec_q); | ||
} | ||
if (ipiv != nullptr) { | ||
sycl::free(ipiv, exec_q); | ||
} | ||
throw std::runtime_error(error_msg.str()); | ||
} | ||
|
||
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { | ||
cgh.depends_on(gesv_event); | ||
auto ctx = exec_q.get_context(); | ||
cgh.host_task([ctx, scratchpad, ipiv]() { | ||
sycl::free(scratchpad, ctx); | ||
sycl::free(ipiv, ctx); | ||
}); | ||
}); | ||
host_task_events.push_back(clean_up_event); | ||
|
||
return gesv_event; | ||
} | ||
|
||
std::pair<sycl::event, sycl::event> | ||
gesv(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray coeff_matrix, | ||
dpctl::tensor::usm_ndarray dependent_vals, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
const int coeff_matrix_nd = coeff_matrix.get_ndim(); | ||
const int dependent_vals_nd = dependent_vals.get_ndim(); | ||
|
||
if (coeff_matrix_nd != 2) { | ||
throw py::value_error("The coefficient matrix has ndim=" + | ||
std::to_string(coeff_matrix_nd) + | ||
", but a 2-dimensional array is expected."); | ||
} | ||
|
||
if (dependent_vals_nd > 2) { | ||
throw py::value_error( | ||
"The dependent values array has ndim=" + | ||
std::to_string(dependent_vals_nd) + | ||
", but a 1-dimensional or a 2-dimensional array is expected."); | ||
} | ||
|
||
const py::ssize_t *coeff_matrix_shape = coeff_matrix.get_shape_raw(); | ||
const py::ssize_t *dependent_vals_shape = dependent_vals.get_shape_raw(); | ||
|
||
if (coeff_matrix_shape[0] != coeff_matrix_shape[1]) { | ||
throw py::value_error("The coefficient matrix must be square," | ||
" but got a shape of (" + | ||
std::to_string(coeff_matrix_shape[0]) + ", " + | ||
std::to_string(coeff_matrix_shape[1]) + ")."); | ||
} | ||
|
||
// check compatibility of execution queue and allocation queue | ||
if (!dpctl::utils::queues_are_compatible(exec_q, | ||
{coeff_matrix, dependent_vals})) | ||
{ | ||
throw py::value_error( | ||
"Execution queue is not compatible with allocation queues"); | ||
} | ||
|
||
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); | ||
if (overlap(coeff_matrix, dependent_vals)) { | ||
throw py::value_error( | ||
"The arrays of coefficients and dependent variables " | ||
"are overlapping segments of memory"); | ||
} | ||
|
||
bool is_coeff_matrix_f_contig = coeff_matrix.is_f_contiguous(); | ||
antonwolfy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!is_coeff_matrix_f_contig) { | ||
throw py::value_error("The coefficient matrix " | ||
"must be F-contiguous"); | ||
} | ||
|
||
bool is_dependent_vals_f_contig = dependent_vals.is_f_contiguous(); | ||
if (!is_dependent_vals_f_contig) { | ||
throw py::value_error("The array of dependent variables " | ||
"must be F-contiguous"); | ||
} | ||
|
||
auto array_types = dpctl_td_ns::usm_ndarray_types(); | ||
int coeff_matrix_type_id = | ||
array_types.typenum_to_lookup_id(coeff_matrix.get_typenum()); | ||
int dependent_vals_type_id = | ||
array_types.typenum_to_lookup_id(dependent_vals.get_typenum()); | ||
|
||
if (coeff_matrix_type_id != dependent_vals_type_id) { | ||
throw py::value_error("The types of the coefficient matrix and " | ||
"dependent variables are mismatched"); | ||
} | ||
|
||
gesv_impl_fn_ptr_t gesv_fn = gesv_dispatch_vector[coeff_matrix_type_id]; | ||
if (gesv_fn == nullptr) { | ||
throw py::value_error( | ||
"No gesv implementation defined for the provided type " | ||
"of the coefficient matrix."); | ||
} | ||
|
||
char *coeff_matrix_data = coeff_matrix.get_data(); | ||
char *dependent_vals_data = dependent_vals.get_data(); | ||
|
||
const std::int64_t n = coeff_matrix_shape[0]; | ||
const std::int64_t m = dependent_vals_shape[0]; | ||
const std::int64_t nrhs = | ||
(dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1; | ||
|
||
const std::int64_t lda = std::max<size_t>(1UL, n); | ||
const std::int64_t ldb = std::max<size_t>(1UL, m); | ||
|
||
std::vector<sycl::event> host_task_events; | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sycl::event gesv_ev = | ||
gesv_fn(exec_q, n, nrhs, coeff_matrix_data, lda, dependent_vals_data, | ||
ldb, host_task_events, depends); | ||
|
||
sycl::event args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {coeff_matrix, dependent_vals}, host_task_events); | ||
|
||
return std::make_pair(args_ev, gesv_ev); | ||
} | ||
|
||
template <typename fnT, typename T> | ||
struct GesvContigFactory | ||
{ | ||
fnT get() | ||
{ | ||
if constexpr (types::GesvTypePairSupportFactory<T>::is_defined) { | ||
return gesv_impl<T>; | ||
} | ||
else { | ||
return nullptr; | ||
} | ||
} | ||
}; | ||
|
||
void init_gesv_dispatch_vector(void) | ||
{ | ||
dpctl_td_ns::DispatchVectorBuilder<gesv_impl_fn_ptr_t, GesvContigFactory, | ||
dpctl_td_ns::num_types> | ||
contig; | ||
contig.populate_dispatch_vector(gesv_dispatch_vector); | ||
} | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.