Skip to content

Commit a3d6c36

Browse files
authored
REF: dont pass exception to check_opname (#54365)
* REF: dont pass exception to check_opname * future imports * privatize * REF: update pattern for check_divmod_op * typo fixup * suggested edit * mypy fixup * lint fixup
1 parent 92d1d6a commit a3d6c36

File tree

11 files changed

+98
-130
lines changed

11 files changed

+98
-130
lines changed

pandas/tests/extension/base/ops.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,28 @@
1212

1313

1414
class BaseOpsUtil(BaseExtensionTests):
15+
series_scalar_exc: type[Exception] | None = TypeError
16+
frame_scalar_exc: type[Exception] | None = TypeError
17+
series_array_exc: type[Exception] | None = TypeError
18+
divmod_exc: type[Exception] | None = TypeError
19+
20+
def _get_expected_exception(
21+
self, op_name: str, obj, other
22+
) -> type[Exception] | None:
23+
# Find the Exception, if any we expect to raise calling
24+
# obj.__op_name__(other)
25+
26+
# The self.obj_bar_exc pattern isn't great in part because it can depend
27+
# on op_name or dtypes, but we use it here for backward-compatibility.
28+
if op_name in ["__divmod__", "__rdivmod__"]:
29+
return self.divmod_exc
30+
if isinstance(obj, pd.Series) and isinstance(other, pd.Series):
31+
return self.series_array_exc
32+
elif isinstance(obj, pd.Series):
33+
return self.series_scalar_exc
34+
else:
35+
return self.frame_scalar_exc
36+
1537
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
1638
# In _check_op we check that the result of a pointwise operation
1739
# (found via _combine) matches the result of the vectorized
@@ -24,17 +46,21 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
2446
def get_op_from_name(self, op_name: str):
2547
return tm.get_op_from_name(op_name)
2648

27-
def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):
28-
op = self.get_op_from_name(op_name)
29-
30-
self._check_op(ser, op, other, op_name, exc)
31-
32-
# Subclasses are not expected to need to override _check_op or _combine.
49+
# Subclasses are not expected to need to override check_opname, _check_op,
50+
# _check_divmod_op, or _combine.
3351
# Ideally any relevant overriding can be done in _cast_pointwise_result,
3452
# get_op_from_name, and the specification of `exc`. If you find a use
3553
# case that still requires overriding _check_op or _combine, please let
3654
# us know at github.com/pandas-dev/pandas/issues
3755
@final
56+
def check_opname(self, ser: pd.Series, op_name: str, other):
57+
exc = self._get_expected_exception(op_name, ser, other)
58+
op = self.get_op_from_name(op_name)
59+
60+
self._check_op(ser, op, other, op_name, exc)
61+
62+
# see comment on check_opname
63+
@final
3864
def _combine(self, obj, other, op):
3965
if isinstance(obj, pd.DataFrame):
4066
if len(obj.columns) != 1:
@@ -44,11 +70,14 @@ def _combine(self, obj, other, op):
4470
expected = obj.combine(other, op)
4571
return expected
4672

47-
# see comment on _combine
73+
# see comment on check_opname
4874
@final
4975
def _check_op(
5076
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
5177
):
78+
# Check that the Series/DataFrame arithmetic/comparison method matches
79+
# the pointwise result from _combine.
80+
5281
if exc is None:
5382
result = op(ser, other)
5483
expected = self._combine(ser, other, op)
@@ -59,8 +88,14 @@ def _check_op(
5988
with pytest.raises(exc):
6089
op(ser, other)
6190

62-
def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception):
63-
# divmod has multiple return values, so check separately
91+
# see comment on check_opname
92+
@final
93+
def _check_divmod_op(self, ser: pd.Series, op, other):
94+
# check that divmod behavior matches behavior of floordiv+mod
95+
if op is divmod:
96+
exc = self._get_expected_exception("__divmod__", ser, other)
97+
else:
98+
exc = self._get_expected_exception("__rdivmod__", ser, other)
6499
if exc is None:
65100
result_div, result_mod = op(ser, other)
66101
if op is divmod:
@@ -96,26 +131,24 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
96131
# series & scalar
97132
op_name = all_arithmetic_operators
98133
ser = pd.Series(data)
99-
self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc)
134+
self.check_opname(ser, op_name, ser.iloc[0])
100135

101136
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
102137
# frame & scalar
103138
op_name = all_arithmetic_operators
104139
df = pd.DataFrame({"A": data})
105-
self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
140+
self.check_opname(df, op_name, data[0])
106141

107142
def test_arith_series_with_array(self, data, all_arithmetic_operators):
108143
# ndarray & other series
109144
op_name = all_arithmetic_operators
110145
ser = pd.Series(data)
111-
self.check_opname(
112-
ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc
113-
)
146+
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))
114147

115148
def test_divmod(self, data):
116149
ser = pd.Series(data)
117-
self._check_divmod_op(ser, divmod, 1, exc=self.divmod_exc)
118-
self._check_divmod_op(1, ops.rdivmod, ser, exc=self.divmod_exc)
150+
self._check_divmod_op(ser, divmod, 1)
151+
self._check_divmod_op(1, ops.rdivmod, ser)
119152

120153
def test_divmod_series_array(self, data, data_for_twos):
121154
ser = pd.Series(data)

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import decimal
24
import operator
35

@@ -311,8 +313,14 @@ def test_astype_dispatches(frame):
311313

312314

313315
class TestArithmeticOps(base.BaseArithmeticOpsTests):
314-
def check_opname(self, s, op_name, other, exc=None):
315-
super().check_opname(s, op_name, other, exc=None)
316+
series_scalar_exc = None
317+
frame_scalar_exc = None
318+
series_array_exc = None
319+
320+
def _get_expected_exception(
321+
self, op_name: str, obj, other
322+
) -> type[Exception] | None:
323+
return None
316324

317325
def test_arith_series_with_array(self, data, all_arithmetic_operators):
318326
op_name = all_arithmetic_operators
@@ -336,10 +344,6 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
336344
context.traps[decimal.DivisionByZero] = divbyzerotrap
337345
context.traps[decimal.InvalidOperation] = invalidoptrap
338346

339-
def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
340-
# We implement divmod
341-
super()._check_divmod_op(s, op, other, exc=None)
342-
343347

344348
class TestComparisonOps(base.BaseComparisonOpsTests):
345349
def test_compare_scalar(self, data, comparison_op):

pandas/tests/extension/json/test_json.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,6 @@ def test_divmod_series_array(self):
323323
# skipping because it is not implemented
324324
super().test_divmod_series_array()
325325

326-
def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
327-
return super()._check_divmod_op(s, op, other, exc=TypeError)
328-
329326

330327
class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
331328
pass

pandas/tests/extension/test_arrow.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
classes (if they are relevant for the extension interface for all dtypes), or
1111
be added to the array-specific tests in `pandas/tests/arrays/`.
1212
"""
13+
from __future__ import annotations
14+
1315
from datetime import (
1416
date,
1517
datetime,
@@ -964,16 +966,26 @@ def _is_temporal_supported(self, opname, pa_dtype):
964966
and pa.types.is_temporal(pa_dtype)
965967
)
966968

967-
def _get_scalar_exception(self, opname, pa_dtype):
968-
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
969-
if opname in {
969+
def _get_expected_exception(
970+
self, op_name: str, obj, other
971+
) -> type[Exception] | None:
972+
if op_name in ("__divmod__", "__rdivmod__"):
973+
return self.divmod_exc
974+
975+
dtype = tm.get_dtype(obj)
976+
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
977+
# attribute "pyarrow_dtype"
978+
pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr]
979+
980+
arrow_temporal_supported = self._is_temporal_supported(op_name, pa_dtype)
981+
if op_name in {
970982
"__mod__",
971983
"__rmod__",
972984
}:
973985
exc = NotImplementedError
974986
elif arrow_temporal_supported:
975987
exc = None
976-
elif opname in ["__add__", "__radd__"] and (
988+
elif op_name in ["__add__", "__radd__"] and (
977989
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
978990
):
979991
exc = None
@@ -1060,10 +1072,6 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
10601072
):
10611073
pytest.skip("Skip testing Python string formatting")
10621074

1063-
self.series_scalar_exc = self._get_scalar_exception(
1064-
all_arithmetic_operators, pa_dtype
1065-
)
1066-
10671075
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
10681076
if mark is not None:
10691077
request.node.add_marker(mark)
@@ -1078,10 +1086,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
10781086
):
10791087
pytest.skip("Skip testing Python string formatting")
10801088

1081-
self.frame_scalar_exc = self._get_scalar_exception(
1082-
all_arithmetic_operators, pa_dtype
1083-
)
1084-
10851089
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
10861090
if mark is not None:
10871091
request.node.add_marker(mark)
@@ -1091,10 +1095,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
10911095
def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
10921096
pa_dtype = data.dtype.pyarrow_dtype
10931097

1094-
self.series_array_exc = self._get_scalar_exception(
1095-
all_arithmetic_operators, pa_dtype
1096-
)
1097-
10981098
if (
10991099
all_arithmetic_operators
11001100
in (
@@ -1124,7 +1124,7 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
11241124
# since ser.iloc[0] is a python scalar
11251125
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
11261126

1127-
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
1127+
self.check_opname(ser, op_name, other)
11281128

11291129
def test_add_series_with_extension_array(self, data, request):
11301130
pa_dtype = data.dtype.pyarrow_dtype

pandas/tests/extension/test_boolean.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,14 @@ class TestMissing(base.BaseMissingTests):
122122
class TestArithmeticOps(base.BaseArithmeticOpsTests):
123123
implements = {"__sub__", "__rsub__"}
124124

125-
def check_opname(self, s, op_name, other, exc=None):
126-
# overwriting to indicate ops don't raise an error
127-
exc = None
125+
def _get_expected_exception(self, op_name, obj, other):
128126
if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
129127
# match behavior with non-masked bool dtype
130-
exc = NotImplementedError
128+
return NotImplementedError
131129
elif op_name in self.implements:
132130
# exception message would include "numpy boolean subtract""
133-
exc = TypeError
134-
135-
super().check_opname(s, op_name, other, exc=exc)
131+
return TypeError
132+
return None
136133

137134
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
138135
if op_name in (
@@ -170,18 +167,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
170167
def test_divmod_series_array(self, data, data_for_twos):
171168
super().test_divmod_series_array(data, data_for_twos)
172169

173-
@pytest.mark.xfail(
174-
reason="Inconsistency between floordiv and divmod; we raise for floordiv "
175-
"but not for divmod. This matches what we do for non-masked bool dtype."
176-
)
177-
def test_divmod(self, data):
178-
super().test_divmod(data)
179-
180170

181171
class TestComparisonOps(base.BaseComparisonOpsTests):
182-
def check_opname(self, s, op_name, other, exc=None):
183-
# overwriting to indicate ops don't raise an error
184-
super().check_opname(s, op_name, other, exc=None)
172+
pass
185173

186174

187175
class TestReshaping(base.BaseReshapingTests):

pandas/tests/extension/test_categorical.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,6 @@ def test_divmod_series_array(self):
268268
# skipping because it is not implemented
269269
pass
270270

271-
def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
272-
return super()._check_divmod_op(s, op, other, exc=TypeError)
273-
274271

275272
class TestComparisonOps(base.BaseComparisonOpsTests):
276273
def _compare_other(self, s, data, op, other):

pandas/tests/extension/test_datetime.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -130,22 +130,10 @@ class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
130130
class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
131131
implements = {"__sub__", "__rsub__"}
132132

133-
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
134-
# frame & scalar
135-
if all_arithmetic_operators in self.implements:
136-
df = pd.DataFrame({"A": data})
137-
self.check_opname(df, all_arithmetic_operators, data[0], exc=None)
138-
else:
139-
# ... but not the rest.
140-
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
141-
142-
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
143-
if all_arithmetic_operators in self.implements:
144-
ser = pd.Series(data)
145-
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
146-
else:
147-
# ... but not the rest.
148-
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
133+
def _get_expected_exception(self, op_name, obj, other):
134+
if op_name in self.implements:
135+
return None
136+
return super()._get_expected_exception(op_name, obj, other)
149137

150138
def test_add_series_with_extension_array(self, data):
151139
# Datetime + Datetime not implemented
@@ -154,14 +142,6 @@ def test_add_series_with_extension_array(self, data):
154142
with pytest.raises(TypeError, match=msg):
155143
ser + data
156144

157-
def test_arith_series_with_array(self, data, all_arithmetic_operators):
158-
if all_arithmetic_operators in self.implements:
159-
ser = pd.Series(data)
160-
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
161-
else:
162-
# ... but not the rest.
163-
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
164-
165145
def test_divmod_series_array(self):
166146
# GH 23287
167147
# skipping because it is not implemented

pandas/tests/extension/test_masked_numeric.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,20 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
163163
expected = expected.astype(sdtype)
164164
return expected
165165

166-
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
167-
# overwriting to indicate ops don't raise an error
168-
super().check_opname(ser, op_name, other, exc=None)
169-
170-
def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
171-
super()._check_divmod_op(ser, op, other, None)
166+
series_scalar_exc = None
167+
series_array_exc = None
168+
frame_scalar_exc = None
169+
divmod_exc = None
172170

173171

174172
class TestComparisonOps(base.BaseComparisonOpsTests):
173+
series_scalar_exc = None
174+
series_array_exc = None
175+
frame_scalar_exc = None
176+
175177
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
176178
return pointwise_result.astype("boolean")
177179

178-
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
179-
super().check_opname(ser, op_name, other, exc=None)
180-
181180
def _compare_other(self, ser: pd.Series, data, op, other):
182181
op_name = f"__{op.__name__}__"
183182
self.check_opname(ser, op_name, other)

pandas/tests/extension/test_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def test_divmod(self, data):
281281
@skip_nested
282282
def test_divmod_series_array(self, data):
283283
ser = pd.Series(data)
284-
self._check_divmod_op(ser, divmod, data, exc=None)
284+
self._check_divmod_op(ser, divmod, data)
285285

286286
@skip_nested
287287
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):

0 commit comments

Comments
 (0)