Skip to content

Commit cfc9b40

Browse files
Merge pull request #1167 from IntelPython/fix-take
Fix take
2 parents 51369c3 + d839ea1 commit cfc9b40

9 files changed

+121
-38
lines changed

dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,9 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
493493
(void)dep_event_vec_ref;
494494

495495
DPCTLSyclEventRef event_ref = nullptr;
496+
DPCTLSyclEventRef e1_ref = nullptr;
497+
DPCTLSyclEventRef e2_ref = nullptr;
498+
DPCTLSyclEventRef e3_ref = nullptr;
496499

497500
if ((input1_in == nullptr) || (result1_out == nullptr))
498501
{
@@ -514,29 +517,36 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
514517
_DataType* min_arr = reinterpret_cast<_DataType*>(sycl::malloc_shared(result_size * sizeof(_DataType), q));
515518
_DataType* max_arr = reinterpret_cast<_DataType*>(sycl::malloc_shared(result_size * sizeof(_DataType), q));
516519

517-
dpnp_min_c<_DataType>(arr, min_arr, result_size, input_shape, input_ndim, axis, naxis);
518-
dpnp_max_c<_DataType>(arr, max_arr, result_size, input_shape, input_ndim, axis, naxis);
520+
e1_ref = dpnp_min_c<_DataType>(q_ref, arr, min_arr, result_size, input_shape, input_ndim, axis, naxis, NULL);
521+
e2_ref = dpnp_max_c<_DataType>(q_ref, arr, max_arr, result_size, input_shape, input_ndim, axis, naxis, NULL);
519522

520523
shape_elem_type* _strides =
521524
reinterpret_cast<shape_elem_type*>(sycl::malloc_shared(result_ndim * sizeof(shape_elem_type), q));
522525
get_shape_offsets_inkernel(result_shape, result_ndim, _strides);
523526

524-
dpnp_subtract_c<_DataType, _DataType, _DataType>(result,
525-
result_size,
526-
result_ndim,
527-
result_shape,
528-
result_strides,
529-
max_arr,
530-
result_size,
531-
result_ndim,
532-
result_shape,
533-
_strides,
534-
min_arr,
535-
result_size,
536-
result_ndim,
537-
result_shape,
538-
_strides,
539-
NULL);
527+
e3_ref = dpnp_subtract_c<_DataType, _DataType, _DataType>(q_ref, result,
528+
result_size,
529+
result_ndim,
530+
result_shape,
531+
result_strides,
532+
max_arr,
533+
result_size,
534+
result_ndim,
535+
result_shape,
536+
_strides,
537+
min_arr,
538+
result_size,
539+
result_ndim,
540+
result_shape,
541+
_strides,
542+
NULL, NULL);
543+
544+
DPCTLEvent_Wait(e1_ref);
545+
DPCTLEvent_Wait(e2_ref);
546+
DPCTLEvent_Wait(e3_ref);
547+
DPCTLEvent_Delete(e1_ref);
548+
DPCTLEvent_Delete(e2_ref);
549+
DPCTLEvent_Delete(e3_ref);
540550

541551
sycl::free(min_arr, q);
542552
sycl::free(max_arr, q);
@@ -576,6 +586,7 @@ void dpnp_ptp_c(void* result1_out,
576586
naxis,
577587
dep_event_vec_ref);
578588
DPCTLEvent_WaitAndThrow(event_ref);
589+
DPCTLEvent_Delete(event_ref);
579590
}
580591

581592
template <typename _DataType>

dpnp/backend/kernels/dpnp_krnl_bitwise.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,16 +148,16 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
148148
\
149149
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
150150
\
151-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
152-
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
153-
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
151+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
152+
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
153+
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
154154
\
155-
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, input2_in, input2_size); \
156-
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
157-
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
155+
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, input2_in, input2_size); \
156+
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
157+
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
158158
\
159-
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
160-
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
159+
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
160+
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
161161
\
162162
_DataType* input1_data = input1_ptr.get_ptr(); \
163163
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \
@@ -226,6 +226,14 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
226226
}; \
227227
event = q.submit(kernel_func); \
228228
} \
229+
input1_ptr.depends_on(event); \
230+
input1_shape_ptr.depends_on(event); \
231+
input1_strides_ptr.depends_on(event); \
232+
input2_ptr.depends_on(event); \
233+
input2_shape_ptr.depends_on(event); \
234+
input2_strides_ptr.depends_on(event); \
235+
result_ptr.depends_on(event); \
236+
result_strides_ptr.depends_on(event); \
229237
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
230238
\
231239
return DPCTLEvent_Copy(event_ref); \

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@
143143
} \
144144
} \
145145
\
146+
input1_ptr.depends_on(event); \
147+
input1_shape_ptr.depends_on(event); \
148+
input1_strides_ptr.depends_on(event); \
149+
result_ptr.depends_on(event); \
150+
result_strides_ptr.depends_on(event); \
151+
\
146152
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
147153
\
148154
return DPCTLEvent_Copy(event_ref); \
@@ -644,6 +650,12 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
644650
} \
645651
} \
646652
\
653+
input1_ptr.depends_on(event); \
654+
input1_shape_ptr.depends_on(event); \
655+
input1_strides_ptr.depends_on(event); \
656+
result_ptr.depends_on(event); \
657+
result_strides_ptr.depends_on(event); \
658+
\
647659
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
648660
\
649661
return DPCTLEvent_Copy(event_ref); \
@@ -998,6 +1010,17 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
9981010
event = q.submit(kernel_func); \
9991011
} \
10001012
} \
1013+
\
1014+
input1_ptr.depends_on(event); \
1015+
input1_shape_ptr.depends_on(event); \
1016+
input1_strides_ptr.depends_on(event); \
1017+
input2_ptr.depends_on(event); \
1018+
input2_shape_ptr.depends_on(event); \
1019+
input2_strides_ptr.depends_on(event); \
1020+
result_ptr.depends_on(event); \
1021+
result_shape_ptr.depends_on(event); \
1022+
result_strides_ptr.depends_on(event); \
1023+
\
10011024
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
10021025
\
10031026
return DPCTLEvent_Copy(event_ref); \

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -901,7 +901,7 @@ DPCTLSyclEventRef dpnp_take_c(DPCTLSyclQueueRef q_ref,
901901
DPCTLSyclEventRef event_ref = nullptr;
902902
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
903903

904-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, array1_size, true);
904+
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, array1_size);
905905
DPNPC_ptr_adapter<_IndecesType> input2_ptr(q_ref, indices1, size);
906906
_DataType* array_1 = input1_ptr.get_ptr();
907907
_IndecesType* indices = input2_ptr.get_ptr();

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
170170
event = q.submit(kernel_func);
171171
}
172172

173+
input1_ptr.depends_on(event);
174+
result1_ptr.depends_on(event);
173175
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
174176

175177
return DPCTLEvent_Copy(event_ref);
@@ -483,6 +485,8 @@ DPCTLSyclEventRef dpnp_ediff1d_c(DPCTLSyclQueueRef q_ref,
483485
};
484486
event = q.submit(kernel_func);
485487

488+
input1_ptr.depends_on(event);
489+
result_ptr.depends_on(event);
486490
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
487491

488492
return DPCTLEvent_Copy(event_ref);
@@ -676,6 +680,7 @@ void dpnp_floor_divide_c(void* result_out,
676680
where,
677681
dep_event_vec_ref);
678682
DPCTLEvent_WaitAndThrow(event_ref);
683+
DPCTLEvent_Delete(event_ref);
679684
}
680685

681686
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
@@ -770,6 +775,7 @@ void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t s
770775
size,
771776
dep_event_vec_ref);
772777
DPCTLEvent_WaitAndThrow(event_ref);
778+
DPCTLEvent_Delete(event_ref);
773779
}
774780

775781
template <typename _DataType_input, typename _DataType_output>
@@ -911,6 +917,7 @@ void dpnp_remainder_c(void* result_out,
911917
where,
912918
dep_event_vec_ref);
913919
DPCTLEvent_WaitAndThrow(event_ref);
920+
DPCTLEvent_Delete(event_ref);
914921
}
915922

916923
template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
@@ -1041,6 +1048,7 @@ void dpnp_trapz_c(
10411048
array2_size,
10421049
dep_event_vec_ref);
10431050
DPCTLEvent_WaitAndThrow(event_ref);
1051+
DPCTLEvent_Delete(event_ref);
10441052
}
10451053

10461054
template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>

dpnp/backend/kernels/dpnp_krnl_reduction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ void dpnp_sum_c(void* result_out,
162162
where,
163163
dep_event_vec_ref);
164164
DPCTLEvent_WaitAndThrow(event_ref);
165+
DPCTLEvent_Delete(event_ref);
165166
}
166167

167168
template <typename _DataType_output, typename _DataType_input>
@@ -278,6 +279,7 @@ void dpnp_prod_c(void* result_out,
278279
where,
279280
dep_event_vec_ref);
280281
DPCTLEvent_WaitAndThrow(event_ref);
282+
DPCTLEvent_Delete(event_ref);
281283
}
282284

283285
template <typename _DataType_output, typename _DataType_input>

dpnp/backend/kernels/dpnp_krnl_sorting.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ void dpnp_argsort_c(void* array1_in, void* result1, size_t size)
9191
size,
9292
dep_event_vec_ref);
9393
DPCTLEvent_WaitAndThrow(event_ref);
94+
DPCTLEvent_Delete(event_ref);
9495
}
9596

9697
template <typename _DataType, typename _idx_DataType>
@@ -242,6 +243,7 @@ void dpnp_partition_c(
242243
ndim,
243244
dep_event_vec_ref);
244245
DPCTLEvent_WaitAndThrow(event_ref);
246+
DPCTLEvent_Delete(event_ref);
245247
}
246248

247249
template <typename _DataType>
@@ -394,6 +396,7 @@ void dpnp_searchsorted_c(
394396
v_size,
395397
dep_event_vec_ref);
396398
DPCTLEvent_WaitAndThrow(event_ref);
399+
DPCTLEvent_Delete(event_ref);
397400
}
398401

399402
template <typename _DataType, typename _IndexingType>
@@ -459,6 +462,7 @@ void dpnp_sort_c(void* array1_in, void* result1, size_t size)
459462
size,
460463
dep_event_vec_ref);
461464
DPCTLEvent_WaitAndThrow(event_ref);
465+
DPCTLEvent_Delete(event_ref);
462466
}
463467

464468
template <typename _DataType>

0 commit comments

Comments
 (0)