-
Notifications
You must be signed in to change notification settings - Fork 22
Implement of dpnp.linalg.slogdet() #1607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
308b8cf
Add a new impl of dpnp.linalg._lu_factor
vlad-perevezentsev 1119341
Get dev_info_array after calling getrf
vlad-perevezentsev 71c2e96
Add an extra dev_info return to _lu_factor
vlad-perevezentsev b24c8c5
qwe
vlad-perevezentsev 500df36
Add a logic for a.ndim > 2 in _lu_factor
vlad-perevezentsev b35e282
Add an implementation of dpnp.linalg.slogdet
vlad-perevezentsev 301fc2a
Add a new test_norms.py file in cupy tests
vlad-perevezentsev 9aba014
Expand test scope in public CI
vlad-perevezentsev 19d909c
Merge master into impl_lu_factor
vlad-perevezentsev 2419eba
Merge master into impl_lu_factor
vlad-perevezentsev a8789e4
A small update _lu_factor func
vlad-perevezentsev 59642f6
Remove w/a for dpnp.count_nonzero in slogdet
vlad-perevezentsev db45555
getrf returns pair of events and uses dpctl.utils.keep_args_alive
vlad-perevezentsev 0350d86
Update dpnp.linalg.det using slogdet
vlad-perevezentsev 6c20deb
Add new cupy tests for dpnp.linalg.det
vlad-perevezentsev b17802d
Merge master into impl_lu_factor
vlad-perevezentsev 3359c05
Add ipiv_vecs and dev_info_vecs in _lu_factor for the batch case
vlad-perevezentsev 1a91385
Skip test_det on CPU due to bug in MKL
vlad-perevezentsev b860790
Small update of cupy tests in test_norms.py
vlad-perevezentsev 37d476b
Add support of complex dtype for dpnp.diagonal and update test_diagonal
vlad-perevezentsev 2e8b4fe
lu_factor func returns the result of LU decomposition as c-contiguous…
vlad-perevezentsev 3849d92
Add getrf_batch MKL extension
vlad-perevezentsev cd282a3
Update docstring for slogdet
vlad-perevezentsev 771605f
Add more tests
vlad-perevezentsev 35de575
Merge master into impl_lu_factor
vlad-perevezentsev cde627d
Remove accidentally added file
vlad-perevezentsev 694870b
Modify sign parameter calculation
vlad-perevezentsev d85d00d
Remove the old backend implementation of dpnp_det
vlad-perevezentsev c4b9992
qwe
vlad-perevezentsev 2ad0bc4
Merge master into impl_lu_factor
vlad-perevezentsev 78b98e7
Keep lexographical order
vlad-perevezentsev 46a9965
Add dpnp_slogdet to dpnp_utils_linalg
vlad-perevezentsev da16383
Move _lu_factor above
vlad-perevezentsev 4e0e183
A minor update
vlad-perevezentsev 80d8188
A minor changes for _lu_factor
vlad-perevezentsev 628dd90
Remove trash files
vlad-perevezentsev a8db460
Use getrf_batch only on CPU
vlad-perevezentsev e3cd5c4
Update tests for dpnp.linalg.slogdet
vlad-perevezentsev 40c7a29
Merge master into impl_lu_factor
vlad-perevezentsev 4679637
Merge master into impl_lu_factor
vlad-perevezentsev 99f3618
Address remarks
vlad-perevezentsev 1f9b6fa
Add _real_type func
vlad-perevezentsev 7e12063
Add test_det in test_usm_type
vlad-perevezentsev 0e82258
Add more checks in getrf and getf_batch functions
vlad-perevezentsev 3a6e5ce
Improve error handler in getrf_impl
vlad-perevezentsev 6700fce
Improve error handler in getrf_batch_impl
vlad-perevezentsev 644485b
dev_info is allocated as zeros
vlad-perevezentsev 5015e15
Remove skipif for singular tests
vlad-perevezentsev 68c436e
Implement _lu_factor logic with dev_info as a python list
vlad-perevezentsev 030a083
Update getrf_rf error handler with mkl_lapack::batch_error
vlad-perevezentsev 579b4e5
Remove passing n parameter to _getrf
vlad-perevezentsev acd04b7
Add a new test_slogdet_singular_matrix_3D test
vlad-perevezentsev 0fe0bf1
Merge remote-tracking branch 'origin/master' into impl_lu_factor
vlad-perevezentsev f0cc4d3
Update tests for dpnp.linalg.det
vlad-perevezentsev 8896aab
Merge master into impl_lu_factor
vlad-perevezentsev c9b7c3b
Use is_exception_caught flag in getrf and getrf_batch error handler
vlad-perevezentsev a3873cd
Update gesv error handler
vlad-perevezentsev 9652797
Reshape results after calling getrf_batch
vlad-perevezentsev cef4690
Add a new dpnp.linalg.det impl and refresh dpnp_utils_linalg
vlad-perevezentsev 67681bf
Remove Limitations from dpnp_det and dpnp_slogdet docstings
vlad-perevezentsev ed71f6b
Address remarks
vlad-perevezentsev e8f5fbd
Merge master into impl_lu_factor
vlad-perevezentsev 75fa23a
Remove det_dtype variable and use the abs val of diag for det
vlad-perevezentsev e9bdcd6
Expand cupy tests for dpnp.linalg.det()
vlad-perevezentsev 549b5da
Update TestDet and TestSlogdet
vlad-perevezentsev 26277ba
Merge master into impl_lu_factor
vlad-perevezentsev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
// dpctl tensor headers | ||
#include "utils/memory_overlap.hpp" | ||
#include "utils/type_utils.hpp" | ||
|
||
#include "getrf.hpp" | ||
#include "types_matrix.hpp" | ||
|
||
#include "dpnp_utils.hpp" | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
namespace mkl_lapack = oneapi::mkl::lapack; | ||
namespace py = pybind11; | ||
namespace type_utils = dpctl::tensor::type_utils; | ||
|
||
typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue, | ||
const std::int64_t, | ||
char *, | ||
std::int64_t, | ||
std::int64_t *, | ||
py::list, | ||
std::vector<sycl::event> &, | ||
const std::vector<sycl::event> &); | ||
|
||
static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types]; | ||
|
||
template <typename T> | ||
static sycl::event getrf_impl(sycl::queue exec_q, | ||
const std::int64_t n, | ||
char *in_a, | ||
std::int64_t lda, | ||
std::int64_t *ipiv, | ||
py::list dev_info, | ||
std::vector<sycl::event> &host_task_events, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
type_utils::validate_type_for_device<T>(exec_q); | ||
|
||
T *a = reinterpret_cast<T *>(in_a); | ||
|
||
const std::int64_t scratchpad_size = | ||
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda); | ||
T *scratchpad = nullptr; | ||
|
||
std::stringstream error_msg; | ||
std::int64_t info = 0; | ||
bool is_exception_caught = false; | ||
|
||
sycl::event getrf_event; | ||
try { | ||
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q); | ||
|
||
getrf_event = mkl_lapack::getrf( | ||
exec_q, | ||
n, // The order of the square matrix A (0 ≤ n). | ||
// It must be a non-negative integer. | ||
n, // The number of columns in the square matrix A (0 ≤ n). | ||
// It must be a non-negative integer. | ||
a, // Pointer to the square matrix A (n x n). | ||
lda, // The leading dimension of matrix A. | ||
// It must be at least max(1, n). | ||
ipiv, // Pointer to the output array of pivot indices. | ||
scratchpad, // Pointer to scratchpad memory to be used by MKL | ||
// routine for storing intermediate results. | ||
scratchpad_size, depends); | ||
} catch (mkl_lapack::exception const &e) { | ||
is_exception_caught = true; | ||
info = e.info(); | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (info < 0) { | ||
error_msg << "Parameter number " << -info | ||
<< " had an illegal value."; | ||
} | ||
else if (info == scratchpad_size && e.detail() != 0) { | ||
error_msg | ||
<< "Insufficient scratchpad size. Required size is at least " | ||
<< e.detail(); | ||
} | ||
else if (info > 0) { | ||
// Store the positive 'info' value in the first element of | ||
// 'dev_info'. This indicates that the factorization has been | ||
// completed, but the factor U (upper triangular matrix) is exactly | ||
// singular. The 'info' value here is the index of the first zero | ||
// element in the diagonal of U. | ||
is_exception_caught = false; | ||
dev_info[0] = info; | ||
} | ||
else { | ||
error_msg << "Unexpected MKL exception caught during getrf() " | ||
"call:\nreason: " | ||
<< e.what() << "\ninfo: " << e.info(); | ||
} | ||
} catch (sycl::exception const &e) { | ||
is_exception_caught = true; | ||
error_msg << "Unexpected SYCL exception caught during getrf() call:\n" | ||
<< e.what(); | ||
} | ||
|
||
if (is_exception_caught) // an unexpected error occurs | ||
{ | ||
if (scratchpad != nullptr) { | ||
sycl::free(scratchpad, exec_q); | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
throw std::runtime_error(error_msg.str()); | ||
} | ||
|
||
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { | ||
cgh.depends_on(getrf_event); | ||
auto ctx = exec_q.get_context(); | ||
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); | ||
}); | ||
host_task_events.push_back(clean_up_event); | ||
return getrf_event; | ||
} | ||
|
||
std::pair<sycl::event, sycl::event> | ||
getrf(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray ipiv_array, | ||
py::list dev_info, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
const int a_array_nd = a_array.get_ndim(); | ||
const int ipiv_array_nd = ipiv_array.get_ndim(); | ||
|
||
if (a_array_nd != 2) { | ||
throw py::value_error( | ||
"The input array has ndim=" + std::to_string(a_array_nd) + | ||
", but a 2-dimensional array is expected."); | ||
} | ||
|
||
if (ipiv_array_nd != 1) { | ||
throw py::value_error("The array of pivot indices has ndim=" + | ||
std::to_string(ipiv_array_nd) + | ||
", but a 1-dimensional array is expected."); | ||
} | ||
|
||
// check compatibility of execution queue and allocation queue | ||
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) { | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
throw py::value_error( | ||
"Execution queue is not compatible with allocation queues"); | ||
} | ||
|
||
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); | ||
if (overlap(a_array, ipiv_array)) { | ||
throw py::value_error("The input array and the array of pivot indices " | ||
"are overlapping segments of memory"); | ||
} | ||
|
||
bool is_a_array_c_contig = a_array.is_c_contiguous(); | ||
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); | ||
if (!is_a_array_c_contig) { | ||
throw py::value_error("The input array " | ||
"must be C-contiguous"); | ||
} | ||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!is_ipiv_array_c_contig) { | ||
throw py::value_error("The array of pivot indices " | ||
"must be C-contiguous"); | ||
} | ||
|
||
auto array_types = dpctl_td_ns::usm_ndarray_types(); | ||
int a_array_type_id = | ||
array_types.typenum_to_lookup_id(a_array.get_typenum()); | ||
|
||
getrf_impl_fn_ptr_t getrf_fn = getrf_dispatch_vector[a_array_type_id]; | ||
if (getrf_fn == nullptr) { | ||
throw py::value_error( | ||
"No getrf implementation defined for the provided type " | ||
"of the input matrix."); | ||
} | ||
|
||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); | ||
int ipiv_array_type_id = | ||
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); | ||
|
||
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) { | ||
throw py::value_error("The type of 'ipiv_array' must be int64."); | ||
} | ||
|
||
const std::int64_t n = a_array.get_shape_raw()[0]; | ||
|
||
char *a_array_data = a_array.get_data(); | ||
const std::int64_t lda = std::max<size_t>(1UL, n); | ||
|
||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
char *ipiv_array_data = ipiv_array.get_data(); | ||
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data); | ||
|
||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<sycl::event> host_task_events; | ||
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv, | ||
dev_info, host_task_events, depends); | ||
|
||
sycl::event args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {a_array, ipiv_array}, host_task_events); | ||
|
||
return std::make_pair(args_ev, getrf_ev); | ||
} | ||
|
||
template <typename fnT, typename T> | ||
struct GetrfContigFactory | ||
{ | ||
fnT get() | ||
{ | ||
if constexpr (types::GetrfTypePairSupportFactory<T>::is_defined) { | ||
return getrf_impl<T>; | ||
} | ||
else { | ||
return nullptr; | ||
} | ||
} | ||
}; | ||
|
||
void init_getrf_dispatch_vector(void) | ||
{ | ||
dpctl_td_ns::DispatchVectorBuilder<getrf_impl_fn_ptr_t, GetrfContigFactory, | ||
dpctl_td_ns::num_types> | ||
contig; | ||
contig.populate_dispatch_vector(getrf_dispatch_vector); | ||
} | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#pragma once | ||
|
||
#include <CL/sycl.hpp> | ||
#include <oneapi/mkl.hpp> | ||
|
||
#include <dpctl4pybind11.hpp> | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
extern std::pair<sycl::event, sycl::event> | ||
getrf(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray ipiv_array, | ||
py::list dev_info, | ||
const std::vector<sycl::event> &depends = {}); | ||
|
||
extern std::pair<sycl::event, sycl::event> | ||
getrf_batch(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray ipiv_array, | ||
py::list dev_info, | ||
std::int64_t n, | ||
std::int64_t stride_a, | ||
std::int64_t stride_ipiv, | ||
std::int64_t batch_size, | ||
const std::vector<sycl::event> &depends = {}); | ||
|
||
extern void init_getrf_dispatch_vector(void); | ||
extern void init_getrf_batch_dispatch_vector(void); | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.