Skip to content

Commit 9ff406b

Browse files
committed
address comments
1 parent 03276b7 commit 9ff406b

File tree

5 files changed

+226
-29
lines changed

5 files changed

+226
-29
lines changed

dpnp/dpnp_array.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,11 @@ def max(
942942
initial=None,
943943
where=True,
944944
):
945-
"""Return the maximum along an axis."""
945+
"""
946+
Return the maximum along an axis.
947+
948+
Refer to :obj:`dpnp.max` for full documentation.
949+
"""
946950

947951
return dpnp.max(self, axis, out, keepdims, initial, where)
948952

@@ -959,7 +963,11 @@ def min(
959963
initial=None,
960964
where=True,
961965
):
962-
"""Return the minimum along a given axis."""
966+
"""
967+
Return the minimum along a given axis.
968+
969+
Refer to :obj:`dpnp.min` for full documentation.
970+
"""
963971

964972
return dpnp.min(self, axis, out, keepdims, initial, where)
965973

dpnp/dpnp_iface_statistics.py

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
354354
"""
355355
Return the maximum of an array or maximum along an axis.
356356
357+
For full documentation refer to :obj:`numpy.max`.
358+
357359
Returns
358360
-------
359361
out : dpnp.ndarray
@@ -396,19 +398,56 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
396398
397399
"""
398400

399-
if out is not None:
400-
pass
401-
elif initial is not None:
402-
pass
401+
if initial is not None:
402+
raise NotImplementedError(
403+
"initial keyword arguemnts is only supported by its default value."
404+
)
403405
elif where is not True:
404-
pass
406+
raise NotImplementedError(
407+
"where keyword arguemnts is only supported by its default values."
408+
)
405409
else:
406410
dpt_array = dpnp.get_usm_ndarray(a)
407-
return dpnp_array._create_from_usm_ndarray(
408-
dpt.max(dpt_array, axis=axis, keepdims=keepdims)
409-
)
410-
411-
return call_origin(numpy.max, a, axis, out, keepdims, initial, where)
411+
if dpt_array.size == 0:
412+
# TODO: get rid of this if condition when dpctl supports it
413+
for i in range(a.ndim):
414+
if a.shape[i] == 0:
415+
if i not in axis:
416+
indices = [i for i in range(a.ndim) if i not in axis]
417+
res_shape = tuple([a.shape[i] for i in indices])
418+
result = dpnp.empty(res_shape, dtype=a.dtype)
419+
else:
420+
raise ValueError(
421+
"reduction does not support zero-size arrays"
422+
)
423+
else:
424+
result = dpnp_array._create_from_usm_ndarray(
425+
dpt.max(dpt_array, axis=axis, keepdims=keepdims)
426+
)
427+
if out is None:
428+
return result
429+
else:
430+
if out.shape != result.shape:
431+
raise ValueError(
432+
f"Output array of shape {result.shape} is needed, got {out.shape}."
433+
)
434+
elif out.dtype != result.dtype:
435+
raise ValueError(
436+
f"Output array of type {result.dtype} is needed, got {out.dtype}."
437+
)
438+
elif not isinstance(out, dpnp_array):
439+
if isinstance(out, dpt.usm_ndarray):
440+
out = dpnp.array(out)
441+
else:
442+
raise ValueError(
443+
"An array must be any of supported type, but got {}".format(
444+
type(out)
445+
)
446+
)
447+
448+
dpnp.copyto(out, result)
449+
450+
return out
412451

413452

414453
def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True):
@@ -606,19 +645,56 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True):
606645
607646
"""
608647

609-
if out is not None:
610-
pass
611-
elif initial is not None:
612-
pass
648+
if initial is not None:
649+
raise NotImplementedError(
650+
"initial keyword arguemnts is only supported by its default value."
651+
)
613652
elif where is not True:
614-
pass
653+
raise NotImplementedError(
654+
"where keyword arguemnts is only supported by its default values."
655+
)
615656
else:
616657
dpt_array = dpnp.get_usm_ndarray(a)
617-
return dpnp_array._create_from_usm_ndarray(
618-
dpt.min(dpt_array, axis=axis, keepdims=keepdims)
619-
)
620-
621-
return call_origin(numpy.min, a, axis, out, keepdims, initial, where)
658+
if dpt_array.size == 0:
659+
# TODO: get rid of this if condition when dpctl supports it
660+
for i in range(a.ndim):
661+
if a.shape[i] == 0:
662+
if i not in axis:
663+
indices = [i for i in range(a.ndim) if i not in axis]
664+
res_shape = tuple([a.shape[i] for i in indices])
665+
result = dpnp.empty(res_shape, dtype=a.dtype)
666+
else:
667+
raise ValueError(
668+
"reduction does not support zero-size arrays"
669+
)
670+
else:
671+
result = dpnp_array._create_from_usm_ndarray(
672+
dpt.min(dpt_array, axis=axis, keepdims=keepdims)
673+
)
674+
if out is None:
675+
return result
676+
else:
677+
if out.shape != result.shape:
678+
raise ValueError(
679+
f"Output array of shape {result.shape} is needed, got {out.shape}."
680+
)
681+
elif out.dtype != result.dtype:
682+
raise ValueError(
683+
f"Output array of type {result.dtype} is needed, got {out.dtype}."
684+
)
685+
elif not isinstance(out, dpnp_array):
686+
if isinstance(out, dpt.usm_ndarray):
687+
out = dpnp.array(out)
688+
else:
689+
raise ValueError(
690+
"An array must be any of supported type, but got {}".format(
691+
type(out)
692+
)
693+
)
694+
695+
dpnp.copyto(out, result)
696+
697+
return out
622698

623699

624700
def nanvar(x1, axis=None, dtype=None, out=None, ddof=0, keepdims=False):

tests/test_amin_amax.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .helper import get_all_dtypes
88

99

10-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
10+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1111
def test_amax(dtype):
1212
a = numpy.array(
1313
[
@@ -25,7 +25,7 @@ def test_amax(dtype):
2525
assert_allclose(expected, result)
2626

2727

28-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
28+
@pytest.mark.parametrize("dtype", get_all_dtypes())
2929
def test_amin(dtype):
3030
a = numpy.array(
3131
[
@@ -55,8 +55,7 @@ def _get_min_max_input(type, shape):
5555
return a.reshape(shape)
5656

5757

58-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
59-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
58+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
6059
@pytest.mark.parametrize(
6160
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2,3)", "(4,5,6)"]
6261
)
@@ -74,8 +73,7 @@ def test_amax_diff_shape(dtype, shape):
7473
numpy.testing.assert_array_equal(dpnp_res, np_res)
7574

7675

77-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
78-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True, no_complex=True))
76+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
7977
@pytest.mark.parametrize(
8078
"shape", [(4,), (2, 3), (4, 5, 6)], ids=["(4,)", "(2,3)", "(4,5,6)"]
8179
)

tests/test_statistics.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_median(dtype, size):
2323

2424
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
2525
@pytest.mark.parametrize("keepdims", [False, True])
26-
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True, no_bool=True))
26+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
2727
def test_max_min(axis, keepdims, dtype):
2828
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
2929
ia = dpnp.array(a)
@@ -41,6 +41,62 @@ def test_max_min(axis, keepdims, dtype):
4141
assert_allclose(dpnp_res, np_res)
4242

4343

44+
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
45+
@pytest.mark.parametrize("keepdims", [False, True])
46+
def test_max_min_bool(axis, keepdims):
47+
a = numpy.arange(2, dtype=dpnp.bool)
48+
a = numpy.tile(a, (2, 2))
49+
ia = dpnp.array(a)
50+
51+
np_res = numpy.max(a, axis=axis, keepdims=keepdims)
52+
dpnp_res = dpnp.max(ia, axis=axis, keepdims=keepdims)
53+
54+
assert dpnp_res.shape == np_res.shape
55+
assert_allclose(dpnp_res, np_res)
56+
57+
np_res = numpy.min(a, axis=axis, keepdims=keepdims)
58+
dpnp_res = dpnp.min(ia, axis=axis, keepdims=keepdims)
59+
60+
assert dpnp_res.shape == np_res.shape
61+
assert_allclose(dpnp_res, np_res)
62+
63+
64+
@pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])
65+
@pytest.mark.parametrize("keepdims", [False, True])
66+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
67+
def test_max_min_out(axis, keepdims, dtype):
68+
a = numpy.arange(768, dtype=dtype).reshape((4, 4, 6, 8))
69+
ia = dpnp.array(a)
70+
71+
np_res = numpy.max(a, axis=axis, keepdims=keepdims)
72+
dpnp_res = dpnp.array(numpy.empty_like(np_res))
73+
dpnp.max(ia, axis=axis, keepdims=keepdims, out=dpnp_res)
74+
75+
assert dpnp_res.shape == np_res.shape
76+
assert_allclose(dpnp_res, np_res)
77+
78+
np_res = numpy.min(a, axis=axis, keepdims=keepdims)
79+
dpnp_res = dpnp.array(numpy.empty_like(np_res))
80+
dpnp.min(ia, axis=axis, keepdims=keepdims, out=dpnp_res)
81+
82+
assert dpnp_res.shape == np_res.shape
83+
assert_allclose(dpnp_res, np_res)
84+
85+
86+
def test_max_min_NotImplemented():
87+
ia = dpnp.arange(5)
88+
89+
with pytest.raises(NotImplementedError):
90+
dpnp.max(ia, where=False)
91+
with pytest.raises(NotImplementedError):
92+
dpnp.max(ia, initial=6)
93+
94+
with pytest.raises(NotImplementedError):
95+
dpnp.min(ia, where=False)
96+
with pytest.raises(NotImplementedError):
97+
dpnp.max(ia, initial=6)
98+
99+
44100
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
45101
@pytest.mark.parametrize(
46102
"array",

tests/third_party/cupy/core_tests/test_ndarray_reduction.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,65 @@ def test_ptp_nan_imag(self, xp, dtype):
215215
return a.ptp()
216216

217217

218+
@testing.parameterize(
219+
*testing.product(
220+
{
221+
# TODO(leofang): make a @testing.for_all_axes decorator
222+
"shape_and_axis": [
223+
((), None),
224+
((0,), (0,)),
225+
((0, 2), (0,)),
226+
((0, 2), (1,)),
227+
((0, 2), (0, 1)),
228+
((2, 0), (0,)),
229+
((2, 0), (1,)),
230+
((2, 0), (0, 1)),
231+
((0, 2, 3), (0,)),
232+
((0, 2, 3), (1,)),
233+
((0, 2, 3), (2,)),
234+
((0, 2, 3), (0, 1)),
235+
((0, 2, 3), (1, 2)),
236+
((0, 2, 3), (0, 2)),
237+
((0, 2, 3), (0, 1, 2)),
238+
((2, 0, 3), (0,)),
239+
((2, 0, 3), (1,)),
240+
((2, 0, 3), (2,)),
241+
((2, 0, 3), (0, 1)),
242+
((2, 0, 3), (1, 2)),
243+
((2, 0, 3), (0, 2)),
244+
((2, 0, 3), (0, 1, 2)),
245+
((2, 3, 0), (0,)),
246+
((2, 3, 0), (1,)),
247+
((2, 3, 0), (2,)),
248+
((2, 3, 0), (0, 1)),
249+
((2, 3, 0), (1, 2)),
250+
((2, 3, 0), (0, 2)),
251+
((2, 3, 0), (0, 1, 2)),
252+
],
253+
"order": ("C", "F"),
254+
"func": ("min", "max"),
255+
}
256+
)
257+
)
258+
class TestArrayReductionZeroSize:
259+
@testing.numpy_cupy_allclose(
260+
contiguous_check=False, accept_error=ValueError
261+
)
262+
def test_zero_size(self, xp):
263+
shape, axis = self.shape_and_axis
264+
# NumPy only supports axis being an int
265+
if self.func in ("argmax", "argmin"):
266+
if axis is not None and len(axis) == 1:
267+
axis = axis[0]
268+
else:
269+
pytest.skip(
270+
f"NumPy does not support axis={axis} for {self.func}"
271+
)
272+
# dtype is irrelevant here, just pick one
273+
a = testing.shaped_random(shape, xp, xp.float32, order=self.order)
274+
return getattr(a, self.func)(axis=axis)
275+
276+
218277
# This class compares CUB results against NumPy's
219278
@testing.parameterize(
220279
*testing.product(

0 commit comments

Comments
 (0)