Skip to content

REF: update to use extension test patterns #54436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pandas/tests/extension/base/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ def check_reduce(self, s, op_name, skipna):
# that the results match. Override if you need to cast to something
# other than float64.
res_op = getattr(s, op_name)
exp_op = getattr(s.astype("float64"), op_name)

try:
alt = s.astype("float64")
except TypeError:
# e.g. Interval can't cast, so let's cast to object and do
# the reduction pointwise
alt = s.astype(object)

exp_op = getattr(alt, op_name)
if op_name == "count":
result = res_op()
expected = exp_op()
Expand Down
23 changes: 11 additions & 12 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,8 @@ def _supports_reduction(self, obj, op_name: str) -> bool:
return True

def check_reduce(self, s, op_name, skipna):
if op_name in ["median", "skew", "kurt", "sem"]:
msg = r"decimal does not support the .* operation"
with pytest.raises(NotImplementedError, match=msg):
getattr(s, op_name)(skipna=skipna)
elif op_name == "count":
result = getattr(s, op_name)()
expected = len(s) - s.isna().sum()
tm.assert_almost_equal(result, expected)
if op_name == "count":
return super().check_reduce(s, op_name, skipna)
else:
result = getattr(s, op_name)(skipna=skipna)
expected = getattr(np.asarray(s), op_name)()
Expand Down Expand Up @@ -189,12 +183,17 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):


class TestReduce(Reduce, base.BaseReduceTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
if all_numeric_reductions in ["kurt", "skew", "sem", "median"]:
mark = pytest.mark.xfail(raises=NotImplementedError)
request.node.add_marker(mark)
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)

def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
op_name = all_numeric_reductions
if op_name in ["skew", "median"]:
assert not hasattr(data, op_name)
pytest.skip(f"{op_name} not an array method")
mark = pytest.mark.xfail(raises=NotImplementedError)
request.node.add_marker(mark)

return super().test_reduce_frame(data, all_numeric_reductions, skipna)

Expand Down
23 changes: 4 additions & 19 deletions pandas/tests/extension/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@

from pandas.core.dtypes.dtypes import IntervalDtype

from pandas import (
Interval,
Series,
)
from pandas import Interval
from pandas.core.arrays import IntervalArray
from pandas.tests.extension import base

Expand Down Expand Up @@ -106,18 +103,8 @@ class TestInterface(BaseInterval, base.BaseInterfaceTests):


class TestReduce(base.BaseReduceTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
op_name = all_numeric_reductions
ser = Series(data)

if op_name in ["min", "max"]:
# IntervalArray *does* implement these
assert getattr(ser, op_name)(skipna=skipna) in data
assert getattr(data, op_name)(skipna=skipna) in data
return

super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
def _supports_reduction(self, obj, op_name: str) -> bool:
return op_name in ["min", "max"]


class TestMethods(BaseInterval, base.BaseMethodsTests):
Expand Down Expand Up @@ -145,9 +132,7 @@ class TestSetitem(BaseInterval, base.BaseSetitemTests):


class TestPrinting(BaseInterval, base.BasePrintingTests):
@pytest.mark.xfail(reason="Interval has custom repr")
def test_array_repr(self, data, size):
super().test_array_repr()
pass


class TestParsing(BaseInterval, base.BaseParsingTests):
Expand Down