diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp b/dpctl/tensor/libtensor/source/reduction_over_axis.cpp deleted file mode 100644 index 00e4a0a076..0000000000 --- a/dpctl/tensor/libtensor/source/reduction_over_axis.cpp +++ /dev/null @@ -1,514 +0,0 @@ -//===-- ------------ Implementation of _tensor_impl module ----*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===--------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions -//===--------------------------------------------------------------------===// - -#include -#include -#include -#include - -#include -#include -#include - -#include "dpctl4pybind11.hpp" -#include "kernels/reductions.hpp" -#include "reduction_over_axis.hpp" -#include "simplify_iteration_space.hpp" -#include "utils/type_dispatch.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -namespace td_ns = dpctl::tensor::type_dispatch; -// Max -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - max_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_max_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::MaxOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::MaxOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::MaxOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::MaxOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Min -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - min_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_min_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::MinOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::MinOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::MinOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::MinOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Sum -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_sum_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using namespace td_ns; - - using dpctl::tensor::kernels::SumOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::SumOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Product -namespace impl -{ - -using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; -static reduction_strided_impl_fn_ptr - prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_strided_impl_fn_ptr - prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; -static reduction_contig_impl_fn_ptr - prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; -static reduction_contig_impl_fn_ptr - prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_prod_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; - using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; - using namespace td_ns; - - using dpctl::tensor::kernels::ProductOverAxisAtomicStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); - - using dpctl::tensor::kernels::ProductOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb2; - dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); - - using dpctl::tensor::kernels::ProductOverAxis1AtomicContigFactory; - DispatchTableBuilder - dtb3; - dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); - - using dpctl::tensor::kernels::ProductOverAxis0AtomicContigFactory; - DispatchTableBuilder - dtb4; - dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); -} - -} // namespace impl - -// Argmax -namespace impl -{ - -using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; -static search_reduction_strided_impl_fn_ptr - argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_argmax_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::ArgmaxOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); -} - -} // namespace impl - -// Argmin -namespace impl -{ - -using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; -static search_reduction_strided_impl_fn_ptr - argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -void populate_argmin_over_axis_dispatch_tables(void) -{ - using dpctl::tensor::kernels::search_reduction_strided_impl_fn_ptr; - using td_ns::DispatchTableBuilder; - - using dpctl::tensor::kernels::ArgminOverAxisTempsStridedFactory; - DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); -} - -} // namespace impl - -namespace py = pybind11; - -void init_reduction_functions(py::module_ m) -{ - using arrayT = dpctl::tensor::usm_ndarray; - using event_vecT = std::vector; - - namespace impl = dpctl::tensor::py_internal::impl; - - using dpctl::tensor::py_internal::py_reduction_dtype_supported; - using dpctl::tensor::py_internal::py_reduction_over_axis; - - using dpctl::tensor::py_internal::check_atomic_support; - using dpctl::tensor::py_internal::fixed_decision; - - // MAX - { - using dpctl::tensor::py_internal::impl:: - populate_max_over_axis_dispatch_tables; - populate_max_over_axis_dispatch_tables(); - using impl::max_over_axis0_contig_atomic_dispatch_table; - using impl::max_over_axis1_contig_atomic_dispatch_table; - using impl::max_over_axis_strided_atomic_dispatch_table; - using impl::max_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - max_over_axis_strided_atomic_dispatch_table, - max_over_axis_strided_temps_dispatch_table, - max_over_axis0_contig_atomic_dispatch_table, - max_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_max_over_axis", max_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } - - // MIN - { - using dpctl::tensor::py_internal::impl:: - populate_min_over_axis_dispatch_tables; - populate_min_over_axis_dispatch_tables(); - using impl::min_over_axis0_contig_atomic_dispatch_table; - using impl::min_over_axis1_contig_atomic_dispatch_table; - using impl::min_over_axis_strided_atomic_dispatch_table; - using impl::min_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - min_over_axis_strided_atomic_dispatch_table, - min_over_axis_strided_temps_dispatch_table, - min_over_axis0_contig_atomic_dispatch_table, - min_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_min_over_axis", min_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } - - // SUM - { - using dpctl::tensor::py_internal::impl:: - populate_sum_over_axis_dispatch_tables; - populate_sum_over_axis_dispatch_tables(); - using impl::sum_over_axis0_contig_atomic_dispatch_table; - using impl::sum_over_axis1_contig_atomic_dispatch_table; - using impl::sum_over_axis_strided_atomic_dispatch_table; - using impl::sum_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - sum_over_axis_strided_atomic_dispatch_table, - sum_over_axis_strided_temps_dispatch_table, - sum_over_axis0_contig_atomic_dispatch_table, - sum_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto sum_dtype_supported = - [&](const py::dtype &input_dtype, const py::dtype &output_dtype, - const std::string &dst_usm_type, sycl::queue &q) { - return py_reduction_dtype_supported( - input_dtype, output_dtype, dst_usm_type, q, - sum_over_axis_strided_atomic_dispatch_table, - sum_over_axis_strided_temps_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", - py::arg("arg_dtype"), py::arg("out_dtype"), - py::arg("dst_usm_type"), py::arg("sycl_queue")); - } - - // PROD - { - using dpctl::tensor::py_internal::impl:: - populate_prod_over_axis_dispatch_tables; - populate_prod_over_axis_dispatch_tables(); - using impl::prod_over_axis0_contig_atomic_dispatch_table; - using impl::prod_over_axis1_contig_atomic_dispatch_table; - using impl::prod_over_axis_strided_atomic_dispatch_table; - using impl::prod_over_axis_strided_temps_dispatch_table; - - const auto &check_atomic_support_size4 = - check_atomic_support; - const auto &check_atomic_support_size8 = - check_atomic_support; - - auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - return py_reduction_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - prod_over_axis_strided_atomic_dispatch_table, - prod_over_axis_strided_temps_dispatch_table, - prod_over_axis0_contig_atomic_dispatch_table, - prod_over_axis1_contig_atomic_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto prod_dtype_supported = - [&](const py::dtype &input_dtype, const py::dtype &output_dtype, - const std::string &dst_usm_type, sycl::queue &q) { - return py_reduction_dtype_supported( - input_dtype, output_dtype, dst_usm_type, q, - prod_over_axis_strided_atomic_dispatch_table, - prod_over_axis_strided_temps_dispatch_table, - check_atomic_support_size4, check_atomic_support_size8); - }; - m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", - py::arg("arg_dtype"), py::arg("out_dtype"), - py::arg("dst_usm_type"), py::arg("sycl_queue")); - } - - // ARGMAX - { - using dpctl::tensor::py_internal::impl:: - populate_argmax_over_axis_dispatch_tables; - populate_argmax_over_axis_dispatch_tables(); - using impl::argmax_over_axis_strided_temps_dispatch_table; - - auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_search_over_axis; - return py_search_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - argmax_over_axis_strided_temps_dispatch_table); - }; - m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } - - // ARGMIN - { - using dpctl::tensor::py_internal::impl:: - populate_argmin_over_axis_dispatch_tables; - populate_argmin_over_axis_dispatch_tables(); - using impl::argmin_over_axis_strided_temps_dispatch_table; - - auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, - const arrayT &dst, sycl::queue &exec_q, - const event_vecT &depends = {}) { - using dpctl::tensor::py_internal::py_search_over_axis; - return py_search_over_axis( - src, trailing_dims_to_reduce, dst, exec_q, depends, - argmin_over_axis_strided_temps_dispatch_table); - }; - m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), - py::arg("trailing_dims_to_reduce"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - } -} - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp b/dpctl/tensor/libtensor/source/reduction_over_axis.hpp deleted file mode 100644 index e9ccd1d52a..0000000000 --- a/dpctl/tensor/libtensor/source/reduction_over_axis.hpp +++ /dev/null @@ -1,691 +0,0 @@ -//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===// -// -// Data Parallel Control (dpctl) -// -// Copyright 2020-2023 Intel Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines functions of dpctl.tensor._tensor_impl extensions, -/// specifically functions for reductions. -//===----------------------------------------------------------------------===// - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dpctl4pybind11.hpp" -#include -#include -#include - -#include "kernels/reductions.hpp" -#include "simplify_iteration_space.hpp" -#include "utils/memory_overlap.hpp" -#include "utils/offset_utils.hpp" -#include "utils/type_dispatch.hpp" - -namespace dpctl -{ -namespace tensor -{ -namespace py_internal -{ - -template -bool check_atomic_support(const sycl::queue &exec_q, - sycl::usm::alloc usm_alloc_type) -{ - bool supports_atomics = false; - - const sycl::device &dev = exec_q.get_device(); - - if constexpr (require_atomic64) { - if (!dev.has(sycl::aspect::atomic64)) - return false; - } - - switch (usm_alloc_type) { - case sycl::usm::alloc::shared: - supports_atomics = dev.has(sycl::aspect::usm_atomic_shared_allocations); - break; - case sycl::usm::alloc::host: - supports_atomics = dev.has(sycl::aspect::usm_atomic_host_allocations); - break; - case sycl::usm::alloc::device: - supports_atomics = true; - break; - default: - supports_atomics = false; - } - - return supports_atomics; -} - -template -bool fixed_decision(const sycl::queue &, sycl::usm::alloc) -{ - return return_value; -} - -/* ====================== dtype supported ======================== */ - -template -bool py_reduction_dtype_supported( - const py::dtype &input_dtype, - const py::dtype &output_dtype, - const std::string &dst_usm_type, - sycl::queue &q, - const fnT &atomic_dispatch_table, - const fnT &temps_dispatch_table, - const CheckAtomicSupportFnT &check_atomic_support_size4, - const CheckAtomicSupportFnT &check_atomic_support_size8) -{ - int arg_tn = - input_dtype.num(); // NumPy type numbers are the same as in dpctl - int out_tn = - output_dtype.num(); // NumPy type numbers are the same as in dpctl - int arg_typeid = -1; - int out_typeid = -1; - - auto array_types = td_ns::usm_ndarray_types(); - - try { - arg_typeid = array_types.typenum_to_lookup_id(arg_tn); - out_typeid = array_types.typenum_to_lookup_id(out_tn); - } catch (const std::exception &e) { - throw py::value_error(e.what()); - } - - if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || - out_typeid >= td_ns::num_types) - { - throw std::runtime_error("Reduction type support check: lookup failed"); - } - - // remove_all_extents gets underlying type of table - using fn_ptrT = typename std::remove_all_extents::type; - fn_ptrT fn = nullptr; - - sycl::usm::alloc kind = sycl::usm::alloc::unknown; - - if (dst_usm_type == "device") { - kind = sycl::usm::alloc::device; - } - else if (dst_usm_type == "shared") { - kind = sycl::usm::alloc::shared; - } - else if (dst_usm_type == "host") { - kind = sycl::usm::alloc::host; - } - else { - throw py::value_error("Unrecognized `dst_usm_type` argument."); - } - - bool supports_atomics = false; - - switch (output_dtype.itemsize()) { - case sizeof(float): - { - supports_atomics = check_atomic_support_size4(q, kind); - } break; - case sizeof(double): - { - supports_atomics = check_atomic_support_size8(q, kind); - } break; - } - - if (supports_atomics) { - fn = atomic_dispatch_table[arg_typeid][out_typeid]; - } - - if (fn == nullptr) { - // use slower reduction implementation using temporaries - fn = temps_dispatch_table[arg_typeid][out_typeid]; - } - - return (fn != nullptr); -} - -/* ==================== Generic reductions ====================== */ - -template -std::pair py_reduction_over_axis( - const dpctl::tensor::usm_ndarray &src, - int trailing_dims_to_reduce, // comp over this many trailing indexes - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends, - const strided_fnT &atomic_dispatch_table, - const strided_fnT &temps_dispatch_table, - const contig_fnT &axis0_dispatch_table, - const contig_fnT &axis1_dispatch_table, - const SupportAtomicFnT &check_atomic_support_size4, - const SupportAtomicFnT &check_atomic_support_size8) -{ - int src_nd = src.get_ndim(); - int iteration_nd = src_nd - trailing_dims_to_reduce; - if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { - throw py::value_error("Trailing_dim_to_reduce must be positive, but no " - "greater than rank of the array being reduced"); - } - - int dst_nd = dst.get_ndim(); - if (dst_nd != iteration_nd) { - throw py::value_error("Destination array rank does not match input " - "array rank and number of reduced dimensions"); - } - - const py::ssize_t *src_shape_ptr = src.get_shape_raw(); - const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); - - bool same_shapes = true; - for (int i = 0; same_shapes && (i < dst_nd); ++i) { - same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); - } - - if (!same_shapes) { - throw py::value_error("Destination shape does not match unreduced " - "dimensions of the input shape"); - } - - if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - size_t dst_nelems = dst.get_size(); - - size_t reduction_nelems(1); - for (int i = dst_nd; i < src_nd; ++i) { - reduction_nelems *= static_cast(src_shape_ptr[i]); - } - - // check that dst and src do not overlap - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(src, dst)) { - throw py::value_error("Arrays index overlapping segments of memory"); - } - - // destination must be ample enough to accommodate all elements - { - auto dst_offsets = dst.get_minmax_offsets(); - size_t range = - static_cast(dst_offsets.second - dst_offsets.first); - if (range + 1 < dst_nelems) { - throw py::value_error( - "Destination array can not accommodate all the " - "elements of source array."); - } - } - - int src_typenum = src.get_typenum(); - int dst_typenum = dst.get_typenum(); - - namespace td_ns = dpctl::tensor::type_dispatch; - const auto &array_types = td_ns::usm_ndarray_types(); - int src_typeid = array_types.typenum_to_lookup_id(src_typenum); - int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); - - int dst_itemsize = dst.get_elemsize(); - bool supports_atomics = false; - - switch (dst_itemsize) { - case sizeof(float): - { - void *data_ptr = dst.get_data(); - const auto &ctx = exec_q.get_context(); - auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - supports_atomics = check_atomic_support_size4(exec_q, usm_type); - } break; - case sizeof(double): - { - void *data_ptr = dst.get_data(); - const auto &ctx = exec_q.get_context(); - auto usm_type = sycl::get_pointer_type(data_ptr, ctx); - - supports_atomics = check_atomic_support_size8(exec_q, usm_type); - } break; - } - - // handle special case when both reduction and iteration are 1D contiguous - // and can be done with atomics - if (supports_atomics) { - bool is_src_c_contig = src.is_c_contiguous(); - bool is_dst_c_contig = dst.is_c_contiguous(); - bool is_src_f_contig = src.is_f_contiguous(); - - if ((is_src_c_contig && is_dst_c_contig) || - (is_src_f_contig && dst_nelems == 1)) - { - auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; - - if (fn != nullptr) { - size_t iter_nelems = dst_nelems; - - constexpr py::ssize_t zero_offset = 0; - - sycl::event reduction_over_axis_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), - zero_offset, // iteration_src_offset - zero_offset, // iteration_dst_offset - zero_offset, // reduction_src_offset - depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis_contig_ev); - } - } - else if (is_src_f_contig && - ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) - { - auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - size_t iter_nelems = dst_nelems; - - constexpr py::ssize_t zero_offset = 0; - - sycl::event reduction_over_axis_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), - zero_offset, // iteration_src_offset - zero_offset, // iteration_dst_offset - zero_offset, // reduction_src_offset - depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis_contig_ev); - } - } - } - - using dpctl::tensor::py_internal::simplify_iteration_space; - using dpctl::tensor::py_internal::simplify_iteration_space_1; - - auto const &src_shape_vecs = src.get_shape_vector(); - auto const &src_strides_vecs = src.get_strides_vector(); - auto const &dst_strides_vecs = dst.get_strides_vector(); - - int reduction_nd = trailing_dims_to_reduce; - const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; - using shT = std::vector; - shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, - std::end(src_strides_vecs)); - - shT simplified_reduction_shape; - shT simplified_reduction_src_strides; - py::ssize_t reduction_src_offset(0); - - simplify_iteration_space_1( - reduction_nd, reduction_shape_ptr, reduction_src_strides, - // output - simplified_reduction_shape, simplified_reduction_src_strides, - reduction_src_offset); - - const py::ssize_t *iteration_shape_ptr = src_shape_ptr; - - shT iteration_src_strides(std::begin(src_strides_vecs), - std::begin(src_strides_vecs) + iteration_nd); - shT const &iteration_dst_strides = dst_strides_vecs; - - shT simplified_iteration_shape; - shT simplified_iteration_src_strides; - shT simplified_iteration_dst_strides; - py::ssize_t iteration_src_offset(0); - py::ssize_t iteration_dst_offset(0); - - if (iteration_nd == 0) { - if (dst_nelems != 1) { - throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); - } - iteration_nd = 1; - simplified_iteration_shape.push_back(1); - simplified_iteration_src_strides.push_back(0); - simplified_iteration_dst_strides.push_back(0); - } - else { - simplify_iteration_space(iteration_nd, iteration_shape_ptr, - iteration_src_strides, iteration_dst_strides, - // output - simplified_iteration_shape, - simplified_iteration_src_strides, - simplified_iteration_dst_strides, - iteration_src_offset, iteration_dst_offset); - } - - if (supports_atomics && (reduction_nd == 1) && (iteration_nd == 1)) { - bool mat_reduce_over_axis1 = false; - bool mat_reduce_over_axis0 = false; - bool array_reduce_all_elems = false; - size_t iter_nelems = dst_nelems; - - if (simplified_reduction_src_strides[0] == 1) { - array_reduce_all_elems = (simplified_iteration_shape[0] == 1); - mat_reduce_over_axis1 = - (simplified_iteration_dst_strides[0] == 1) && - (static_cast(simplified_iteration_src_strides[0]) == - reduction_nelems); - } - else if (static_cast(simplified_reduction_src_strides[0]) == - iter_nelems) - { - mat_reduce_over_axis0 = - (simplified_iteration_dst_strides[0] == 1) && - (simplified_iteration_src_strides[0] == 1); - } - - if (mat_reduce_over_axis1 || array_reduce_all_elems) { - auto fn = axis1_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - sycl::event reduction_over_axis1_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_src_offset, - iteration_dst_offset, reduction_src_offset, depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis1_contig_ev); - } - } - else if (mat_reduce_over_axis0) { - auto fn = axis0_dispatch_table[src_typeid][dst_typeid]; - if (fn != nullptr) { - sycl::event reduction_over_axis0_contig_ev = - fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_src_offset, - iteration_dst_offset, reduction_src_offset, depends); - - sycl::event keep_args_event = dpctl::utils::keep_args_alive( - exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); - - return std::make_pair(keep_args_event, - reduction_over_axis0_contig_ev); - } - } - } - - // remove_all_extents gets underlying type of table - using strided_fn_ptr_T = - typename std::remove_all_extents::type; - strided_fn_ptr_T fn = nullptr; - - if (supports_atomics) { - fn = atomic_dispatch_table[src_typeid][dst_typeid]; - } - - if (fn == nullptr) { - // use slower reduction implementation using temporaries - fn = temps_dispatch_table[src_typeid][dst_typeid]; - if (fn == nullptr) { - throw std::runtime_error("Datatypes are not supported"); - } - } - - std::vector host_task_events{}; - - using dpctl::tensor::offset_utils::device_allocate_and_pack; - - const auto &arrays_metainfo_packing_triple_ = - device_allocate_and_pack( - exec_q, host_task_events, - // iteration metadata - simplified_iteration_shape, simplified_iteration_src_strides, - simplified_iteration_dst_strides, - // reduction metadata - simplified_reduction_shape, simplified_reduction_src_strides); - py::ssize_t *temp_allocation_ptr = - std::get<0>(arrays_metainfo_packing_triple_); - if (temp_allocation_ptr == nullptr) { - throw std::runtime_error("Unable to allocate memory on device"); - } - const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); - - py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; - py::ssize_t *reduction_shape_stride = - temp_allocation_ptr + 3 * simplified_iteration_shape.size(); - - std::vector all_deps; - all_deps.reserve(depends.size() + 1); - all_deps.resize(depends.size()); - std::copy(depends.begin(), depends.end(), all_deps.begin()); - all_deps.push_back(copy_metadata_ev); - - auto reduction_ev = - fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), - iteration_nd, iter_shape_and_strides, iteration_src_offset, - iteration_dst_offset, - reduction_nd, // number dimensions being reduced - reduction_shape_stride, reduction_src_offset, all_deps); - - sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(reduction_ev); - const auto &ctx = exec_q.get_context(); - cgh.host_task([ctx, temp_allocation_ptr] { - sycl::free(temp_allocation_ptr, ctx); - }); - }); - host_task_events.push_back(temp_cleanup_ev); - - sycl::event keep_args_event = - dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); - - return std::make_pair(keep_args_event, reduction_ev); -} - -/* ==================== Search reductions ====================== */ - -template -std::pair py_search_over_axis( - const dpctl::tensor::usm_ndarray &src, - int trailing_dims_to_reduce, // comp over this many trailing indexes - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends, - const fn_tableT &dispatch_table) -{ - int src_nd = src.get_ndim(); - int iteration_nd = src_nd - trailing_dims_to_reduce; - if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { - throw py::value_error("Trailing_dim_to_reduce must be positive, but no " - "greater than rank of the array being reduced"); - } - - int dst_nd = dst.get_ndim(); - if (dst_nd != iteration_nd) { - throw py::value_error("Destination array rank does not match input " - "array rank and number of reduced dimensions"); - } - - const py::ssize_t *src_shape_ptr = src.get_shape_raw(); - const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); - - bool same_shapes = true; - for (int i = 0; same_shapes && (i < dst_nd); ++i) { - same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); - } - - if (!same_shapes) { - throw py::value_error("Destination shape does not match unreduced " - "dimensions of the input shape"); - } - - if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - size_t dst_nelems = dst.get_size(); - - size_t reduction_nelems(1); - for (int i = dst_nd; i < src_nd; ++i) { - reduction_nelems *= static_cast(src_shape_ptr[i]); - } - - // check that dst and src do not overlap - auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); - if (overlap(src, dst)) { - throw py::value_error("Arrays index overlapping segments of memory"); - } - - // destination must be ample enough to accommodate all elements - { - auto dst_offsets = dst.get_minmax_offsets(); - size_t range = - static_cast(dst_offsets.second - dst_offsets.first); - if (range + 1 < dst_nelems) { - throw py::value_error( - "Destination array can not accommodate all the " - "elements of source array."); - } - } - - int src_typenum = src.get_typenum(); - int dst_typenum = dst.get_typenum(); - - namespace td_ns = dpctl::tensor::type_dispatch; - const auto &array_types = td_ns::usm_ndarray_types(); - int src_typeid = array_types.typenum_to_lookup_id(src_typenum); - int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); - - using dpctl::tensor::py_internal::simplify_iteration_space; - using dpctl::tensor::py_internal::simplify_iteration_space_1; - - auto const &src_shape_vecs = src.get_shape_vector(); - auto const &src_strides_vecs = src.get_strides_vector(); - auto const &dst_strides_vecs = dst.get_strides_vector(); - - int reduction_nd = trailing_dims_to_reduce; - const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; - using shT = std::vector; - shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, - std::end(src_strides_vecs)); - - shT compact_reduction_shape; - shT compact_reduction_src_strides; - py::ssize_t reduction_src_offset(0); - - compact_iteration_space( - reduction_nd, reduction_shape_ptr, reduction_src_strides, - // output - compact_reduction_shape, compact_reduction_src_strides); - - const py::ssize_t *iteration_shape_ptr = src_shape_ptr; - - shT iteration_src_strides(std::begin(src_strides_vecs), - std::begin(src_strides_vecs) + iteration_nd); - shT const &iteration_dst_strides = dst_strides_vecs; - - shT simplified_iteration_shape; - shT simplified_iteration_src_strides; - shT simplified_iteration_dst_strides; - py::ssize_t iteration_src_offset(0); - py::ssize_t iteration_dst_offset(0); - - if (iteration_nd == 0) { - if (dst_nelems != 1) { - throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); - } - iteration_nd = 1; - simplified_iteration_shape.push_back(1); - simplified_iteration_src_strides.push_back(0); - simplified_iteration_dst_strides.push_back(0); - } - else { - simplify_iteration_space(iteration_nd, iteration_shape_ptr, - iteration_src_strides, iteration_dst_strides, - // output - simplified_iteration_shape, - simplified_iteration_src_strides, - simplified_iteration_dst_strides, - iteration_src_offset, iteration_dst_offset); - } - - auto fn = dispatch_table[src_typeid][dst_typeid]; - if (fn == nullptr) { - throw std::runtime_error("Datatypes are not supported"); - } - - std::vector host_task_events{}; - - using dpctl::tensor::offset_utils::device_allocate_and_pack; - - const auto &arrays_metainfo_packing_triple_ = - device_allocate_and_pack( - exec_q, host_task_events, - // iteration metadata - simplified_iteration_shape, simplified_iteration_src_strides, - simplified_iteration_dst_strides, - // reduction metadata - compact_reduction_shape, compact_reduction_src_strides); - py::ssize_t *temp_allocation_ptr = - std::get<0>(arrays_metainfo_packing_triple_); - if (temp_allocation_ptr == nullptr) { - throw std::runtime_error("Unable to allocate memory on device"); - } - const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); - - py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; - py::ssize_t *reduction_shape_stride = - temp_allocation_ptr + 3 * simplified_iteration_shape.size(); - - std::vector all_deps; - all_deps.reserve(depends.size() + 1); - all_deps.resize(depends.size()); - std::copy(depends.begin(), depends.end(), all_deps.begin()); - all_deps.push_back(copy_metadata_ev); - - auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), - dst.get_data(), iteration_nd, iter_shape_and_strides, - iteration_src_offset, iteration_dst_offset, - reduction_nd, // number dimensions being reduced - reduction_shape_stride, reduction_src_offset, all_deps); - - sycl::event temp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(comp_ev); - const auto &ctx = exec_q.get_context(); - cgh.host_task([ctx, temp_allocation_ptr] { - sycl::free(temp_allocation_ptr, ctx); - }); - }); - host_task_events.push_back(temp_cleanup_ev); - - sycl::event keep_args_event = - dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); - - return std::make_pair(keep_args_event, comp_ev); -} - -extern void init_reduction_functions(py::module_ m); - -} // namespace py_internal -} // namespace tensor -} // namespace dpctl