Skip to content

add support for axes as list in dpnp.ndarray.transpose #1770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,21 +1317,31 @@ def transpose(self, *axes):

For full documentation refer to :obj:`numpy.ndarray.transpose`.

Parameters
----------
axes : None, tuple or list of ints, n ints, optional
* ``None`` or no argument: reverses the order of the axes.
* tuple or list of ints: `i` in the `j`-th place in the tuple/list
means that the array’s `i`-th axis becomes the transposed
array’s `j`-th axis.
* n ints: same as an n-tuple/n-list of the same ints (this form is
intended simply as a “convenience” alternative to the tuple form).

Returns
-------
y : dpnp.ndarray
out : dpnp.ndarray
View of the array with its axes suitably permuted.

See Also
--------
:obj:`dpnp.transpose` : Equivalent function.
:obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed.
:obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data.
:obj:`dpnp.transpose` : Equivalent function.
:obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed.
:obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data.

Examples
--------
>>> import dpnp as dp
>>> a = dp.array([[1, 2], [3, 4]])
>>> import dpnp as np
>>> a = np.array([[1, 2], [3, 4]])
>>> a
array([[1, 2],
[3, 4]])
Expand All @@ -1342,7 +1352,7 @@ def transpose(self, *axes):
array([[1, 3],
[2, 4]])

>>> a = dp.array([1, 2, 3, 4])
>>> a = np.array([1, 2, 3, 4])
>>> a
array([1, 2, 3, 4])
>>> a.transpose()
Expand All @@ -1355,7 +1365,7 @@ def transpose(self, *axes):
return self

axes_len = len(axes)
if axes_len == 1 and isinstance(axes[0], tuple):
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
axes = axes[0]

res = self.__new__(dpnp_array)
Expand Down
12 changes: 6 additions & 6 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,12 @@ def tensordot(a, b, axes=2):
Second input array. Both inputs `a` and `b` can not be scalars
at the same time.
axes : int or (2,) array_like
* integer_like
If an int `N`, sum over the last `N` axes of `a` and the first `N`
axes of `b` in order. The sizes of the corresponding axes must match.
* (2,) array_like
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements array_like must be of the same length.
* integer_like: If an int `N`, sum over the last `N` axes of `a` and
the first `N` axes of `b` in order. The sizes of the corresponding
axes must match.
* (2,) array_like: A list of axes to be summed over, first sequence
applying to `a`, second to `b`. Both elements array_like must be of
the same length.

Returns
-------
Expand Down
27 changes: 14 additions & 13 deletions dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,12 +1861,13 @@ def transpose(a, axes=None):
----------
a : {dpnp.ndarray, usm_ndarray}
Input array.
axes : tuple or list of ints, optional
axes : None, tuple or list of ints, optional
If specified, it must be a tuple or list which contains a permutation
of [0, 1, ..., N-1] where N is the number of axes of `a`.
The `i`'th axis of the returned array will correspond to the axis
numbered ``axes[i]`` of the input. If not specified, defaults to
``range(a.ndim)[::-1]``, which reverses the order of the axes.
numbered ``axes[i]`` of the input. If not specified or ``None``,
defaults to ``range(a.ndim)[::-1]``, which reverses the order of
the axes.

Returns
-------
Expand All @@ -1881,35 +1882,35 @@ def transpose(a, axes=None):

Examples
--------
>>> import dpnp as dp
>>> a = dp.array([[1, 2], [3, 4]])
>>> import dpnp as np
>>> a = np.array([[1, 2], [3, 4]])
>>> a
array([[1, 2],
[3, 4]])
>>> dp.transpose(a)
>>> np.transpose(a)
array([[1, 3],
[2, 4]])

>>> a = dp.array([1, 2, 3, 4])
>>> a = np.array([1, 2, 3, 4])
>>> a
array([1, 2, 3, 4])
>>> dp.transpose(a)
>>> np.transpose(a)
array([1, 2, 3, 4])

>>> a = dp.ones((1, 2, 3))
>>> dp.transpose(a, (1, 0, 2)).shape
>>> a = np.ones((1, 2, 3))
>>> np.transpose(a, (1, 0, 2)).shape
(2, 1, 3)

>>> a = dp.ones((2, 3, 4, 5))
>>> dp.transpose(a).shape
>>> a = np.ones((2, 3, 4, 5))
>>> np.transpose(a).shape
(5, 4, 3, 2)

"""

if isinstance(a, dpnp_array):
array = a
elif isinstance(a, dpt.usm_ndarray):
array = dpnp_array._create_from_usm_ndarray(a.get_array())
array = dpnp_array._create_from_usm_ndarray(a)
else:
raise TypeError(
f"An array must be any of supported type, but got {type(a)}"
Expand Down
24 changes: 12 additions & 12 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2802,22 +2802,22 @@ def sum(
Data type of the returned array. If ``None``, the default data
type is inferred from the "kind" of the input array data type.
* If `a` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `a` is allocated.
the returned array will have the default real-valued
floating-point data type for the device where input
array `a` is allocated.
* If `a` has signed integral data type, the returned array
will have the default signed integral type for the device
where input array `a` is allocated.
will have the default signed integral type for the device
where input array `a` is allocated.
* If `a` has unsigned integral data type, the returned array
will have the default unsigned integral type for the device
where input array `a` is allocated.
will have the default unsigned integral type for the device
where input array `a` is allocated.
* If `a` has a complex-valued floating-point data type,
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `a` is allocated.
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `a` is allocated.
* If `a` has a boolean data type, the returned array will
have the default signed integral type for the device
where input array `a` is allocated.
have the default signed integral type for the device
where input array `a` is allocated.
If the data type (either specified or resolved) differs from the
data type of `a`, the input array elements are cast to the
specified data type before computing the sum.
Expand Down
24 changes: 12 additions & 12 deletions dpnp/dpnp_iface_nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,22 +717,22 @@ def nansum(
Data type of the returned array. If ``None``, the default data
type is inferred from the "kind" of the input array data type.
* If `a` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `a` is allocated.
the returned array will have the default real-valued
floating-point data type for the device where input
array `a` is allocated.
* If `a` has signed integral data type, the returned array
will have the default signed integral type for the device
where input array `a` is allocated.
will have the default signed integral type for the device
where input array `a` is allocated.
* If `a` has unsigned integral data type, the returned array
will have the default unsigned integral type for the device
where input array `a` is allocated.
will have the default unsigned integral type for the device
where input array `a` is allocated.
* If `a` has a complex-valued floating-point data type,
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `a` is allocated.
the returned array will have the default complex-valued
floating-pointer data type for the device where input
array `a` is allocated.
* If `a` has a boolean data type, the returned array will
have the default signed integral type for the device
where input array `a` is allocated.
have the default signed integral type for the device
where input array `a` is allocated.
If the data type (either specified or resolved) differs from the
data type of `a`, the input array elements are cast to the
specified data type before computing the sum.
Expand Down
12 changes: 6 additions & 6 deletions dpnp/dpnp_iface_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,14 +1355,14 @@ def logsumexp(x, axis=None, out=None, dtype=None, keepdims=False):
Data type of the returned array. If ``None``, the default data
type is inferred from the "kind" of the input array data type.
* If `x` has a real-valued floating-point data type,
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
the returned array will have the default real-valued
floating-point data type for the device where input
array `x` is allocated.
* If `x` has a boolean or integral data type, the returned array
will have the default floating point data type for the device
where input array `x` is allocated.
will have the default floating point data type for the device
where input array `x` is allocated.
* If `x` has a complex-valued floating-point data type,
an error is raised.
an error is raised.
If the data type (either specified or resolved) differs from the
data type of `x`, the input array elements are cast to the
specified data type before computing the result. Default: ``None``.
Expand Down
36 changes: 34 additions & 2 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_unique(array):


class TestTranspose:
@pytest.mark.parametrize("axes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]])
def test_2d_with_axes(self, axes):
na = numpy.array([[1, 2], [3, 4]])
da = dpnp.array(na)
Expand All @@ -124,7 +124,22 @@ def test_2d_with_axes(self, axes):
result = dpnp.transpose(da, axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize("axes", [(1, 0, 2), ((1, 0, 2),)])
# ndarray
expected = na.transpose(axes)
result = da.transpose(axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize(
"axes",
[
(1, 0, 2),
[1, 0, 2],
((1, 0, 2),),
([1, 0, 2],),
[(1, 0, 2)],
[[1, 0, 2]],
],
)
def test_3d_with_packed_axes(self, axes):
na = numpy.ones((1, 2, 3))
da = dpnp.array(na)
Expand All @@ -133,10 +148,27 @@ def test_3d_with_packed_axes(self, axes):
result = da.transpose(*axes)
assert_array_equal(expected, result)

# ndarray
expected = na.transpose(*axes)
result = da.transpose(*axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize("shape", [(10,), (2, 4), (5, 3, 7), (3, 8, 4, 1)])
def test_none_axes(self, shape):
na = numpy.ones(shape)
da = dpnp.ones(shape)

assert_array_equal(numpy.transpose(na), dpnp.transpose(da))
assert_array_equal(numpy.transpose(na, None), dpnp.transpose(da, None))

# ndarray
assert_array_equal(na.transpose(), da.transpose())
assert_array_equal(na.transpose(None), da.transpose(None))

def test_ndarray_axes_n_int(self):
na = numpy.ones((1, 2, 3))
da = dpnp.array(na)

expected = na.transpose(1, 0, 2)
result = da.transpose(1, 0, 2)
assert_array_equal(expected, result)
22 changes: 20 additions & 2 deletions tests/third_party/cupy/manipulation_tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,38 @@ def test_moveaxis_invalid2_2(self):
with pytest.raises(numpy.AxisError):
xp.moveaxis(a, [0, -4], [1, 2])

def test_moveaxis_invalid2_3(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(numpy.AxisError):
xp.moveaxis(a, -4, 0)

# len(source) != len(destination)
def test_moveaxis_invalid3(self):
def test_moveaxis_invalid3_1(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1, 2], [1, 2])

def test_moveaxis_invalid3_2(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, 0, [1, 2])

# len(source) != len(destination)
def test_moveaxis_invalid4(self):
def test_moveaxis_invalid4_1(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1], [1, 2, 0])

def test_moveaxis_invalid4_2(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1], 1)

# Use the same axis twice
def test_moveaxis_invalid5_1(self):
for xp in (numpy, cupy):
Expand Down