Skip to content

Commit 5f99f4e

Browse files
authored
add support for axes as list in dpnp.ndarray.transpose (#1770)
* support axes as list * fix a bug * update description * fix docstring when bullet is used
1 parent a8bcdaf commit 5f99f4e

File tree

8 files changed

+122
-61
lines changed

8 files changed

+122
-61
lines changed

dpnp/dpnp_array.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,21 +1317,31 @@ def transpose(self, *axes):
13171317
13181318
For full documentation refer to :obj:`numpy.ndarray.transpose`.
13191319
1320+
Parameters
1321+
----------
1322+
axes : None, tuple or list of ints, n ints, optional
1323+
* ``None`` or no argument: reverses the order of the axes.
1324+
* tuple or list of ints: `i` in the `j`-th place in the tuple/list
1325+
means that the array’s `i`-th axis becomes the transposed
1326+
array’s `j`-th axis.
1327+
* n ints: same as an n-tuple/n-list of the same ints (this form is
1328+
intended simply as a “convenience” alternative to the tuple form).
1329+
13201330
Returns
13211331
-------
1322-
y : dpnp.ndarray
1332+
out : dpnp.ndarray
13231333
View of the array with its axes suitably permuted.
13241334
13251335
See Also
13261336
--------
1327-
:obj:`dpnp.transpose` : Equivalent function.
1328-
:obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed.
1329-
:obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data.
1337+
:obj:`dpnp.transpose` : Equivalent function.
1338+
:obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed.
1339+
:obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data.
13301340
13311341
Examples
13321342
--------
1333-
>>> import dpnp as dp
1334-
>>> a = dp.array([[1, 2], [3, 4]])
1343+
>>> import dpnp as np
1344+
>>> a = np.array([[1, 2], [3, 4]])
13351345
>>> a
13361346
array([[1, 2],
13371347
[3, 4]])
@@ -1342,7 +1352,7 @@ def transpose(self, *axes):
13421352
array([[1, 3],
13431353
[2, 4]])
13441354
1345-
>>> a = dp.array([1, 2, 3, 4])
1355+
>>> a = np.array([1, 2, 3, 4])
13461356
>>> a
13471357
array([1, 2, 3, 4])
13481358
>>> a.transpose()
@@ -1355,7 +1365,7 @@ def transpose(self, *axes):
13551365
return self
13561366

13571367
axes_len = len(axes)
1358-
if axes_len == 1 and isinstance(axes[0], tuple):
1368+
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
13591369
axes = axes[0]
13601370

13611371
res = self.__new__(dpnp_array)

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,12 +587,12 @@ def tensordot(a, b, axes=2):
587587
Second input array. Both inputs `a` and `b` can not be scalars
588588
at the same time.
589589
axes : int or (2,) array_like
590-
* integer_like
591-
If an int `N`, sum over the last `N` axes of `a` and the first `N`
592-
axes of `b` in order. The sizes of the corresponding axes must match.
593-
* (2,) array_like
594-
Or, a list of axes to be summed over, first sequence applying to `a`,
595-
second to `b`. Both elements array_like must be of the same length.
590+
* integer_like: If an int `N`, sum over the last `N` axes of `a` and
591+
the first `N` axes of `b` in order. The sizes of the corresponding
592+
axes must match.
593+
* (2,) array_like: A list of axes to be summed over, first sequence
594+
applying to `a`, second to `b`. Both elements array_like must be of
595+
the same length.
596596
597597
Returns
598598
-------

dpnp/dpnp_iface_manipulation.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,12 +1861,13 @@ def transpose(a, axes=None):
18611861
----------
18621862
a : {dpnp.ndarray, usm_ndarray}
18631863
Input array.
1864-
axes : tuple or list of ints, optional
1864+
axes : None, tuple or list of ints, optional
18651865
If specified, it must be a tuple or list which contains a permutation
18661866
of [0, 1, ..., N-1] where N is the number of axes of `a`.
18671867
The `i`'th axis of the returned array will correspond to the axis
1868-
numbered ``axes[i]`` of the input. If not specified, defaults to
1869-
``range(a.ndim)[::-1]``, which reverses the order of the axes.
1868+
numbered ``axes[i]`` of the input. If not specified or ``None``,
1869+
defaults to ``range(a.ndim)[::-1]``, which reverses the order of
1870+
the axes.
18701871
18711872
Returns
18721873
-------
@@ -1881,35 +1882,35 @@ def transpose(a, axes=None):
18811882
18821883
Examples
18831884
--------
1884-
>>> import dpnp as dp
1885-
>>> a = dp.array([[1, 2], [3, 4]])
1885+
>>> import dpnp as np
1886+
>>> a = np.array([[1, 2], [3, 4]])
18861887
>>> a
18871888
array([[1, 2],
18881889
[3, 4]])
1889-
>>> dp.transpose(a)
1890+
>>> np.transpose(a)
18901891
array([[1, 3],
18911892
[2, 4]])
18921893
1893-
>>> a = dp.array([1, 2, 3, 4])
1894+
>>> a = np.array([1, 2, 3, 4])
18941895
>>> a
18951896
array([1, 2, 3, 4])
1896-
>>> dp.transpose(a)
1897+
>>> np.transpose(a)
18971898
array([1, 2, 3, 4])
18981899
1899-
>>> a = dp.ones((1, 2, 3))
1900-
>>> dp.transpose(a, (1, 0, 2)).shape
1900+
>>> a = np.ones((1, 2, 3))
1901+
>>> np.transpose(a, (1, 0, 2)).shape
19011902
(2, 1, 3)
19021903
1903-
>>> a = dp.ones((2, 3, 4, 5))
1904-
>>> dp.transpose(a).shape
1904+
>>> a = np.ones((2, 3, 4, 5))
1905+
>>> np.transpose(a).shape
19051906
(5, 4, 3, 2)
19061907
19071908
"""
19081909

19091910
if isinstance(a, dpnp_array):
19101911
array = a
19111912
elif isinstance(a, dpt.usm_ndarray):
1912-
array = dpnp_array._create_from_usm_ndarray(a.get_array())
1913+
array = dpnp_array._create_from_usm_ndarray(a)
19131914
else:
19141915
raise TypeError(
19151916
f"An array must be any of supported type, but got {type(a)}"

dpnp/dpnp_iface_mathematical.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,22 +2802,22 @@ def sum(
28022802
Data type of the returned array. If ``None``, the default data
28032803
type is inferred from the "kind" of the input array data type.
28042804
* If `a` has a real-valued floating-point data type,
2805-
the returned array will have the default real-valued
2806-
floating-point data type for the device where input
2807-
array `a` is allocated.
2805+
the returned array will have the default real-valued
2806+
floating-point data type for the device where input
2807+
array `a` is allocated.
28082808
* If `a` has signed integral data type, the returned array
2809-
will have the default signed integral type for the device
2810-
where input array `a` is allocated.
2809+
will have the default signed integral type for the device
2810+
where input array `a` is allocated.
28112811
* If `a` has unsigned integral data type, the returned array
2812-
will have the default unsigned integral type for the device
2813-
where input array `a` is allocated.
2812+
will have the default unsigned integral type for the device
2813+
where input array `a` is allocated.
28142814
* If `a` has a complex-valued floating-point data type,
2815-
the returned array will have the default complex-valued
2816-
floating-pointer data type for the device where input
2817-
array `a` is allocated.
2815+
the returned array will have the default complex-valued
2816+
floating-pointer data type for the device where input
2817+
array `a` is allocated.
28182818
* If `a` has a boolean data type, the returned array will
2819-
have the default signed integral type for the device
2820-
where input array `a` is allocated.
2819+
have the default signed integral type for the device
2820+
where input array `a` is allocated.
28212821
If the data type (either specified or resolved) differs from the
28222822
data type of `a`, the input array elements are cast to the
28232823
specified data type before computing the sum.

dpnp/dpnp_iface_nanfunctions.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -717,22 +717,22 @@ def nansum(
717717
Data type of the returned array. If ``None``, the default data
718718
type is inferred from the "kind" of the input array data type.
719719
* If `a` has a real-valued floating-point data type,
720-
the returned array will have the default real-valued
721-
floating-point data type for the device where input
722-
array `a` is allocated.
720+
the returned array will have the default real-valued
721+
floating-point data type for the device where input
722+
array `a` is allocated.
723723
* If `a` has signed integral data type, the returned array
724-
will have the default signed integral type for the device
725-
where input array `a` is allocated.
724+
will have the default signed integral type for the device
725+
where input array `a` is allocated.
726726
* If `a` has unsigned integral data type, the returned array
727-
will have the default unsigned integral type for the device
728-
where input array `a` is allocated.
727+
will have the default unsigned integral type for the device
728+
where input array `a` is allocated.
729729
* If `a` has a complex-valued floating-point data type,
730-
the returned array will have the default complex-valued
731-
floating-pointer data type for the device where input
732-
array `a` is allocated.
730+
the returned array will have the default complex-valued
731+
floating-pointer data type for the device where input
732+
array `a` is allocated.
733733
* If `a` has a boolean data type, the returned array will
734-
have the default signed integral type for the device
735-
where input array `a` is allocated.
734+
have the default signed integral type for the device
735+
where input array `a` is allocated.
736736
If the data type (either specified or resolved) differs from the
737737
data type of `a`, the input array elements are cast to the
738738
specified data type before computing the sum.

dpnp/dpnp_iface_trigonometric.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,14 +1355,14 @@ def logsumexp(x, axis=None, out=None, dtype=None, keepdims=False):
13551355
Data type of the returned array. If ``None``, the default data
13561356
type is inferred from the "kind" of the input array data type.
13571357
* If `x` has a real-valued floating-point data type,
1358-
the returned array will have the default real-valued
1359-
floating-point data type for the device where input
1360-
array `x` is allocated.
1358+
the returned array will have the default real-valued
1359+
floating-point data type for the device where input
1360+
array `x` is allocated.
13611361
* If `x` has a boolean or integral data type, the returned array
1362-
will have the default floating point data type for the device
1363-
where input array `x` is allocated.
1362+
will have the default floating point data type for the device
1363+
where input array `x` is allocated.
13641364
* If `x` has a complex-valued floating-point data type,
1365-
an error is raised.
1365+
an error is raised.
13661366
If the data type (either specified or resolved) differs from the
13671367
data type of `x`, the input array elements are cast to the
13681368
specified data type before computing the result. Default: ``None``.

tests/test_manipulation.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_unique(array):
115115

116116

117117
class TestTranspose:
118-
@pytest.mark.parametrize("axes", [(0, 1), (1, 0)])
118+
@pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]])
119119
def test_2d_with_axes(self, axes):
120120
na = numpy.array([[1, 2], [3, 4]])
121121
da = dpnp.array(na)
@@ -124,7 +124,22 @@ def test_2d_with_axes(self, axes):
124124
result = dpnp.transpose(da, axes)
125125
assert_array_equal(expected, result)
126126

127-
@pytest.mark.parametrize("axes", [(1, 0, 2), ((1, 0, 2),)])
127+
# ndarray
128+
expected = na.transpose(axes)
129+
result = da.transpose(axes)
130+
assert_array_equal(expected, result)
131+
132+
@pytest.mark.parametrize(
133+
"axes",
134+
[
135+
(1, 0, 2),
136+
[1, 0, 2],
137+
((1, 0, 2),),
138+
([1, 0, 2],),
139+
[(1, 0, 2)],
140+
[[1, 0, 2]],
141+
],
142+
)
128143
def test_3d_with_packed_axes(self, axes):
129144
na = numpy.ones((1, 2, 3))
130145
da = dpnp.array(na)
@@ -133,10 +148,27 @@ def test_3d_with_packed_axes(self, axes):
133148
result = da.transpose(*axes)
134149
assert_array_equal(expected, result)
135150

151+
# ndarray
152+
expected = na.transpose(*axes)
153+
result = da.transpose(*axes)
154+
assert_array_equal(expected, result)
155+
136156
@pytest.mark.parametrize("shape", [(10,), (2, 4), (5, 3, 7), (3, 8, 4, 1)])
137157
def test_none_axes(self, shape):
138158
na = numpy.ones(shape)
139159
da = dpnp.ones(shape)
140160

161+
assert_array_equal(numpy.transpose(na), dpnp.transpose(da))
162+
assert_array_equal(numpy.transpose(na, None), dpnp.transpose(da, None))
163+
164+
# ndarray
141165
assert_array_equal(na.transpose(), da.transpose())
142166
assert_array_equal(na.transpose(None), da.transpose(None))
167+
168+
def test_ndarray_axes_n_int(self):
169+
na = numpy.ones((1, 2, 3))
170+
da = dpnp.array(na)
171+
172+
expected = na.transpose(1, 0, 2)
173+
result = da.transpose(1, 0, 2)
174+
assert_array_equal(expected, result)

tests/third_party/cupy/manipulation_tests/test_transpose.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,38 @@ def test_moveaxis_invalid2_2(self):
6464
with pytest.raises(numpy.AxisError):
6565
xp.moveaxis(a, [0, -4], [1, 2])
6666

67+
def test_moveaxis_invalid2_3(self):
68+
for xp in (numpy, cupy):
69+
a = testing.shaped_arange((2, 3, 4), xp)
70+
with pytest.raises(numpy.AxisError):
71+
xp.moveaxis(a, -4, 0)
72+
6773
# len(source) != len(destination)
68-
def test_moveaxis_invalid3(self):
74+
def test_moveaxis_invalid3_1(self):
6975
for xp in (numpy, cupy):
7076
a = testing.shaped_arange((2, 3, 4), xp)
7177
with pytest.raises(ValueError):
7278
xp.moveaxis(a, [0, 1, 2], [1, 2])
7379

80+
def test_moveaxis_invalid3_2(self):
81+
for xp in (numpy, cupy):
82+
a = testing.shaped_arange((2, 3, 4), xp)
83+
with pytest.raises(ValueError):
84+
xp.moveaxis(a, 0, [1, 2])
85+
7486
# len(source) != len(destination)
75-
def test_moveaxis_invalid4(self):
87+
def test_moveaxis_invalid4_1(self):
7688
for xp in (numpy, cupy):
7789
a = testing.shaped_arange((2, 3, 4), xp)
7890
with pytest.raises(ValueError):
7991
xp.moveaxis(a, [0, 1], [1, 2, 0])
8092

93+
def test_moveaxis_invalid4_2(self):
94+
for xp in (numpy, cupy):
95+
a = testing.shaped_arange((2, 3, 4), xp)
96+
with pytest.raises(ValueError):
97+
xp.moveaxis(a, [0, 1], 1)
98+
8199
# Use the same axis twice
82100
def test_moveaxis_invalid5_1(self):
83101
for xp in (numpy, cupy):

0 commit comments

Comments
 (0)