From 06d5e9a5e824ba1a8d01f44c4aa0d3bc13b8c5ce Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 5 Aug 2023 14:41:43 -0700 Subject: [PATCH] REF: update extension test patterns --- pandas/tests/extension/base/reduce.py | 10 +++++++- .../tests/extension/decimal/test_decimal.py | 23 +++++++++---------- pandas/tests/extension/test_interval.py | 23 ++++--------------- 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/pandas/tests/extension/base/reduce.py b/pandas/tests/extension/base/reduce.py index 3d2870191ff6b..dec1b27bac094 100644 --- a/pandas/tests/extension/base/reduce.py +++ b/pandas/tests/extension/base/reduce.py @@ -23,7 +23,15 @@ def check_reduce(self, s, op_name, skipna): # that the results match. Override if you need to cast to something # other than float64. res_op = getattr(s, op_name) - exp_op = getattr(s.astype("float64"), op_name) + + try: + alt = s.astype("float64") + except TypeError: + # e.g. Interval can't cast, so let's cast to object and do + # the reduction pointwise + alt = s.astype(object) + + exp_op = getattr(alt, op_name) if op_name == "count": result = res_op() expected = exp_op() diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index ed64a9939a203..8f97b307187bd 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -152,14 +152,8 @@ def _supports_reduction(self, obj, op_name: str) -> bool: return True def check_reduce(self, s, op_name, skipna): - if op_name in ["median", "skew", "kurt", "sem"]: - msg = r"decimal does not support the .* operation" - with pytest.raises(NotImplementedError, match=msg): - getattr(s, op_name)(skipna=skipna) - elif op_name == "count": - result = getattr(s, op_name)() - expected = len(s) - s.isna().sum() - tm.assert_almost_equal(result, expected) + if op_name == "count": + return super().check_reduce(s, op_name, skipna) else: result = getattr(s, op_name)(skipna=skipna) expected = getattr(np.asarray(s), op_name)() @@ -189,12 +183,17 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs): class TestReduce(Reduce, base.BaseReduceTests): - @pytest.mark.parametrize("skipna", [True, False]) - def test_reduce_frame(self, data, all_numeric_reductions, skipna): + def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request): + if all_numeric_reductions in ["kurt", "skew", "sem", "median"]: + mark = pytest.mark.xfail(raises=NotImplementedError) + request.node.add_marker(mark) + super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) + + def test_reduce_frame(self, data, all_numeric_reductions, skipna, request): op_name = all_numeric_reductions if op_name in ["skew", "median"]: - assert not hasattr(data, op_name) - pytest.skip(f"{op_name} not an array method") + mark = pytest.mark.xfail(raises=NotImplementedError) + request.node.add_marker(mark) return super().test_reduce_frame(data, all_numeric_reductions, skipna) diff --git a/pandas/tests/extension/test_interval.py b/pandas/tests/extension/test_interval.py index d565f14fe199d..5957701b86977 100644 --- a/pandas/tests/extension/test_interval.py +++ b/pandas/tests/extension/test_interval.py @@ -18,10 +18,7 @@ from pandas.core.dtypes.dtypes import IntervalDtype -from pandas import ( - Interval, - Series, -) +from pandas import Interval from pandas.core.arrays import IntervalArray from pandas.tests.extension import base @@ -106,18 +103,8 @@ class TestInterface(BaseInterval, base.BaseInterfaceTests): class TestReduce(base.BaseReduceTests): - @pytest.mark.parametrize("skipna", [True, False]) - def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna): - op_name = all_numeric_reductions - ser = Series(data) - - if op_name in ["min", "max"]: - # IntervalArray *does* implement these - assert getattr(ser, op_name)(skipna=skipna) in data - assert getattr(data, op_name)(skipna=skipna) in data - return - - super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) + def _supports_reduction(self, obj, op_name: str) -> bool: + return op_name in ["min", "max"] class TestMethods(BaseInterval, base.BaseMethodsTests): @@ -145,9 +132,7 @@ class TestSetitem(BaseInterval, base.BaseSetitemTests): class TestPrinting(BaseInterval, base.BasePrintingTests): - @pytest.mark.xfail(reason="Interval has custom repr") - def test_array_repr(self, data, size): - super().test_array_repr() + pass class TestParsing(BaseInterval, base.BaseParsingTests):