Skip to content

Commit bbe11b2

Browse files
authored
REF: implement TestExtension (#54432)
1 parent 809f371 commit bbe11b2

19 files changed

+79
-103
lines changed

pandas/tests/extension/base/__init__.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,33 +34,61 @@ class TestMyDtype(BaseDtypeTests):
3434
wherever the test requires it. You're free to implement additional tests.
3535
3636
"""
37-
from pandas.tests.extension.base.accumulate import BaseAccumulateTests # noqa: F401
38-
from pandas.tests.extension.base.casting import BaseCastingTests # noqa: F401
39-
from pandas.tests.extension.base.constructors import BaseConstructorsTests # noqa: F401
37+
from pandas.tests.extension.base.accumulate import BaseAccumulateTests
38+
from pandas.tests.extension.base.casting import BaseCastingTests
39+
from pandas.tests.extension.base.constructors import BaseConstructorsTests
4040
from pandas.tests.extension.base.dim2 import ( # noqa: F401
4141
Dim2CompatTests,
4242
NDArrayBacked2DTests,
4343
)
44-
from pandas.tests.extension.base.dtype import BaseDtypeTests # noqa: F401
45-
from pandas.tests.extension.base.getitem import BaseGetitemTests # noqa: F401
46-
from pandas.tests.extension.base.groupby import BaseGroupbyTests # noqa: F401
47-
from pandas.tests.extension.base.index import BaseIndexTests # noqa: F401
48-
from pandas.tests.extension.base.interface import BaseInterfaceTests # noqa: F401
49-
from pandas.tests.extension.base.io import BaseParsingTests # noqa: F401
50-
from pandas.tests.extension.base.methods import BaseMethodsTests # noqa: F401
51-
from pandas.tests.extension.base.missing import BaseMissingTests # noqa: F401
44+
from pandas.tests.extension.base.dtype import BaseDtypeTests
45+
from pandas.tests.extension.base.getitem import BaseGetitemTests
46+
from pandas.tests.extension.base.groupby import BaseGroupbyTests
47+
from pandas.tests.extension.base.index import BaseIndexTests
48+
from pandas.tests.extension.base.interface import BaseInterfaceTests
49+
from pandas.tests.extension.base.io import BaseParsingTests
50+
from pandas.tests.extension.base.methods import BaseMethodsTests
51+
from pandas.tests.extension.base.missing import BaseMissingTests
5252
from pandas.tests.extension.base.ops import ( # noqa: F401
5353
BaseArithmeticOpsTests,
5454
BaseComparisonOpsTests,
5555
BaseOpsUtil,
5656
BaseUnaryOpsTests,
5757
)
58-
from pandas.tests.extension.base.printing import BasePrintingTests # noqa: F401
58+
from pandas.tests.extension.base.printing import BasePrintingTests
5959
from pandas.tests.extension.base.reduce import ( # noqa: F401
6060
BaseBooleanReduceTests,
6161
BaseNoReduceTests,
6262
BaseNumericReduceTests,
6363
BaseReduceTests,
6464
)
65-
from pandas.tests.extension.base.reshaping import BaseReshapingTests # noqa: F401
66-
from pandas.tests.extension.base.setitem import BaseSetitemTests # noqa: F401
65+
from pandas.tests.extension.base.reshaping import BaseReshapingTests
66+
from pandas.tests.extension.base.setitem import BaseSetitemTests
67+
68+
69+
# One test class that you can inherit as an alternative to inheriting all the
70+
# test classes above.
71+
# Note 1) this excludes Dim2CompatTests and NDArrayBacked2DTests.
72+
# Note 2) this uses BaseReduceTests and and _not_ BaseBooleanReduceTests,
73+
# BaseNoReduceTests, or BaseNumericReduceTests
74+
class ExtensionTests(
75+
BaseAccumulateTests,
76+
BaseCastingTests,
77+
BaseConstructorsTests,
78+
BaseDtypeTests,
79+
BaseGetitemTests,
80+
BaseGroupbyTests,
81+
BaseIndexTests,
82+
BaseInterfaceTests,
83+
BaseParsingTests,
84+
BaseMethodsTests,
85+
BaseMissingTests,
86+
BaseArithmeticOpsTests,
87+
BaseComparisonOpsTests,
88+
BaseUnaryOpsTests,
89+
BasePrintingTests,
90+
BaseReduceTests,
91+
BaseReshapingTests,
92+
BaseSetitemTests,
93+
):
94+
pass

pandas/tests/extension/base/accumulate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import pandas as pd
44
import pandas._testing as tm
5-
from pandas.tests.extension.base.base import BaseExtensionTests
65

76

8-
class BaseAccumulateTests(BaseExtensionTests):
7+
class BaseAccumulateTests:
98
"""
109
Accumulation specific tests. Generally these only
1110
make sense for numeric/boolean operations.

pandas/tests/extension/base/casting.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
import pandas as pd
77
import pandas._testing as tm
88
from pandas.core.internals.blocks import NumpyBlock
9-
from pandas.tests.extension.base.base import BaseExtensionTests
109

1110

12-
class BaseCastingTests(BaseExtensionTests):
11+
class BaseCastingTests:
1312
"""Casting to and from ExtensionDtypes"""
1413

1514
def test_astype_object_series(self, all_data):

pandas/tests/extension/base/constructors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import pandas._testing as tm
66
from pandas.api.extensions import ExtensionArray
77
from pandas.core.internals.blocks import EABackedBlock
8-
from pandas.tests.extension.base.base import BaseExtensionTests
98

109

11-
class BaseConstructorsTests(BaseExtensionTests):
10+
class BaseConstructorsTests:
1211
def test_from_sequence_from_cls(self, data):
1312
result = type(data)._from_sequence(data, dtype=data.dtype)
1413
tm.assert_extension_array_equal(result, data)

pandas/tests/extension/base/dim2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
import pandas as pd
1515
import pandas._testing as tm
1616
from pandas.core.arrays.integer import NUMPY_INT_TO_DTYPE
17-
from pandas.tests.extension.base.base import BaseExtensionTests
1817

1918

20-
class Dim2CompatTests(BaseExtensionTests):
19+
class Dim2CompatTests:
2120
# Note: these are ONLY for ExtensionArray subclasses that support 2D arrays.
2221
# i.e. not for pyarrow-backed EAs.
2322

pandas/tests/extension/base/dtype.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
is_object_dtype,
99
is_string_dtype,
1010
)
11-
from pandas.tests.extension.base.base import BaseExtensionTests
1211

1312

14-
class BaseDtypeTests(BaseExtensionTests):
13+
class BaseDtypeTests:
1514
"""Base class for ExtensionDtype classes"""
1615

1716
def test_name(self, dtype):

pandas/tests/extension/base/getitem.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import pandas as pd
55
import pandas._testing as tm
6-
from pandas.tests.extension.base.base import BaseExtensionTests
76

87

9-
class BaseGetitemTests(BaseExtensionTests):
8+
class BaseGetitemTests:
109
"""Tests for ExtensionArray.__getitem__."""
1110

1211
def test_iloc_series(self, data):

pandas/tests/extension/base/groupby.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
import pandas as pd
1313
import pandas._testing as tm
14-
from pandas.tests.extension.base.base import BaseExtensionTests
1514

1615

17-
class BaseGroupbyTests(BaseExtensionTests):
16+
class BaseGroupbyTests:
1817
"""Groupby-specific tests."""
1918

2019
def test_grouping_grouper(self, data_for_grouping):

pandas/tests/extension/base/index.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
Tests for Indexes backed by arbitrary ExtensionArrays.
33
"""
44
import pandas as pd
5-
from pandas.tests.extension.base.base import BaseExtensionTests
65

76

8-
class BaseIndexTests(BaseExtensionTests):
7+
class BaseIndexTests:
98
"""Tests for Index object backed by an ExtensionArray"""
109

1110
def test_index_from_array(self, data):

pandas/tests/extension/base/interface.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
import pandas as pd
88
import pandas._testing as tm
9-
from pandas.tests.extension.base.base import BaseExtensionTests
109

1110

12-
class BaseInterfaceTests(BaseExtensionTests):
11+
class BaseInterfaceTests:
1312
"""Tests that the basic interface is satisfied."""
1413

1514
# ------------------------------------------------------------------------

pandas/tests/extension/base/io.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66
import pandas as pd
77
import pandas._testing as tm
8-
from pandas.tests.extension.base.base import BaseExtensionTests
98

109

11-
class BaseParsingTests(BaseExtensionTests):
10+
class BaseParsingTests:
1211
@pytest.mark.parametrize("engine", ["c", "python"])
1312
def test_EA_types(self, engine, data):
1413
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})

pandas/tests/extension/base/methods.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
import pandas as pd
1313
import pandas._testing as tm
1414
from pandas.core.sorting import nargsort
15-
from pandas.tests.extension.base.base import BaseExtensionTests
1615

1716

18-
class BaseMethodsTests(BaseExtensionTests):
17+
class BaseMethodsTests:
1918
"""Various Series and DataFrame methods."""
2019

2120
def test_hash_pandas_object(self, data):

pandas/tests/extension/base/missing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import pandas as pd
55
import pandas._testing as tm
6-
from pandas.tests.extension.base.base import BaseExtensionTests
76

87

9-
class BaseMissingTests(BaseExtensionTests):
8+
class BaseMissingTests:
109
def test_isna(self, data_missing):
1110
expected = np.array([True, False])
1211

pandas/tests/extension/base/ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import pandas as pd
99
import pandas._testing as tm
1010
from pandas.core import ops
11-
from pandas.tests.extension.base.base import BaseExtensionTests
1211

1312

14-
class BaseOpsUtil(BaseExtensionTests):
13+
class BaseOpsUtil:
1514
series_scalar_exc: type[Exception] | None = TypeError
1615
frame_scalar_exc: type[Exception] | None = TypeError
1716
series_array_exc: type[Exception] | None = TypeError

pandas/tests/extension/base/printing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import pytest
44

55
import pandas as pd
6-
from pandas.tests.extension.base.base import BaseExtensionTests
76

87

9-
class BasePrintingTests(BaseExtensionTests):
8+
class BasePrintingTests:
109
"""Tests checking the formatting of your EA when printed."""
1110

1211
@pytest.mark.parametrize("size", ["big", "small"])

pandas/tests/extension/base/reduce.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import pandas as pd
66
import pandas._testing as tm
77
from pandas.api.types import is_numeric_dtype
8-
from pandas.tests.extension.base.base import BaseExtensionTests
98

109

11-
class BaseReduceTests(BaseExtensionTests):
10+
class BaseReduceTests:
1211
"""
1312
Reduction specific tests. Generally these only
1413
make sense for numeric/boolean operations.

pandas/tests/extension/base/reshaping.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import pandas._testing as tm
88
from pandas.api.extensions import ExtensionArray
99
from pandas.core.internals.blocks import EABackedBlock
10-
from pandas.tests.extension.base.base import BaseExtensionTests
1110

1211

13-
class BaseReshapingTests(BaseExtensionTests):
12+
class BaseReshapingTests:
1413
"""Tests for reshaping and concatenation."""
1514

1615
@pytest.mark.parametrize("in_frame", [True, False])

pandas/tests/extension/base/setitem.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import pandas as pd
55
import pandas._testing as tm
6-
from pandas.tests.extension.base.base import BaseExtensionTests
76

87

9-
class BaseSetitemTests(BaseExtensionTests):
8+
class BaseSetitemTests:
109
@pytest.fixture(
1110
params=[
1211
lambda x: x.index,

pandas/tests/extension/test_interval.py

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def data_missing():
4747
return IntervalArray.from_tuples([None, (0, 1)])
4848

4949

50+
@pytest.fixture
51+
def data_for_twos():
52+
pytest.skip("Not a numeric dtype")
53+
54+
5055
@pytest.fixture
5156
def data_for_sorting():
5257
return IntervalArray.from_tuples([(1, 2), (2, 3), (0, 1)])
@@ -65,74 +70,34 @@ def data_for_grouping():
6570
return IntervalArray.from_tuples([b, b, None, None, a, a, b, c])
6671

6772

68-
class BaseInterval:
69-
pass
70-
71-
72-
class TestDtype(BaseInterval, base.BaseDtypeTests):
73-
pass
74-
75-
76-
class TestCasting(BaseInterval, base.BaseCastingTests):
77-
pass
78-
79-
80-
class TestConstructors(BaseInterval, base.BaseConstructorsTests):
81-
pass
82-
83-
84-
class TestGetitem(BaseInterval, base.BaseGetitemTests):
85-
pass
86-
87-
88-
class TestIndex(base.BaseIndexTests):
89-
pass
90-
91-
92-
class TestGrouping(BaseInterval, base.BaseGroupbyTests):
93-
pass
94-
95-
96-
class TestInterface(BaseInterval, base.BaseInterfaceTests):
97-
pass
73+
class TestIntervalArray(base.ExtensionTests):
74+
divmod_exc = TypeError
9875

99-
100-
class TestReduce(base.BaseReduceTests):
10176
def _supports_reduction(self, obj, op_name: str) -> bool:
10277
return op_name in ["min", "max"]
10378

104-
105-
class TestMethods(BaseInterval, base.BaseMethodsTests):
10679
@pytest.mark.xfail(
10780
reason="Raises with incorrect message bc it disallows *all* listlikes "
10881
"instead of just wrong-length listlikes"
10982
)
11083
def test_fillna_length_mismatch(self, data_missing):
11184
super().test_fillna_length_mismatch(data_missing)
11285

113-
114-
class TestMissing(BaseInterval, base.BaseMissingTests):
115-
def test_fillna_non_scalar_raises(self, data_missing):
116-
msg = "can only insert Interval objects and NA into an IntervalArray"
117-
with pytest.raises(TypeError, match=msg):
118-
data_missing.fillna([1, 1])
119-
120-
121-
class TestReshaping(BaseInterval, base.BaseReshapingTests):
122-
pass
123-
124-
125-
class TestSetitem(BaseInterval, base.BaseSetitemTests):
126-
pass
127-
128-
129-
class TestPrinting(BaseInterval, base.BasePrintingTests):
130-
pass
131-
132-
133-
class TestParsing(BaseInterval, base.BaseParsingTests):
13486
@pytest.mark.parametrize("engine", ["c", "python"])
13587
def test_EA_types(self, engine, data):
13688
expected_msg = r".*must implement _from_sequence_of_strings.*"
13789
with pytest.raises(NotImplementedError, match=expected_msg):
13890
super().test_EA_types(engine, data)
91+
92+
@pytest.mark.xfail(
93+
reason="Looks like the test (incorrectly) implicitly assumes int/bool dtype"
94+
)
95+
def test_invert(self, data):
96+
super().test_invert(data)
97+
98+
99+
# TODO: either belongs in tests.arrays.interval or move into base tests.
100+
def test_fillna_non_scalar_raises(data_missing):
101+
msg = "can only insert Interval objects and NA into an IntervalArray"
102+
with pytest.raises(TypeError, match=msg):
103+
data_missing.fillna([1, 1])

0 commit comments

Comments
 (0)