Skip to content

Commit 1608284

Browse files
committed
support axes as list
1 parent a134bdc commit 1608284

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

dpnp/dpnp_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def transpose(self, *axes):
13551355
return self
13561356

13571357
axes_len = len(axes)
1358-
if axes_len == 1 and isinstance(axes[0], tuple):
1358+
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
13591359
axes = axes[0]
13601360

13611361
res = self.__new__(dpnp_array)

tests/test_manipulation.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_unique(array):
115115

116116

117117
class TestTranspose:
118-
@pytest.mark.parametrize("axes", [(0, 1), (1, 0)])
118+
@pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]])
119119
def test_2d_with_axes(self, axes):
120120
na = numpy.array([[1, 2], [3, 4]])
121121
da = dpnp.array(na)
@@ -124,7 +124,22 @@ def test_2d_with_axes(self, axes):
124124
result = dpnp.transpose(da, axes)
125125
assert_array_equal(expected, result)
126126

127-
@pytest.mark.parametrize("axes", [(1, 0, 2), ((1, 0, 2),)])
127+
# ndarray
128+
expected = na.transpose(axes)
129+
result = da.transpose(axes)
130+
assert_array_equal(expected, result)
131+
132+
@pytest.mark.parametrize(
133+
"axes",
134+
[
135+
(1, 0, 2),
136+
[1, 0, 2],
137+
((1, 0, 2),),
138+
([1, 0, 2],),
139+
[(1, 0, 2)],
140+
[[1, 0, 2]],
141+
],
142+
)
128143
def test_3d_with_packed_axes(self, axes):
129144
na = numpy.ones((1, 2, 3))
130145
da = dpnp.array(na)
@@ -133,6 +148,11 @@ def test_3d_with_packed_axes(self, axes):
133148
result = da.transpose(*axes)
134149
assert_array_equal(expected, result)
135150

151+
# ndarray
152+
expected = na.transpose(*axes)
153+
result = da.transpose(*axes)
154+
assert_array_equal(expected, result)
155+
136156
@pytest.mark.parametrize("shape", [(10,), (2, 4), (5, 3, 7), (3, 8, 4, 1)])
137157
def test_none_axes(self, shape):
138158
na = numpy.ones(shape)

tests/third_party/cupy/manipulation_tests/test_transpose.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,38 @@ def test_moveaxis_invalid2_2(self):
6464
with pytest.raises(numpy.AxisError):
6565
xp.moveaxis(a, [0, -4], [1, 2])
6666

67+
def test_moveaxis_invalid2_3(self):
68+
for xp in (numpy, cupy):
69+
a = testing.shaped_arange((2, 3, 4), xp)
70+
with pytest.raises(numpy.AxisError):
71+
xp.moveaxis(a, -4, 0)
72+
6773
# len(source) != len(destination)
68-
def test_moveaxis_invalid3(self):
74+
def test_moveaxis_invalid3_1(self):
6975
for xp in (numpy, cupy):
7076
a = testing.shaped_arange((2, 3, 4), xp)
7177
with pytest.raises(ValueError):
7278
xp.moveaxis(a, [0, 1, 2], [1, 2])
7379

80+
def test_moveaxis_invalid3_2(self):
81+
for xp in (numpy, cupy):
82+
a = testing.shaped_arange((2, 3, 4), xp)
83+
with pytest.raises(ValueError):
84+
xp.moveaxis(a, 0, [1, 2])
85+
7486
# len(source) != len(destination)
75-
def test_moveaxis_invalid4(self):
87+
def test_moveaxis_invalid4_1(self):
7688
for xp in (numpy, cupy):
7789
a = testing.shaped_arange((2, 3, 4), xp)
7890
with pytest.raises(ValueError):
7991
xp.moveaxis(a, [0, 1], [1, 2, 0])
8092

93+
def test_moveaxis_invalid4_2(self):
94+
for xp in (numpy, cupy):
95+
a = testing.shaped_arange((2, 3, 4), xp)
96+
with pytest.raises(ValueError):
97+
xp.moveaxis(a, [0, 1], 1)
98+
8199
# Use the same axis twice
82100
def test_moveaxis_invalid5_1(self):
83101
for xp in (numpy, cupy):

0 commit comments

Comments
 (0)