Skip to content

Commit fb6f704

Browse files
authored
REF: implement BaseOpsUtil._cast_pointwise_result (#54366)
* REF: implement BaseOpsUtil._cast_pointwise_result * REF: use _cast_pointwise_result in arrow tests
1 parent 7cbf949 commit fb6f704

File tree

4 files changed

+75
-87
lines changed

4 files changed

+75
-87
lines changed

pandas/tests/extension/base/ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import final
4+
35
import numpy as np
46
import pytest
57

@@ -10,6 +12,15 @@
1012

1113

1214
class BaseOpsUtil(BaseExtensionTests):
15+
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
16+
# In _check_op we check that the result of a pointwise operation
17+
# (found via _combine) matches the result of the vectorized
18+
# operation obj.__op_name__(other).
19+
# In some cases pandas dtype inference on the scalar result may not
20+
# give a matching dtype even if both operations are behaving "correctly".
21+
# In these cases, do extra required casting here.
22+
return pointwise_result
23+
1324
def get_op_from_name(self, op_name: str):
1425
return tm.get_op_from_name(op_name)
1526

@@ -18,6 +29,12 @@ def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):
1829

1930
self._check_op(ser, op, other, op_name, exc)
2031

32+
# Subclasses are not expected to need to override _check_op or _combine.
33+
# Ideally any relevant overriding can be done in _cast_pointwise_result,
34+
# get_op_from_name, and the specification of `exc`. If you find a use
35+
# case that still requires overriding _check_op or _combine, please let
36+
# us know at github.com/pandas-dev/pandas/issues
37+
@final
2138
def _combine(self, obj, other, op):
2239
if isinstance(obj, pd.DataFrame):
2340
if len(obj.columns) != 1:
@@ -27,12 +44,15 @@ def _combine(self, obj, other, op):
2744
expected = obj.combine(other, op)
2845
return expected
2946

47+
# see comment on _combine
48+
@final
3049
def _check_op(
3150
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
3251
):
3352
if exc is None:
3453
result = op(ser, other)
3554
expected = self._combine(ser, other, op)
55+
expected = self._cast_pointwise_result(op_name, ser, other, expected)
3656
assert isinstance(result, type(ser))
3757
tm.assert_equal(result, expected)
3858
else:

pandas/tests/extension/test_arrow.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -873,11 +873,11 @@ def rtruediv(x, y):
873873

874874
return tm.get_op_from_name(op_name)
875875

876-
def _combine(self, obj, other, op):
876+
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
877877
# BaseOpsUtil._combine can upcast expected dtype
878878
# (because it generates expected on python scalars)
879879
# while ArrowExtensionArray maintains original type
880-
expected = base.BaseArithmeticOpsTests._combine(self, obj, other, op)
880+
expected = pointwise_result
881881

882882
was_frame = False
883883
if isinstance(expected, pd.DataFrame):
@@ -895,7 +895,7 @@ def _combine(self, obj, other, op):
895895
pa.types.is_floating(orig_pa_type)
896896
or (
897897
pa.types.is_integer(orig_pa_type)
898-
and op.__name__ not in ["truediv", "rtruediv"]
898+
and op_name not in ["__truediv__", "__rtruediv__"]
899899
)
900900
or pa.types.is_duration(orig_pa_type)
901901
or pa.types.is_timestamp(orig_pa_type)
@@ -906,7 +906,7 @@ def _combine(self, obj, other, op):
906906
# ArrowExtensionArray does not upcast
907907
return expected
908908
elif not (
909-
(op is operator.floordiv and pa.types.is_integer(orig_pa_type))
909+
(op_name == "__floordiv__" and pa.types.is_integer(orig_pa_type))
910910
or pa.types.is_duration(orig_pa_type)
911911
or pa.types.is_timestamp(orig_pa_type)
912912
or pa.types.is_date(orig_pa_type)
@@ -943,14 +943,14 @@ def _combine(self, obj, other, op):
943943
):
944944
# decimal precision can resize in the result type depending on data
945945
# just compare the float values
946-
alt = op(obj, other)
946+
alt = getattr(obj, op_name)(other)
947947
alt_dtype = tm.get_dtype(alt)
948948
assert isinstance(alt_dtype, ArrowDtype)
949-
if op is operator.pow and isinstance(other, Decimal):
949+
if op_name == "__pow__" and isinstance(other, Decimal):
950950
# TODO: would it make more sense to retain Decimal here?
951951
alt_dtype = ArrowDtype(pa.float64())
952952
elif (
953-
op is operator.pow
953+
op_name == "__pow__"
954954
and isinstance(other, pd.Series)
955955
and other.dtype == original_dtype
956956
):

pandas/tests/extension/test_boolean.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
be added to the array-specific tests in `pandas/tests/arrays/`.
1414
1515
"""
16+
import operator
17+
1618
import numpy as np
1719
import pytest
1820

@@ -23,6 +25,7 @@
2325

2426
import pandas as pd
2527
import pandas._testing as tm
28+
from pandas.core import roperator
2629
from pandas.core.arrays.boolean import BooleanDtype
2730
from pandas.tests.extension import base
2831

@@ -125,41 +128,40 @@ def check_opname(self, s, op_name, other, exc=None):
125128
if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
126129
# match behavior with non-masked bool dtype
127130
exc = NotImplementedError
131+
elif op_name in self.implements:
132+
# exception message would include "numpy boolean subtract""
133+
exc = TypeError
134+
128135
super().check_opname(s, op_name, other, exc=exc)
129136

130-
def _check_op(self, obj, op, other, op_name, exc=NotImplementedError):
131-
if exc is None:
132-
if op_name in self.implements:
133-
msg = r"numpy boolean subtract"
134-
with pytest.raises(TypeError, match=msg):
135-
op(obj, other)
136-
return
137-
138-
result = op(obj, other)
139-
expected = self._combine(obj, other, op)
140-
141-
if op_name in (
142-
"__floordiv__",
143-
"__rfloordiv__",
144-
"__pow__",
145-
"__rpow__",
146-
"__mod__",
147-
"__rmod__",
148-
):
149-
# combine keeps boolean type
150-
expected = expected.astype("Int8")
151-
elif op_name in ("__truediv__", "__rtruediv__"):
152-
# combine with bools does not generate the correct result
153-
# (numpy behaviour for div is to regard the bools as numeric)
154-
expected = self._combine(obj.astype(float), other, op)
155-
expected = expected.astype("Float64")
156-
if op_name == "__rpow__":
157-
# for rpow, combine does not propagate NaN
158-
expected[result.isna()] = np.nan
159-
tm.assert_equal(result, expected)
160-
else:
161-
with pytest.raises(exc):
162-
op(obj, other)
137+
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
138+
if op_name in (
139+
"__floordiv__",
140+
"__rfloordiv__",
141+
"__pow__",
142+
"__rpow__",
143+
"__mod__",
144+
"__rmod__",
145+
):
146+
# combine keeps boolean type
147+
pointwise_result = pointwise_result.astype("Int8")
148+
149+
elif op_name in ("__truediv__", "__rtruediv__"):
150+
# combine with bools does not generate the correct result
151+
# (numpy behaviour for div is to regard the bools as numeric)
152+
if op_name == "__truediv__":
153+
op = operator.truediv
154+
else:
155+
op = roperator.rtruediv
156+
pointwise_result = self._combine(obj.astype(float), other, op)
157+
pointwise_result = pointwise_result.astype("Float64")
158+
159+
if op_name == "__rpow__":
160+
# for rpow, combine does not propagate NaN
161+
result = getattr(obj, op_name)(other)
162+
pointwise_result[result.isna()] = np.nan
163+
164+
return pointwise_result
163165

164166
@pytest.mark.xfail(
165167
reason="Inconsistency between floordiv and divmod; we raise for floordiv "

pandas/tests/extension/test_masked_numeric.py

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -146,45 +146,20 @@ class TestDtype(base.BaseDtypeTests):
146146

147147

148148
class TestArithmeticOps(base.BaseArithmeticOpsTests):
149-
def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
150-
if exc is None:
151-
sdtype = tm.get_dtype(s)
152-
153-
if hasattr(other, "dtype") and isinstance(other.dtype, np.dtype):
154-
if sdtype.kind == "f":
155-
if other.dtype.kind == "f":
156-
# other is np.float64 and would therefore always result
157-
# in upcasting, so keeping other as same numpy_dtype
158-
other = other.astype(sdtype.numpy_dtype)
159-
160-
else:
161-
# i.e. sdtype.kind in "iu""
162-
if other.dtype.kind in "iu" and sdtype.is_unsigned_integer:
163-
# TODO: comment below is inaccurate; other can be int8
164-
# int16, ...
165-
# and the trouble is that e.g. if s is UInt8 and other
166-
# is int8, then result is UInt16
167-
# other is np.int64 and would therefore always result in
168-
# upcasting, so keeping other as same numpy_dtype
169-
other = other.astype(sdtype.numpy_dtype)
170-
171-
result = op(s, other)
172-
expected = self._combine(s, other, op)
173-
174-
if sdtype.kind in "iu":
175-
if op_name in ("__rtruediv__", "__truediv__", "__div__"):
176-
expected = expected.fillna(np.nan).astype("Float64")
177-
else:
178-
# combine method result in 'biggest' (int64) dtype
179-
expected = expected.astype(sdtype)
149+
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
150+
sdtype = tm.get_dtype(obj)
151+
expected = pointwise_result
152+
153+
if sdtype.kind in "iu":
154+
if op_name in ("__rtruediv__", "__truediv__", "__div__"):
155+
expected = expected.fillna(np.nan).astype("Float64")
180156
else:
181-
# combine method result in 'biggest' (float64) dtype
157+
# combine method result in 'biggest' (int64) dtype
182158
expected = expected.astype(sdtype)
183-
184-
tm.assert_equal(result, expected)
185159
else:
186-
with pytest.raises(exc):
187-
op(s, other)
160+
# combine method result in 'biggest' (float64) dtype
161+
expected = expected.astype(sdtype)
162+
return expected
188163

189164
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
190165
# overwriting to indicate ops don't raise an error
@@ -195,17 +170,8 @@ def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
195170

196171

197172
class TestComparisonOps(base.BaseComparisonOpsTests):
198-
def _check_op(
199-
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
200-
):
201-
if exc is None:
202-
result = op(ser, other)
203-
# Override to do the astype to boolean
204-
expected = ser.combine(other, op).astype("boolean")
205-
tm.assert_series_equal(result, expected)
206-
else:
207-
with pytest.raises(exc):
208-
op(ser, other)
173+
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
174+
return pointwise_result.astype("boolean")
209175

210176
def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
211177
super().check_opname(ser, op_name, other, exc=None)

0 commit comments

Comments
 (0)