Skip to content

Commit 2987585

Browse files
authored
implement dpnp.apply_over_axes (#2174)
* implement dpnp.apply_over_axes * fix issue with a test
1 parent 264c6d8 commit 2987585

File tree

4 files changed

+125
-3
lines changed

4 files changed

+125
-3
lines changed

dpnp/dpnp_iface_functional.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@
3838

3939

4040
import numpy
41-
from dpctl.tensor._numpy_helper import normalize_axis_index
41+
from dpctl.tensor._numpy_helper import (
42+
normalize_axis_index,
43+
normalize_axis_tuple,
44+
)
4245

4346
import dpnp
4447

45-
__all__ = ["apply_along_axis"]
48+
__all__ = ["apply_along_axis", "apply_over_axes"]
4649

4750

4851
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
@@ -185,3 +188,83 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
185188
buff = dpnp.moveaxis(buff, -1, axis)
186189

187190
return buff
191+
192+
193+
def apply_over_axes(func, a, axes):
194+
"""
195+
Apply a function repeatedly over multiple axes.
196+
197+
`func` is called as ``res = func(a, axis)``, where `axis` is the first
198+
element of `axes`. The result `res` of the function call must have
199+
either the same dimensions as `a` or one less dimension. If `res`
200+
has one less dimension than `a`, a dimension is inserted before
201+
`axis`. The call to `func` is then repeated for each axis in `axes`,
202+
with `res` as the first argument.
203+
204+
For full documentation refer to :obj:`numpy.apply_over_axes`.
205+
206+
Parameters
207+
----------
208+
func : function
209+
This function must take two arguments, ``func(a, axis)``.
210+
a : {dpnp.ndarray, usm_ndarray}
211+
Input array.
212+
axes : {int, sequence of ints}
213+
Axes over which `func` is applied.
214+
215+
Returns
216+
-------
217+
out : dpnp.ndarray
218+
The output array. The number of dimensions is the same as `a`,
219+
but the shape can be different. This depends on whether `func`
220+
changes the shape of its output with respect to its input.
221+
222+
See Also
223+
--------
224+
:obj:`dpnp.apply_along_axis` : Apply a function to 1-D slices of an array
225+
along the given axis.
226+
227+
Examples
228+
--------
229+
>>> import dpnp as np
230+
>>> a = np.arange(24).reshape(2, 3, 4)
231+
>>> a
232+
array([[[ 0, 1, 2, 3],
233+
[ 4, 5, 6, 7],
234+
[ 8, 9, 10, 11]],
235+
[[12, 13, 14, 15],
236+
[16, 17, 18, 19],
237+
[20, 21, 22, 23]]])
238+
239+
Sum over axes 0 and 2. The result has same number of dimensions
240+
as the original array:
241+
242+
>>> np.apply_over_axes(np.sum, a, [0, 2])
243+
array([[[ 60],
244+
[ 92],
245+
[124]]])
246+
247+
Tuple axis arguments to ufuncs are equivalent:
248+
249+
>>> np.sum(a, axis=(0, 2), keepdims=True)
250+
array([[[ 60],
251+
[ 92],
252+
[124]]])
253+
254+
"""
255+
256+
dpnp.check_supported_arrays_type(a)
257+
if isinstance(axes, int):
258+
axes = (axes,)
259+
axes = normalize_axis_tuple(axes, a.ndim)
260+
261+
for axis in axes:
262+
res = func(a, axis)
263+
if res.ndim != a.ndim:
264+
res = dpnp.expand_dims(res, axis)
265+
if res.ndim != a.ndim:
266+
raise ValueError(
267+
"function is not returning an array of the correct shape"
268+
)
269+
a = res
270+
return res

dpnp/tests/test_functional.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy
22
import pytest
3-
from numpy.testing import assert_array_equal
3+
from numpy.testing import assert_array_equal, assert_raises
44

55
import dpnp
66

@@ -46,3 +46,22 @@ def test_args(self, dtype):
4646
# positional args: axis, dtype, out, keepdims
4747
result = dpnp.apply_along_axis(dpnp.mean, 0, ia, 0, dtype, None, True)
4848
assert_array_equal(result, expected)
49+
50+
51+
class TestApplyOverAxes:
52+
@pytest.mark.parametrize("func", ["sum", "cumsum"])
53+
@pytest.mark.parametrize("axes", [1, [0, 2], (-1, -2)])
54+
def test_basic(self, func, axes):
55+
a = numpy.arange(24).reshape(2, 3, 4)
56+
ia = dpnp.array(a)
57+
58+
expected = numpy.apply_over_axes(getattr(numpy, func), a, axes)
59+
result = dpnp.apply_over_axes(getattr(dpnp, func), ia, axes)
60+
assert_array_equal(result, expected)
61+
62+
def test_custom_func(self):
63+
def custom_func(x, axis):
64+
return dpnp.expand_dims(dpnp.expand_dims(x, axis), axis)
65+
66+
ia = dpnp.arange(24).reshape(2, 3, 4)
67+
assert_raises(ValueError, dpnp.apply_over_axes, custom_func, ia, 1)

dpnp/tests/test_sycl_queue.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,18 @@ def test_apply_along_axis(device):
22152215
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
22162216

22172217

2218+
@pytest.mark.parametrize(
2219+
"device",
2220+
valid_devices,
2221+
ids=[device.filter_string for device in valid_devices],
2222+
)
2223+
def test_apply_over_axes(device):
2224+
x = dpnp.arange(18, device=device).reshape(2, 3, 3)
2225+
result = dpnp.apply_over_axes(dpnp.sum, x, [0, 1])
2226+
2227+
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
2228+
2229+
22182230
@pytest.mark.parametrize(
22192231
"device_x",
22202232
valid_devices,

dpnp/tests/test_usm_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,14 @@ def test_apply_along_axis(usm_type):
784784
assert x.usm_type == y.usm_type
785785

786786

787+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
788+
def test_apply_over_axes(usm_type):
789+
x = dp.arange(18, usm_type=usm_type).reshape(2, 3, 3)
790+
y = dp.apply_over_axes(dp.sum, x, [0, 1])
791+
792+
assert x.usm_type == y.usm_type
793+
794+
787795
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
788796
def test_broadcast_to(usm_type):
789797
x = dp.ones(7, usm_type=usm_type)

0 commit comments

Comments
 (0)