Skip to content

Commit 74715ae

Browse files
committed
Align with changes in dpctl::tensor::offset_utils::device_allocate_and_pack
1 parent 23c4907 commit 74715ae

File tree

1 file changed

+19
-44
lines changed

1 file changed

+19
-44
lines changed

dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -218,28 +218,21 @@ std::pair<sycl::event, sycl::event>
218218
std::vector<sycl::event> host_tasks{};
219219
host_tasks.reserve(2);
220220

221-
const auto &ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
221+
auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
222222
q, host_tasks, simplified_shape, simplified_src_strides,
223223
simplified_dst_strides);
224-
py::ssize_t *shape_strides = std::get<0>(ptr_size_event_triple_);
225-
const sycl::event &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
226-
227-
if (shape_strides == nullptr) {
228-
throw std::runtime_error("Device memory allocation failed");
229-
}
224+
auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
225+
const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
226+
const py::ssize_t *shape_strides = shape_strides_owner.get();
230227

231228
sycl::event strided_fn_ev =
232229
strided_fn(q, src_nelems, nd, shape_strides, src_data, src_offset,
233230
dst_data, dst_offset, depends, {copy_shape_ev});
234231

235232
// async free of shape_strides temporary
236-
auto ctx = q.get_context();
237-
sycl::event tmp_cleanup_ev = q.submit([&](sycl::handler &cgh) {
238-
cgh.depends_on(strided_fn_ev);
239-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
240-
cgh.host_task(
241-
[ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); });
242-
});
233+
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
234+
q, {strided_fn_ev}, shape_strides_owner);
235+
243236
host_tasks.push_back(tmp_cleanup_ev);
244237

245238
return std::make_pair(
@@ -543,30 +536,21 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
543536
}
544537

545538
using dpctl::tensor::offset_utils::device_allocate_and_pack;
546-
const auto &ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
539+
auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
547540
exec_q, host_tasks, simplified_shape, simplified_src1_strides,
548541
simplified_src2_strides, simplified_dst_strides);
542+
auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
543+
auto &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
549544

550-
py::ssize_t *shape_strides = std::get<0>(ptr_sz_event_triple_);
551-
const sycl::event &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
552-
553-
if (shape_strides == nullptr) {
554-
throw std::runtime_error("Unable to allocate device memory");
555-
}
545+
const py::ssize_t *shape_strides = shape_strides_owner.get();
556546

557547
sycl::event strided_fn_ev = strided_fn(
558548
exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset,
559549
src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev});
560550

561551
// async free of shape_strides temporary
562-
auto ctx = exec_q.get_context();
563-
564-
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
565-
cgh.depends_on(strided_fn_ev);
566-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
567-
cgh.host_task(
568-
[ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); });
569-
});
552+
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
553+
exec_q, {strided_fn_ev}, shape_strides_owner);
570554

571555
host_tasks.push_back(tmp_cleanup_ev);
572556

@@ -796,30 +780,21 @@ std::pair<sycl::event, sycl::event>
796780
}
797781

798782
using dpctl::tensor::offset_utils::device_allocate_and_pack;
799-
const auto &ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
783+
auto ptr_sz_event_triple_ = device_allocate_and_pack<py::ssize_t>(
800784
exec_q, host_tasks, simplified_shape, simplified_rhs_strides,
801785
simplified_lhs_strides);
786+
auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_));
787+
auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
802788

803-
py::ssize_t *shape_strides = std::get<0>(ptr_sz_event_triple_);
804-
const sycl::event &copy_shape_ev = std::get<2>(ptr_sz_event_triple_);
805-
806-
if (shape_strides == nullptr) {
807-
throw std::runtime_error("Unable to allocate device memory");
808-
}
789+
const py::ssize_t *shape_strides = shape_strides_owner.get();
809790

810791
sycl::event strided_fn_ev =
811792
strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset,
812793
lhs_data, lhs_offset, depends, {copy_shape_ev});
813794

814795
// async free of shape_strides temporary
815-
auto ctx = exec_q.get_context();
816-
817-
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
818-
cgh.depends_on(strided_fn_ev);
819-
using dpctl::tensor::alloc_utils::sycl_free_noexcept;
820-
cgh.host_task(
821-
[ctx, shape_strides]() { sycl_free_noexcept(shape_strides, ctx); });
822-
});
796+
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
797+
exec_q, {strided_fn_ev}, shape_strides_owner);
823798

824799
host_tasks.push_back(tmp_cleanup_ev);
825800

0 commit comments

Comments
 (0)