Skip to content

Commit a7add8e

Browse files
authored
Redesign dpnp.put_along_axis and dpnp.take_along_axis thorough existing calls (#1636)
* Redesigned `put_along_axis` and `take_along_axis` thorugh existing calls * Simplified check for * Move check of array type in dpnp.prod after the TODO comment
1 parent 2c8cbb5 commit a7add8e

14 files changed

+424
-321
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -231,21 +231,19 @@ enum class DPNPFuncName : size_t
231231
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
232232
DPNP_FN_PUT, /**< Used in numpy.put() impl */
233233
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
234-
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
235-
requires extra parameters */
236-
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
237-
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
238-
parameters */
239-
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
240-
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
241-
parameters */
242-
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
243-
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
244-
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
245-
parameters */
246-
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
247-
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
248-
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
234+
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
235+
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
236+
parameters */
237+
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
238+
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
239+
parameters */
240+
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
241+
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
242+
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
243+
parameters */
244+
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
245+
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
246+
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
249247
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
250248
parameters */
251249
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */

dpnp/backend/kernels/dpnp_krnl_indexing.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -796,19 +796,6 @@ void (*dpnp_put_along_axis_default_c)(void *,
796796
size_t) =
797797
dpnp_put_along_axis_c<_DataType>;
798798

799-
template <typename _DataType>
800-
DPCTLSyclEventRef (*dpnp_put_along_axis_ext_c)(DPCTLSyclQueueRef,
801-
void *,
802-
long *,
803-
void *,
804-
size_t,
805-
const shape_elem_type *,
806-
size_t,
807-
size_t,
808-
size_t,
809-
const DPCTLEventVectorRef) =
810-
dpnp_put_along_axis_c<_DataType>;
811-
812799
template <typename _DataType, typename _IndecesType>
813800
class dpnp_take_c_kernel;
814801

@@ -1005,15 +992,6 @@ void func_map_init_indexing_func(func_map_t &fmap)
1005992
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS][eft_DBL][eft_DBL] = {
1006993
eft_DBL, (void *)dpnp_put_along_axis_default_c<double>};
1007994

1008-
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_INT][eft_INT] = {
1009-
eft_INT, (void *)dpnp_put_along_axis_ext_c<int32_t>};
1010-
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_LNG][eft_LNG] = {
1011-
eft_LNG, (void *)dpnp_put_along_axis_ext_c<int64_t>};
1012-
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_FLT][eft_FLT] = {
1013-
eft_FLT, (void *)dpnp_put_along_axis_ext_c<float>};
1014-
fmap[DPNPFuncName::DPNP_FN_PUT_ALONG_AXIS_EXT][eft_DBL][eft_DBL] = {
1015-
eft_DBL, (void *)dpnp_put_along_axis_ext_c<double>};
1016-
1017995
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_BLN][eft_INT] = {
1018996
eft_BLN, (void *)dpnp_take_default_c<bool, int32_t>};
1019997
fmap[DPNPFuncName::DPNP_FN_TAKE][eft_INT][eft_INT] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
156156
DPNP_FN_RNG_POISSON_EXT
157157
DPNP_FN_RNG_POWER
158158
DPNP_FN_RNG_POWER_EXT
159-
DPNP_FN_PUT_ALONG_AXIS
160-
DPNP_FN_PUT_ALONG_AXIS_EXT
161159
DPNP_FN_RNG_RAYLEIGH
162160
DPNP_FN_RNG_RAYLEIGH_EXT
163161
DPNP_FN_RNG_SHUFFLE

dpnp/dpnp_algo/dpnp_algo_indexing.pxi

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ __all__ += [
4141
"dpnp_diagonal",
4242
"dpnp_fill_diagonal",
4343
"dpnp_indices",
44-
"dpnp_put_along_axis",
4544
"dpnp_putmask",
4645
"dpnp_select",
47-
"dpnp_take_along_axis",
4846
"dpnp_tril_indices",
4947
"dpnp_tril_indices_from",
5048
"dpnp_triu_indices",
@@ -69,16 +67,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t_)(c_dpct
6967
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
7068
void *, void * , shape_elem_type * , const size_t,
7169
const c_dpctl.DPCTLEventVectorRef)
72-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_3in_with_axis_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
73-
void * ,
74-
void * ,
75-
void * ,
76-
const size_t,
77-
shape_elem_type * ,
78-
const size_t,
79-
const size_t,
80-
const size_t,
81-
const c_dpctl.DPCTLEventVectorRef)
8270

8371

8472
cpdef utils.dpnp_descriptor dpnp_choose(utils.dpnp_descriptor x1, list choices1):
@@ -283,35 +271,6 @@ cpdef object dpnp_indices(dimensions):
283271
return dpnp_result
284272

285273

286-
cpdef dpnp_put_along_axis(dpnp_descriptor arr, dpnp_descriptor indices, dpnp_descriptor values, int axis):
287-
cdef shape_type_c arr_shape = arr.shape
288-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
289-
290-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PUT_ALONG_AXIS_EXT, param1_type, param1_type)
291-
292-
utils.get_common_usm_allocation(arr, indices) # check USM allocation is common
293-
_, _, result_sycl_queue = utils.get_common_usm_allocation(arr, values)
294-
295-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
296-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
297-
298-
cdef custom_indexing_3in_with_axis_func_ptr_t func = <custom_indexing_3in_with_axis_func_ptr_t > kernel_data.ptr
299-
300-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
301-
arr.get_data(),
302-
indices.get_data(),
303-
values.get_data(),
304-
axis,
305-
arr_shape.data(),
306-
arr.ndim,
307-
indices.size,
308-
values.size,
309-
NULL) # dep_events_ref
310-
311-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
312-
c_dpctl.DPCTLEvent_Delete(event_ref)
313-
314-
315274
cpdef dpnp_putmask(utils.dpnp_descriptor arr, utils.dpnp_descriptor mask, utils.dpnp_descriptor values):
316275
cdef int values_size = values.size
317276

@@ -341,94 +300,6 @@ cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default)
341300
return res_array
342301

343302

344-
cpdef object dpnp_take_along_axis(object arr, object indices, int axis):
345-
cdef long size_arr = arr.size
346-
cdef shape_type_c shape_arr = arr.shape
347-
cdef shape_type_c output_shape
348-
cdef long size_indices = indices.size
349-
res_type = arr.dtype
350-
351-
if axis != arr.ndim - 1:
352-
res_shape_list = list(shape_arr)
353-
res_shape_list[axis] = 1
354-
res_shape = tuple(res_shape_list)
355-
356-
output_shape = (0,) * (len(shape_arr) - 1)
357-
ind = 0
358-
for id, shape_axis in enumerate(shape_arr):
359-
if id != axis:
360-
output_shape[ind] = shape_axis
361-
ind += 1
362-
363-
prod = 1
364-
for i in range(len(output_shape)):
365-
if output_shape[i] != 0:
366-
prod *= output_shape[i]
367-
368-
result_array = dpnp.empty((prod, ), dtype=res_type)
369-
ind_array = [None] * prod
370-
arr_shape_offsets = [None] * len(shape_arr)
371-
acc = 1
372-
373-
for i in range(len(shape_arr)):
374-
ind = len(shape_arr) - 1 - i
375-
arr_shape_offsets[ind] = acc
376-
acc *= shape_arr[ind]
377-
378-
output_shape_offsets = [None] * len(shape_arr)
379-
acc = 1
380-
381-
for i in range(len(output_shape)):
382-
ind = len(output_shape) - 1 - i
383-
output_shape_offsets[ind] = acc
384-
acc *= output_shape[ind]
385-
result_offsets = arr_shape_offsets[:] # need copy. not a reference
386-
result_offsets[axis] = 0
387-
388-
for source_idx in range(size_arr):
389-
390-
# reconstruct x,y,z from linear source_idx
391-
xyz = []
392-
remainder = source_idx
393-
for i in arr_shape_offsets:
394-
quotient, remainder = divmod(remainder, i)
395-
xyz.append(quotient)
396-
397-
# extract result axis
398-
result_axis = []
399-
for idx, offset in enumerate(xyz):
400-
if idx != axis:
401-
result_axis.append(offset)
402-
403-
# Construct result offset
404-
result_offset = 0
405-
for i, result_axis_val in enumerate(result_axis):
406-
result_offset += (output_shape_offsets[i] * result_axis_val)
407-
408-
arr_elem = arr.item(source_idx)
409-
if ind_array[result_offset] is None:
410-
ind_array[result_offset] = 0
411-
else:
412-
ind_array[result_offset] += 1
413-
414-
if ind_array[result_offset] % size_indices == indices.item(result_offset % size_indices):
415-
result_array[result_offset] = arr_elem
416-
417-
dpnp_result_array = dpnp.reshape(result_array, res_shape)
418-
return dpnp_result_array
419-
420-
else:
421-
result_array = utils_py.create_output_descriptor_py(shape_arr, res_type, None).get_pyobj()
422-
423-
result_array_flatiter = result_array.flat
424-
425-
for i in range(size_arr):
426-
ind = size_indices * (i // size_indices) + indices.item(i % size_indices)
427-
result_array_flatiter[i] = arr.item(ind)
428-
429-
return result_array
430-
431-
432303
cpdef tuple dpnp_tril_indices(n, k=0, m=None):
433304
array1 = []
434305
array2 = []

dpnp/dpnp_iface.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"array_equal",
5959
"asnumpy",
6060
"astype",
61+
"check_supported_arrays_type",
6162
"convert_single_elem_array_to_scalar",
6263
"default_float_type",
6364
"dpnp_queue_initialize",
@@ -203,6 +204,42 @@ def astype(x1, dtype, order="K", casting="unsafe", copy=True):
203204
return dpnp_array._create_from_usm_ndarray(array_obj)
204205

205206

207+
def check_supported_arrays_type(*arrays, scalar_type=False):
208+
"""
209+
Return ``True`` if each array has either type of scalar,
210+
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
211+
But if any array has unsupported type, ``TypeError`` will be raised.
212+
213+
Parameters
214+
----------
215+
arrays : {dpnp_array, usm_ndarray}
216+
Input arrays to check for supported types.
217+
scalar_type : {bool}, optional
218+
A scalar type is also considered as supported if flag is True.
219+
220+
Returns
221+
-------
222+
out : bool
223+
``True`` if each type of input `arrays` is supported type,
224+
``False`` otherwise.
225+
226+
Raises
227+
------
228+
TypeError
229+
If any input array from `arrays` is of unsupported array type.
230+
231+
"""
232+
233+
for a in arrays:
234+
if scalar_type and dpnp.isscalar(a) or is_supported_array_type(a):
235+
continue
236+
237+
raise TypeError(
238+
"An array must be any of supported type, but got {}".format(type(a))
239+
)
240+
return True
241+
242+
206243
def convert_single_elem_array_to_scalar(obj, keepdims=False):
207244
"""Convert array with single element to scalar."""
208245

0 commit comments

Comments
 (0)