diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index bc2048f9c31bb..aafb1900a4236 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -12,6 +12,28 @@ class BaseOpsUtil(BaseExtensionTests): + series_scalar_exc: type[Exception] | None = TypeError + frame_scalar_exc: type[Exception] | None = TypeError + series_array_exc: type[Exception] | None = TypeError + divmod_exc: type[Exception] | None = TypeError + + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + # Find the Exception, if any we expect to raise calling + # obj.__op_name__(other) + + # The self.obj_bar_exc pattern isn't great in part because it can depend + # on op_name or dtypes, but we use it here for backward-compatibility. + if op_name in ["__divmod__", "__rdivmod__"]: + return self.divmod_exc + if isinstance(obj, pd.Series) and isinstance(other, pd.Series): + return self.series_array_exc + elif isinstance(obj, pd.Series): + return self.series_scalar_exc + else: + return self.frame_scalar_exc + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): # In _check_op we check that the result of a pointwise operation # (found via _combine) matches the result of the vectorized @@ -24,17 +46,21 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): def get_op_from_name(self, op_name: str): return tm.get_op_from_name(op_name) - def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception): - op = self.get_op_from_name(op_name) - - self._check_op(ser, op, other, op_name, exc) - - # Subclasses are not expected to need to override _check_op or _combine. + # Subclasses are not expected to need to override check_opname, _check_op, + # _check_divmod_op, or _combine. # Ideally any relevant overriding can be done in _cast_pointwise_result, # get_op_from_name, and the specification of `exc`. If you find a use # case that still requires overriding _check_op or _combine, please let # us know at github.com/pandas-dev/pandas/issues @final + def check_opname(self, ser: pd.Series, op_name: str, other): + exc = self._get_expected_exception(op_name, ser, other) + op = self.get_op_from_name(op_name) + + self._check_op(ser, op, other, op_name, exc) + + # see comment on check_opname + @final def _combine(self, obj, other, op): if isinstance(obj, pd.DataFrame): if len(obj.columns) != 1: @@ -44,11 +70,14 @@ def _combine(self, obj, other, op): expected = obj.combine(other, op) return expected - # see comment on _combine + # see comment on check_opname @final def _check_op( self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError ): + # Check that the Series/DataFrame arithmetic/comparison method matches + # the pointwise result from _combine. + if exc is None: result = op(ser, other) expected = self._combine(ser, other, op) @@ -59,8 +88,14 @@ def _check_op( with pytest.raises(exc): op(ser, other) - def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception): - # divmod has multiple return values, so check separately + # see comment on check_opname + @final + def _check_divmod_op(self, ser: pd.Series, op, other): + # check that divmod behavior matches behavior of floordiv+mod + if op is divmod: + exc = self._get_expected_exception("__divmod__", ser, other) + else: + exc = self._get_expected_exception("__rdivmod__", ser, other) if exc is None: result_div, result_mod = op(ser, other) if op is divmod: @@ -96,26 +131,24 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators): # series & scalar op_name = all_arithmetic_operators ser = pd.Series(data) - self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc) + self.check_opname(ser, op_name, ser.iloc[0]) def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): # frame & scalar op_name = all_arithmetic_operators df = pd.DataFrame({"A": data}) - self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc) + self.check_opname(df, op_name, data[0]) def test_arith_series_with_array(self, data, all_arithmetic_operators): # ndarray & other series op_name = all_arithmetic_operators ser = pd.Series(data) - self.check_opname( - ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc - ) + self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser))) def test_divmod(self, data): ser = pd.Series(data) - self._check_divmod_op(ser, divmod, 1, exc=self.divmod_exc) - self._check_divmod_op(1, ops.rdivmod, ser, exc=self.divmod_exc) + self._check_divmod_op(ser, divmod, 1) + self._check_divmod_op(1, ops.rdivmod, ser) def test_divmod_series_array(self, data, data_for_twos): ser = pd.Series(data) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 944ed0dbff66e..5ee41387e1809 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import decimal import operator @@ -311,8 +313,14 @@ def test_astype_dispatches(frame): class TestArithmeticOps(base.BaseArithmeticOpsTests): - def check_opname(self, s, op_name, other, exc=None): - super().check_opname(s, op_name, other, exc=None) + series_scalar_exc = None + frame_scalar_exc = None + series_array_exc = None + + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + return None def test_arith_series_with_array(self, data, all_arithmetic_operators): op_name = all_arithmetic_operators @@ -336,10 +344,6 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators): context.traps[decimal.DivisionByZero] = divbyzerotrap context.traps[decimal.InvalidOperation] = invalidoptrap - def _check_divmod_op(self, s, op, other, exc=NotImplementedError): - # We implement divmod - super()._check_divmod_op(s, op, other, exc=None) - class TestComparisonOps(base.BaseComparisonOpsTests): def test_compare_scalar(self, data, comparison_op): diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index 8a571d9295e1f..0c9abd45a51a5 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -323,9 +323,6 @@ def test_divmod_series_array(self): # skipping because it is not implemented super().test_divmod_series_array() - def _check_divmod_op(self, s, op, other, exc=NotImplementedError): - return super()._check_divmod_op(s, op, other, exc=TypeError) - class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): pass diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 655ca9cc39c58..2438626cf0347 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -10,6 +10,8 @@ classes (if they are relevant for the extension interface for all dtypes), or be added to the array-specific tests in `pandas/tests/arrays/`. """ +from __future__ import annotations + from datetime import ( date, datetime, @@ -964,16 +966,26 @@ def _is_temporal_supported(self, opname, pa_dtype): and pa.types.is_temporal(pa_dtype) ) - def _get_scalar_exception(self, opname, pa_dtype): - arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) - if opname in { + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + if op_name in ("__divmod__", "__rdivmod__"): + return self.divmod_exc + + dtype = tm.get_dtype(obj) + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "pyarrow_dtype" + pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr] + + arrow_temporal_supported = self._is_temporal_supported(op_name, pa_dtype) + if op_name in { "__mod__", "__rmod__", }: exc = NotImplementedError elif arrow_temporal_supported: exc = None - elif opname in ["__add__", "__radd__"] and ( + elif op_name in ["__add__", "__radd__"] and ( pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) ): exc = None @@ -1060,10 +1072,6 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request) ): pytest.skip("Skip testing Python string formatting") - self.series_scalar_exc = self._get_scalar_exception( - all_arithmetic_operators, pa_dtype - ) - mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) if mark is not None: request.node.add_marker(mark) @@ -1078,10 +1086,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): ): pytest.skip("Skip testing Python string formatting") - self.frame_scalar_exc = self._get_scalar_exception( - all_arithmetic_operators, pa_dtype - ) - mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) if mark is not None: request.node.add_marker(mark) @@ -1091,10 +1095,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): def test_arith_series_with_array(self, data, all_arithmetic_operators, request): pa_dtype = data.dtype.pyarrow_dtype - self.series_array_exc = self._get_scalar_exception( - all_arithmetic_operators, pa_dtype - ) - if ( all_arithmetic_operators in ( @@ -1124,7 +1124,7 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators, request): # since ser.iloc[0] is a python scalar other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype)) - self.check_opname(ser, op_name, other, exc=self.series_array_exc) + self.check_opname(ser, op_name, other) def test_add_series_with_extension_array(self, data, request): pa_dtype = data.dtype.pyarrow_dtype diff --git a/pandas/tests/extension/test_boolean.py b/pandas/tests/extension/test_boolean.py index e5f6da5371742..508e2da214336 100644 --- a/pandas/tests/extension/test_boolean.py +++ b/pandas/tests/extension/test_boolean.py @@ -122,17 +122,14 @@ class TestMissing(base.BaseMissingTests): class TestArithmeticOps(base.BaseArithmeticOpsTests): implements = {"__sub__", "__rsub__"} - def check_opname(self, s, op_name, other, exc=None): - # overwriting to indicate ops don't raise an error - exc = None + def _get_expected_exception(self, op_name, obj, other): if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]: # match behavior with non-masked bool dtype - exc = NotImplementedError + return NotImplementedError elif op_name in self.implements: # exception message would include "numpy boolean subtract"" - exc = TypeError - - super().check_opname(s, op_name, other, exc=exc) + return TypeError + return None def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): if op_name in ( @@ -170,18 +167,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): def test_divmod_series_array(self, data, data_for_twos): super().test_divmod_series_array(data, data_for_twos) - @pytest.mark.xfail( - reason="Inconsistency between floordiv and divmod; we raise for floordiv " - "but not for divmod. This matches what we do for non-masked bool dtype." - ) - def test_divmod(self, data): - super().test_divmod(data) - class TestComparisonOps(base.BaseComparisonOpsTests): - def check_opname(self, s, op_name, other, exc=None): - # overwriting to indicate ops don't raise an error - super().check_opname(s, op_name, other, exc=None) + pass class TestReshaping(base.BaseReshapingTests): diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py index fc4dfe3af3bca..5de1debb21d93 100644 --- a/pandas/tests/extension/test_categorical.py +++ b/pandas/tests/extension/test_categorical.py @@ -268,9 +268,6 @@ def test_divmod_series_array(self): # skipping because it is not implemented pass - def _check_divmod_op(self, s, op, other, exc=NotImplementedError): - return super()._check_divmod_op(s, op, other, exc=TypeError) - class TestComparisonOps(base.BaseComparisonOpsTests): def _compare_other(self, s, data, op, other): diff --git a/pandas/tests/extension/test_datetime.py b/pandas/tests/extension/test_datetime.py index d8adc4c8c91a5..ab21f768e6521 100644 --- a/pandas/tests/extension/test_datetime.py +++ b/pandas/tests/extension/test_datetime.py @@ -130,22 +130,10 @@ class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests): class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests): implements = {"__sub__", "__rsub__"} - def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): - # frame & scalar - if all_arithmetic_operators in self.implements: - df = pd.DataFrame({"A": data}) - self.check_opname(df, all_arithmetic_operators, data[0], exc=None) - else: - # ... but not the rest. - super().test_arith_frame_with_scalar(data, all_arithmetic_operators) - - def test_arith_series_with_scalar(self, data, all_arithmetic_operators): - if all_arithmetic_operators in self.implements: - ser = pd.Series(data) - self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None) - else: - # ... but not the rest. - super().test_arith_series_with_scalar(data, all_arithmetic_operators) + def _get_expected_exception(self, op_name, obj, other): + if op_name in self.implements: + return None + return super()._get_expected_exception(op_name, obj, other) def test_add_series_with_extension_array(self, data): # Datetime + Datetime not implemented @@ -154,14 +142,6 @@ def test_add_series_with_extension_array(self, data): with pytest.raises(TypeError, match=msg): ser + data - def test_arith_series_with_array(self, data, all_arithmetic_operators): - if all_arithmetic_operators in self.implements: - ser = pd.Series(data) - self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None) - else: - # ... but not the rest. - super().test_arith_series_with_scalar(data, all_arithmetic_operators) - def test_divmod_series_array(self): # GH 23287 # skipping because it is not implemented diff --git a/pandas/tests/extension/test_masked_numeric.py b/pandas/tests/extension/test_masked_numeric.py index b171797dd6359..ce41c08cafbd6 100644 --- a/pandas/tests/extension/test_masked_numeric.py +++ b/pandas/tests/extension/test_masked_numeric.py @@ -163,21 +163,20 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): expected = expected.astype(sdtype) return expected - def check_opname(self, ser: pd.Series, op_name: str, other, exc=None): - # overwriting to indicate ops don't raise an error - super().check_opname(ser, op_name, other, exc=None) - - def _check_divmod_op(self, ser: pd.Series, op, other, exc=None): - super()._check_divmod_op(ser, op, other, None) + series_scalar_exc = None + series_array_exc = None + frame_scalar_exc = None + divmod_exc = None class TestComparisonOps(base.BaseComparisonOpsTests): + series_scalar_exc = None + series_array_exc = None + frame_scalar_exc = None + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): return pointwise_result.astype("boolean") - def check_opname(self, ser: pd.Series, op_name: str, other, exc=None): - super().check_opname(ser, op_name, other, exc=None) - def _compare_other(self, ser: pd.Series, data, op, other): op_name = f"__{op.__name__}__" self.check_opname(ser, op_name, other) diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index db191954c8d59..13645065bce14 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -281,7 +281,7 @@ def test_divmod(self, data): @skip_nested def test_divmod_series_array(self, data): ser = pd.Series(data) - self._check_divmod_op(ser, divmod, data, exc=None) + self._check_divmod_op(ser, divmod, data) @skip_nested def test_arith_series_with_scalar(self, data, all_arithmetic_operators): diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py index bc0872d359d47..7b6bc98ee8c05 100644 --- a/pandas/tests/extension/test_period.py +++ b/pandas/tests/extension/test_period.py @@ -116,36 +116,10 @@ class TestInterface(BasePeriodTests, base.BaseInterfaceTests): class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests): - implements = {"__sub__", "__rsub__"} - - def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): - # frame & scalar - if all_arithmetic_operators in self.implements: - df = pd.DataFrame({"A": data}) - self.check_opname(df, all_arithmetic_operators, data[0], exc=None) - else: - # ... but not the rest. - super().test_arith_frame_with_scalar(data, all_arithmetic_operators) - - def test_arith_series_with_scalar(self, data, all_arithmetic_operators): - # we implement substitution... - if all_arithmetic_operators in self.implements: - s = pd.Series(data) - self.check_opname(s, all_arithmetic_operators, s.iloc[0], exc=None) - else: - # ... but not the rest. - super().test_arith_series_with_scalar(data, all_arithmetic_operators) - - def test_arith_series_with_array(self, data, all_arithmetic_operators): - if all_arithmetic_operators in self.implements: - s = pd.Series(data) - self.check_opname(s, all_arithmetic_operators, s.iloc[0], exc=None) - else: - # ... but not the rest. - super().test_arith_series_with_scalar(data, all_arithmetic_operators) - - def _check_divmod_op(self, s, op, other, exc=NotImplementedError): - super()._check_divmod_op(s, op, other, exc=TypeError) + def _get_expected_exception(self, op_name, obj, other): + if op_name in ("__sub__", "__rsub__"): + return None + return super()._get_expected_exception(op_name, obj, other) def test_add_series_with_extension_array(self, data): # we don't implement + for Period diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index 77898abb70f4f..a39133c784380 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -415,10 +415,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): request.node.add_marker(mark) super().test_arith_frame_with_scalar(data, all_arithmetic_operators) - def _check_divmod_op(self, ser, op, other, exc=NotImplementedError): - # We implement divmod - super()._check_divmod_op(ser, op, other, exc=None) - class TestComparisonOps(BaseSparseTests): def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):