@@ -1026,18 +1026,15 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
1026
1026
(reduction_groups + preferred_reductions_per_wi * wg - 1 ) /
1027
1027
(preferred_reductions_per_wi * wg);
1028
1028
1029
- resTy *partially_reduced_tmp = sycl::malloc_device<resTy>(
1030
- batches * (reduction_groups + second_iter_reduction_groups_),
1031
- exec_q);
1032
- resTy *partially_reduced_tmp2 = nullptr ;
1029
+ // returns unique_ptr
1030
+ auto partially_reduced_tmp_owner =
1031
+ dpctl::tensor::alloc_utils::smart_malloc_device<resTy>(
1032
+ batches * (reduction_groups + second_iter_reduction_groups_),
1033
+ exec_q);
1033
1034
1034
- if (partially_reduced_tmp == nullptr ) {
1035
- throw std::runtime_error (" Unable to allocate device_memory" );
1036
- }
1037
- else {
1038
- partially_reduced_tmp2 =
1039
- partially_reduced_tmp + reduction_groups * batches;
1040
- }
1035
+ resTy *partially_reduced_tmp = partially_reduced_tmp_owner.get ();
1036
+ resTy *partially_reduced_tmp2 =
1037
+ partially_reduced_tmp + reduction_groups * batches;
1041
1038
1042
1039
sycl::event first_reduction_ev;
1043
1040
{
@@ -1152,16 +1149,10 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
1152
1149
remaining_reduction_nelems, reductions_per_wi, reduction_groups,
1153
1150
in_out_iter_indexer, reduction_indexer, {dependent_ev});
1154
1151
1152
+ // transfer ownership of USM allocation to host_task
1155
1153
sycl::event cleanup_host_task_event =
1156
- exec_q.submit ([&](sycl::handler &cgh) {
1157
- cgh.depends_on (final_reduction_ev);
1158
- const sycl::context &ctx = exec_q.get_context ();
1159
-
1160
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1161
- cgh.host_task ([ctx, partially_reduced_tmp] {
1162
- sycl_free_noexcept (partially_reduced_tmp, ctx);
1163
- });
1164
- });
1154
+ dpctl::tensor::alloc_utils::async_smart_free (
1155
+ exec_q, {final_reduction_ev}, partially_reduced_tmp_owner);
1165
1156
1166
1157
return cleanup_host_task_event;
1167
1158
}
@@ -1282,18 +1273,15 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
1282
1273
(reduction_groups + preferred_reductions_per_wi * wg - 1 ) /
1283
1274
(preferred_reductions_per_wi * wg);
1284
1275
1285
- resTy *partially_reduced_tmp = sycl::malloc_device<resTy>(
1286
- batches * (reduction_groups + second_iter_reduction_groups_),
1287
- exec_q);
1288
- resTy *partially_reduced_tmp2 = nullptr ;
1289
-
1290
- if (partially_reduced_tmp == nullptr ) {
1291
- throw std::runtime_error (" Unable to allocate device_memory" );
1292
- }
1293
- else {
1294
- partially_reduced_tmp2 =
1295
- partially_reduced_tmp + reduction_groups * batches;
1296
- }
1276
+ // unique_ptr that owns temporary allocation for partial reductions
1277
+ auto partially_reduced_tmp_owner =
1278
+ dpctl::tensor::alloc_utils::smart_malloc_device<resTy>(
1279
+ batches * (reduction_groups + second_iter_reduction_groups_),
1280
+ exec_q);
1281
+ // get raw pointers
1282
+ resTy *partially_reduced_tmp = partially_reduced_tmp_owner.get ();
1283
+ resTy *partially_reduced_tmp2 =
1284
+ partially_reduced_tmp + reduction_groups * batches;
1297
1285
1298
1286
sycl::event first_reduction_ev;
1299
1287
{
@@ -1401,15 +1389,8 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
1401
1389
in_out_iter_indexer, reduction_indexer, {dependent_ev});
1402
1390
1403
1391
sycl::event cleanup_host_task_event =
1404
- exec_q.submit ([&](sycl::handler &cgh) {
1405
- cgh.depends_on (final_reduction_ev);
1406
- const sycl::context &ctx = exec_q.get_context ();
1407
-
1408
- using dpctl::tensor::alloc_utils::sycl_free_noexcept;
1409
- cgh.host_task ([ctx, partially_reduced_tmp] {
1410
- sycl_free_noexcept (partially_reduced_tmp, ctx);
1411
- });
1412
- });
1392
+ dpctl::tensor::alloc_utils::async_smart_free (
1393
+ exec_q, {final_reduction_ev}, partially_reduced_tmp_owner);
1413
1394
1414
1395
return cleanup_host_task_event;
1415
1396
}
0 commit comments