Skip to content

Commit 0a8ae44

Browse files
authored
BUG: NDArrayBackedExtensionArray.transpose, copy (#44974)
1 parent a3c0e7b commit 0a8ae44

File tree

8 files changed

+85
-10
lines changed

8 files changed

+85
-10
lines changed

pandas/_libs/arrays.pyx

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cimport cython
66
import numpy as np
77

88
cimport numpy as cnp
9+
from cpython cimport PyErr_Clear
910
from numpy cimport ndarray
1011

1112
cnp.import_array()
@@ -131,9 +132,20 @@ cdef class NDArrayBacked:
131132
def nbytes(self) -> int:
132133
return self._ndarray.nbytes
133134

134-
def copy(self):
135-
# NPY_ANYORDER -> same order as self._ndarray
136-
res_values = cnp.PyArray_NewCopy(self._ndarray, cnp.NPY_ANYORDER)
135+
def copy(self, order="C"):
136+
cdef:
137+
cnp.NPY_ORDER order_code
138+
int success
139+
140+
success = cnp.PyArray_OrderConverter(order, &order_code)
141+
if not success:
142+
# clear exception so that we don't get a SystemError
143+
PyErr_Clear()
144+
# same message used by numpy
145+
msg = f"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')"
146+
raise ValueError(msg)
147+
148+
res_values = cnp.PyArray_NewCopy(self._ndarray, order_code)
137149
return self._from_backing_data(res_values)
138150

139151
def delete(self, loc, axis=0):
@@ -165,3 +177,7 @@ cdef class NDArrayBacked:
165177
def T(self):
166178
res_values = self._ndarray.T
167179
return self._from_backing_data(res_values)
180+
181+
def transpose(self, *axes):
182+
res_values = self._ndarray.transpose(*axes)
183+
return self._from_backing_data(res_values)

pandas/core/arrays/datetimelike.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,9 @@ def _concat_same_type(
507507
new_obj._freq = new_freq
508508
return new_obj
509509

510-
def copy(self: DatetimeLikeArrayT) -> DatetimeLikeArrayT:
511-
new_obj = super().copy()
510+
def copy(self: DatetimeLikeArrayT, order="C") -> DatetimeLikeArrayT:
511+
# error: Unexpected keyword argument "order" for "copy"
512+
new_obj = super().copy(order=order) # type: ignore[call-arg]
512513
new_obj._freq = self.freq
513514
return new_obj
514515

pandas/tests/extension/base/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ class TestMyDtype(BaseDtypeTests):
4343
"""
4444
from pandas.tests.extension.base.casting import BaseCastingTests # noqa
4545
from pandas.tests.extension.base.constructors import BaseConstructorsTests # noqa
46-
from pandas.tests.extension.base.dim2 import Dim2CompatTests # noqa
46+
from pandas.tests.extension.base.dim2 import ( # noqa
47+
Dim2CompatTests,
48+
NDArrayBacked2DTests,
49+
)
4750
from pandas.tests.extension.base.dtype import BaseDtypeTests # noqa
4851
from pandas.tests.extension.base.getitem import BaseGetitemTests # noqa
4952
from pandas.tests.extension.base.groupby import BaseGroupbyTests # noqa

pandas/tests/extension/base/dim2.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212

1313

1414
class Dim2CompatTests(BaseExtensionTests):
15+
def test_transpose(self, data):
16+
arr2d = data.repeat(2).reshape(-1, 2)
17+
shape = arr2d.shape
18+
assert shape[0] != shape[-1] # otherwise the rest of the test is useless
19+
20+
assert arr2d.T.shape == shape[::-1]
21+
1522
def test_frame_from_2d_array(self, data):
1623
arr2d = data.repeat(2).reshape(-1, 2)
1724

@@ -244,3 +251,51 @@ def test_reductions_2d_axis1(self, data, method):
244251
expected_scalar = getattr(data, method)()
245252
res = result[0]
246253
assert is_matching_na(res, expected_scalar) or res == expected_scalar
254+
255+
256+
class NDArrayBacked2DTests(Dim2CompatTests):
257+
# More specific tests for NDArrayBackedExtensionArray subclasses
258+
259+
def test_copy_order(self, data):
260+
# We should be matching numpy semantics for the "order" keyword in 'copy'
261+
arr2d = data.repeat(2).reshape(-1, 2)
262+
assert arr2d._ndarray.flags["C_CONTIGUOUS"]
263+
264+
res = arr2d.copy()
265+
assert res._ndarray.flags["C_CONTIGUOUS"]
266+
267+
res = arr2d[::2, ::2].copy()
268+
assert res._ndarray.flags["C_CONTIGUOUS"]
269+
270+
res = arr2d.copy("F")
271+
assert not res._ndarray.flags["C_CONTIGUOUS"]
272+
assert res._ndarray.flags["F_CONTIGUOUS"]
273+
274+
res = arr2d.copy("K")
275+
assert res._ndarray.flags["C_CONTIGUOUS"]
276+
277+
res = arr2d.T.copy("K")
278+
assert not res._ndarray.flags["C_CONTIGUOUS"]
279+
assert res._ndarray.flags["F_CONTIGUOUS"]
280+
281+
# order not accepted by numpy
282+
msg = r"order must be one of 'C', 'F', 'A', or 'K' \(got 'Q'\)"
283+
with pytest.raises(ValueError, match=msg):
284+
arr2d.copy("Q")
285+
286+
# neither contiguity
287+
arr_nc = arr2d[::2]
288+
assert not arr_nc._ndarray.flags["C_CONTIGUOUS"]
289+
assert not arr_nc._ndarray.flags["F_CONTIGUOUS"]
290+
291+
assert arr_nc.copy()._ndarray.flags["C_CONTIGUOUS"]
292+
assert not arr_nc.copy()._ndarray.flags["F_CONTIGUOUS"]
293+
294+
assert arr_nc.copy("C")._ndarray.flags["C_CONTIGUOUS"]
295+
assert not arr_nc.copy("C")._ndarray.flags["F_CONTIGUOUS"]
296+
297+
assert not arr_nc.copy("F")._ndarray.flags["C_CONTIGUOUS"]
298+
assert arr_nc.copy("F")._ndarray.flags["F_CONTIGUOUS"]
299+
300+
assert arr_nc.copy("K")._ndarray.flags["C_CONTIGUOUS"]
301+
assert not arr_nc.copy("K")._ndarray.flags["F_CONTIGUOUS"]

pandas/tests/extension/test_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class TestParsing(base.BaseParsingTests):
305305
pass
306306

307307

308-
class Test2DCompat(base.Dim2CompatTests):
308+
class Test2DCompat(base.NDArrayBacked2DTests):
309309
def test_repr_2d(self, data):
310310
# Categorical __repr__ doesn't include "Categorical", so we need
311311
# to special-case

pandas/tests/extension/test_datetime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,5 +188,5 @@ class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
188188
pass
189189

190190

191-
class Test2DCompat(BaseDatetimeTests, base.Dim2CompatTests):
191+
class Test2DCompat(BaseDatetimeTests, base.NDArrayBacked2DTests):
192192
pass

pandas/tests/extension/test_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,5 +457,5 @@ class TestParsing(BaseNumPyTests, base.BaseParsingTests):
457457
pass
458458

459459

460-
class Test2DCompat(BaseNumPyTests, base.Dim2CompatTests):
460+
class Test2DCompat(BaseNumPyTests, base.NDArrayBacked2DTests):
461461
pass

pandas/tests/extension/test_period.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,5 @@ def test_EA_types(self, engine, data):
183183
super().test_EA_types(engine, data)
184184

185185

186-
class Test2DCompat(BasePeriodTests, base.Dim2CompatTests):
186+
class Test2DCompat(BasePeriodTests, base.NDArrayBacked2DTests):
187187
pass

0 commit comments

Comments
 (0)