Skip to content

BUG: NDArrayBackedExtensionArray.transpose, copy #44974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions pandas/_libs/arrays.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cimport cython
import numpy as np

cimport numpy as cnp
from cpython cimport PyErr_Clear
from numpy cimport ndarray

cnp.import_array()
Expand Down Expand Up @@ -131,9 +132,20 @@ cdef class NDArrayBacked:
def nbytes(self) -> int:
return self._ndarray.nbytes

def copy(self):
# NPY_ANYORDER -> same order as self._ndarray
res_values = cnp.PyArray_NewCopy(self._ndarray, cnp.NPY_ANYORDER)
def copy(self, order="C"):
cdef:
cnp.NPY_ORDER order_code
int success

success = cnp.PyArray_OrderConverter(order, &order_code)
if not success:
# clear exception so that we don't get a SystemError
PyErr_Clear()
# same message used by numpy
msg = f"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')"
raise ValueError(msg)

res_values = cnp.PyArray_NewCopy(self._ndarray, order_code)
return self._from_backing_data(res_values)

def delete(self, loc, axis=0):
Expand Down Expand Up @@ -165,3 +177,7 @@ cdef class NDArrayBacked:
def T(self):
res_values = self._ndarray.T
return self._from_backing_data(res_values)

def transpose(self, *axes):
res_values = self._ndarray.transpose(*axes)
return self._from_backing_data(res_values)
5 changes: 3 additions & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,9 @@ def _concat_same_type(
new_obj._freq = new_freq
return new_obj

def copy(self: DatetimeLikeArrayT) -> DatetimeLikeArrayT:
new_obj = super().copy()
def copy(self: DatetimeLikeArrayT, order="C") -> DatetimeLikeArrayT:
# error: Unexpected keyword argument "order" for "copy"
new_obj = super().copy(order=order) # type: ignore[call-arg]
new_obj._freq = self.freq
return new_obj

Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ class TestMyDtype(BaseDtypeTests):
"""
from pandas.tests.extension.base.casting import BaseCastingTests # noqa
from pandas.tests.extension.base.constructors import BaseConstructorsTests # noqa
from pandas.tests.extension.base.dim2 import Dim2CompatTests # noqa
from pandas.tests.extension.base.dim2 import ( # noqa
Dim2CompatTests,
NDArrayBacked2DTests,
)
from pandas.tests.extension.base.dtype import BaseDtypeTests # noqa
from pandas.tests.extension.base.getitem import BaseGetitemTests # noqa
from pandas.tests.extension.base.groupby import BaseGroupbyTests # noqa
Expand Down
55 changes: 55 additions & 0 deletions pandas/tests/extension/base/dim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@


class Dim2CompatTests(BaseExtensionTests):
def test_transpose(self, data):
arr2d = data.repeat(2).reshape(-1, 2)
shape = arr2d.shape
assert shape[0] != shape[-1] # otherwise the rest of the test is useless

assert arr2d.T.shape == shape[::-1]

def test_frame_from_2d_array(self, data):
arr2d = data.repeat(2).reshape(-1, 2)

Expand Down Expand Up @@ -244,3 +251,51 @@ def test_reductions_2d_axis1(self, data, method):
expected_scalar = getattr(data, method)()
res = result[0]
assert is_matching_na(res, expected_scalar) or res == expected_scalar


class NDArrayBacked2DTests(Dim2CompatTests):
# More specific tests for NDArrayBackedExtensionArray subclasses

def test_copy_order(self, data):
# We should be matching numpy semantics for the "order" keyword in 'copy'
arr2d = data.repeat(2).reshape(-1, 2)
assert arr2d._ndarray.flags["C_CONTIGUOUS"]

res = arr2d.copy()
assert res._ndarray.flags["C_CONTIGUOUS"]

res = arr2d[::2, ::2].copy()
assert res._ndarray.flags["C_CONTIGUOUS"]

res = arr2d.copy("F")
assert not res._ndarray.flags["C_CONTIGUOUS"]
assert res._ndarray.flags["F_CONTIGUOUS"]

res = arr2d.copy("K")
assert res._ndarray.flags["C_CONTIGUOUS"]

res = arr2d.T.copy("K")
assert not res._ndarray.flags["C_CONTIGUOUS"]
assert res._ndarray.flags["F_CONTIGUOUS"]

# order not accepted by numpy
msg = r"order must be one of 'C', 'F', 'A', or 'K' \(got 'Q'\)"
with pytest.raises(ValueError, match=msg):
arr2d.copy("Q")

# neither contiguity
arr_nc = arr2d[::2]
assert not arr_nc._ndarray.flags["C_CONTIGUOUS"]
assert not arr_nc._ndarray.flags["F_CONTIGUOUS"]

assert arr_nc.copy()._ndarray.flags["C_CONTIGUOUS"]
assert not arr_nc.copy()._ndarray.flags["F_CONTIGUOUS"]

assert arr_nc.copy("C")._ndarray.flags["C_CONTIGUOUS"]
assert not arr_nc.copy("C")._ndarray.flags["F_CONTIGUOUS"]

assert not arr_nc.copy("F")._ndarray.flags["C_CONTIGUOUS"]
assert arr_nc.copy("F")._ndarray.flags["F_CONTIGUOUS"]

assert arr_nc.copy("K")._ndarray.flags["C_CONTIGUOUS"]
assert not arr_nc.copy("K")._ndarray.flags["F_CONTIGUOUS"]
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class TestParsing(base.BaseParsingTests):
pass


class Test2DCompat(base.Dim2CompatTests):
class Test2DCompat(base.NDArrayBacked2DTests):
def test_repr_2d(self, data):
# Categorical __repr__ doesn't include "Categorical", so we need
# to special-case
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,5 +188,5 @@ class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
pass


class Test2DCompat(BaseDatetimeTests, base.Dim2CompatTests):
class Test2DCompat(BaseDatetimeTests, base.NDArrayBacked2DTests):
pass
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,5 +457,5 @@ class TestParsing(BaseNumPyTests, base.BaseParsingTests):
pass


class Test2DCompat(BaseNumPyTests, base.Dim2CompatTests):
class Test2DCompat(BaseNumPyTests, base.NDArrayBacked2DTests):
pass
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,5 +183,5 @@ def test_EA_types(self, engine, data):
super().test_EA_types(engine, data)


class Test2DCompat(BasePeriodTests, base.Dim2CompatTests):
class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests):
pass