diff --git a/dpctl/tensor/_clip.py b/dpctl/tensor/_clip.py index c8eea607f9..f53c888e5f 100644 --- a/dpctl/tensor/_clip.py +++ b/dpctl/tensor/_clip.py @@ -206,7 +206,6 @@ def _clip_none(x, val, out, order, _binary_fn): ) _manager.add_event_pair(ht_copy_out_ev, copy_ev) out = orig_out - ht_binary_ev.wait() return out else: if order == "K": diff --git a/dpctl/tensor/_statistical_functions.py b/dpctl/tensor/_statistical_functions.py index 2779d3b0fc..457a8c1a38 100644 --- a/dpctl/tensor/_statistical_functions.py +++ b/dpctl/tensor/_statistical_functions.py @@ -93,16 +93,13 @@ def _var_impl(x, axis, correction, keepdims): ) # divide in-place to get mean mean_ary_shape = mean_ary.shape - nelems_ary = dpt.asarray( - nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) - if nelems_ary.shape != mean_ary_shape: - nelems_ary = dpt.broadcast_to(nelems_ary, mean_ary_shape) + dep_evs = _manager.submitted_events - ht_e2, d_e1 = tei._divide_inplace( - lhs=mean_ary, rhs=nelems_ary, sycl_queue=q, depends=dep_evs + ht_e2, d_e1 = tei._divide_by_scalar( + src=mean_ary, scalar=nelems, dst=mean_ary, sycl_queue=q, depends=dep_evs ) _manager.add_event_pair(ht_e2, d_e1) + # subtract mean from original array to get deviations dev_ary = dpt.empty_like(buf) if mean_ary_shape != buf.shape: @@ -146,15 +143,9 @@ def _var_impl(x, axis, correction, keepdims): div = max(nelems - correction, 0) if not div: div = dpt.nan - div_ary = dpt.asarray( - div, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q - ) - # divide in-place again - if div_ary.shape != res_shape: - div_ary = dpt.broadcast_to(div_ary, res.shape) dep_evs = _manager.submitted_events - ht_e7, d_e2 = tei._divide_inplace( - lhs=res, rhs=div_ary, sycl_queue=q, depends=dep_evs + ht_e7, d_e2 = tei._divide_by_scalar( + src=res, scalar=div, dst=res, sycl_queue=q, depends=dep_evs ) _manager.add_event_pair(ht_e7, d_e2) return res, [d_e2] @@ -259,17 +250,9 @@ def mean(x, axis=None, keepdims=False): inv_perm = sorted(range(nd), key=lambda d: perm[d]) res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm) - res_shape = res.shape - # in-place divide - den_dt = dpt.finfo(res_dt).dtype if res_dt.kind == "c" else res_dt - nelems_arr = dpt.asarray( - nelems, dtype=den_dt, usm_type=res_usm_type, sycl_queue=q - ) - if nelems_arr.shape != res_shape: - nelems_arr = dpt.broadcast_to(nelems_arr, res_shape) dep_evs = _manager.submitted_events - ht_e2, div_e = tei._divide_inplace( - lhs=res, rhs=nelems_arr, sycl_queue=q, depends=dep_evs + ht_e2, div_e = tei._divide_by_scalar( + src=res, scalar=nelems, dst=res, sycl_queue=q, depends=dep_evs ) _manager.add_event_pair(ht_e2, div_e) return res diff --git a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp index ffb2afc3ea..dd7168beb1 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions/true_divide.cpp @@ -24,14 +24,21 @@ //===----------------------------------------------------------------------===// #include "dpctl4pybind11.hpp" +#include +#include #include #include #include #include +#include #include #include "elementwise_functions.hpp" +#include "simplify_iteration_space.hpp" #include "true_divide.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" #include "utils/type_dispatch.hpp" #include "kernels/elementwise_functions/common.hpp" @@ -165,6 +172,247 @@ void populate_true_divide_dispatch_tables(void) dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table); }; +template class divide_by_scalar_krn; + +typedef sycl::event (*divide_by_scalar_fn_ptr_t)( + sycl::queue &, + size_t, + int, + const ssize_t *, + const char *, + py::ssize_t, + const char *, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event divide_by_scalar(sycl::queue &exec_q, + size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + const char *scalar_ptr, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + const scalarT sc_v = *reinterpret_cast(scalar_ptr); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using BinOpT = + dpctl::tensor::kernels::true_divide::TrueDivideFunctor; + + auto op = BinOpT(); + + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const IndexerT two_offsets_indexer{nd, arg_offset, res_offset, + shape_and_strides}; + + const T *arg_tp = reinterpret_cast(arg_p); + T *res_tp = reinterpret_cast(res_p); + + cgh.parallel_for>( + {nelems}, [=](sycl::id<1> id) { + const auto &two_offsets_ = + two_offsets_indexer(static_cast(id.get(0))); + + const auto &arg_i = two_offsets_.get_first_offset(); + const auto &res_i = two_offsets_.get_second_offset(); + res_tp[res_i] = op(arg_tp[arg_i], sc_v); + }); + }); + return comp_ev; +} + +std::pair +py_divide_by_scalar(const dpctl::tensor::usm_ndarray &src, + double scalar, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends = {}) +{ + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error( + "Destination array has unexpected elemental data type."); + } + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + // check shapes, broadcasting is assumed done by caller + // check that dimensions are the same + int dst_nd = dst.get_ndim(); + if (dst_nd != src.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + // check that shapes are the same + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + bool shapes_equal(true); + size_t src_nelems(1); + + for (int i = 0; i < dst_nd; ++i) { + src_nelems *= static_cast(src_shape[i]); + shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if ((overlap(src, dst) && !same_logical_tensors(src, dst))) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + constexpr int float16_typeid = static_cast(td_ns::typenum_t::HALF); + constexpr int float32_typeid = static_cast(td_ns::typenum_t::FLOAT); + constexpr int float64_typeid = static_cast(td_ns::typenum_t::DOUBLE); + constexpr int complex64_typeid = static_cast(td_ns::typenum_t::CFLOAT); + constexpr int complex128_typeid = + static_cast(td_ns::typenum_t::CDOUBLE); + + // statically pre-allocated memory for scalar + alignas(double) char scalar_alloc[sizeof(double)] = {0}; + + divide_by_scalar_fn_ptr_t fn; + // placement new into stack memory means no call to delete is necessary + switch (src_typeid) { + case float16_typeid: + { + fn = divide_by_scalar; + std::ignore = + new (scalar_alloc) sycl::half(static_cast(scalar)); + break; + } + case float32_typeid: + { + fn = divide_by_scalar; + std::ignore = new (scalar_alloc) float(scalar); + break; + } + case float64_typeid: + { + fn = divide_by_scalar; + std::ignore = new (scalar_alloc) double(scalar); + break; + } + case complex64_typeid: + { + fn = divide_by_scalar, float>; + std::ignore = new (scalar_alloc) float(scalar); + break; + } + case complex128_typeid: + { + fn = divide_by_scalar, double>; + std::ignore = new (scalar_alloc) double(scalar); + break; + } + default: + throw std::runtime_error("Implementation is missing for typeid=" + + std::to_string(src_typeid)); + } + + // simplify strides + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t dst_offset(0); + + int nd = dst_nd; + const py::ssize_t *shape = src_shape; + + std::vector host_tasks{}; + dpctl::tensor::py_internal::simplify_iteration_space( + nd, shape, src_strides, dst_strides, + // outputs + simplified_shape, simplified_src_strides, simplified_dst_strides, + src_offset, dst_offset); + + if (nd == 0) { + // handle 0d array as 1d array with 1 element + constexpr py::ssize_t one{1}; + simplified_shape.push_back(one); + simplified_src_strides.push_back(one); + simplified_dst_strides.push_back(one); + src_offset = 0; + dst_offset = 0; + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + const auto &ptr_sz_event_triple_ = device_allocate_and_pack( + exec_q, host_tasks, simplified_shape, simplified_src_strides, + simplified_dst_strides); + + py::ssize_t *shape_strides = std::get<0>(ptr_sz_event_triple_); + const sycl::event ©_metadata_ev = std::get<2>(ptr_sz_event_triple_); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + if (shape_strides == nullptr) { + throw std::runtime_error("Unable to allocate device memory"); + } + + sycl::event div_ev = + fn(exec_q, src_nelems, nd, shape_strides, src_data, src_offset, + scalar_alloc, dst_data, dst_offset, all_deps); + + // async free of shape_strides temporary + auto ctx = exec_q.get_context(); + + sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(div_ev); + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task( + [ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); }); + }); + + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_tasks), div_ev); +} + } // namespace impl void init_divide(py::module_ m) @@ -233,6 +481,11 @@ void init_divide(py::module_ m) m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using impl::py_divide_by_scalar; + m.def("_divide_by_scalar", &py_divide_by_scalar, "", py::arg("src"), + py::arg("scalar"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); } } diff --git a/dpctl/tests/elementwise/test_divide.py b/dpctl/tests/elementwise/test_divide.py index 610d0ccf31..d6b7d15201 100644 --- a/dpctl/tests/elementwise/test_divide.py +++ b/dpctl/tests/elementwise/test_divide.py @@ -21,8 +21,10 @@ import dpctl import dpctl.tensor as dpt +from dpctl.tensor._tensor_elementwise_impl import _divide_by_scalar from dpctl.tensor._type_utils import _can_cast from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported +from dpctl.utils import SequentialOrderManager from .utils import ( _all_dtypes, @@ -271,3 +273,26 @@ def test_divide_gh_1711(): assert isinstance(res, dpt.usm_ndarray) assert res.dtype.kind == "f" assert dpt.allclose(res, dpt.asarray(3, dtype="i4") / -2) + + +# don't test for overflowing double as Python won't cast +# an Python integer of that size to a Python float +@pytest.mark.parametrize("fp_dt", [dpt.float16, dpt.float32]) +def test_divide_by_scalar_overflow(fp_dt): + q = get_queue_or_skip() + skip_if_dtype_not_supported(fp_dt, q) + + x = dpt.ones(10, dtype=fp_dt, sycl_queue=q) + out = dpt.empty_like(x) + + max_exp = np.finfo(fp_dt).maxexp + sca = 2**max_exp + + _manager = SequentialOrderManager[q] + dep_evs = _manager.submitted_events + _, ev = _divide_by_scalar( + src=x, scalar=sca, dst=out, sycl_queue=q, depends=dep_evs + ) + ev.wait() + + assert dpt.all(out == 0)