Skip to content

Commit fd4597d

Browse files
authored
REF: don patch assert_series_equal, assert_equal (#54345)
1 parent 898e5d0 commit fd4597d

File tree

2 files changed

+28
-30
lines changed

2 files changed

+28
-30
lines changed

pandas/tests/extension/test_arrow.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -873,21 +873,6 @@ def test_basic_equals(self, data):
873873
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
874874
divmod_exc = NotImplementedError
875875

876-
@classmethod
877-
def assert_equal(cls, left, right, **kwargs):
878-
if isinstance(left, pd.DataFrame):
879-
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
880-
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
881-
else:
882-
left_pa_type = left.dtype.pyarrow_dtype
883-
right_pa_type = right.dtype.pyarrow_dtype
884-
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
885-
# decimal precision can resize in the result type depending on data
886-
# just compare the float values
887-
left = left.astype("float[pyarrow]")
888-
right = right.astype("float[pyarrow]")
889-
tm.assert_equal(left, right, **kwargs)
890-
891876
def get_op_from_name(self, op_name):
892877
short_opname = op_name.strip("_")
893878
if short_opname == "rtruediv":
@@ -934,6 +919,29 @@ def _patch_combine(self, obj, other, op):
934919
unit = "us"
935920

936921
pa_expected = pa_expected.cast(f"duration[{unit}]")
922+
923+
elif pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(
924+
original_dtype.pyarrow_dtype
925+
):
926+
# decimal precision can resize in the result type depending on data
927+
# just compare the float values
928+
alt = op(obj, other)
929+
alt_dtype = tm.get_dtype(alt)
930+
assert isinstance(alt_dtype, ArrowDtype)
931+
if op is operator.pow and isinstance(other, Decimal):
932+
# TODO: would it make more sense to retain Decimal here?
933+
alt_dtype = ArrowDtype(pa.float64())
934+
elif (
935+
op is operator.pow
936+
and isinstance(other, pd.Series)
937+
and other.dtype == original_dtype
938+
):
939+
# TODO: would it make more sense to retain Decimal here?
940+
alt_dtype = ArrowDtype(pa.float64())
941+
else:
942+
assert pa.types.is_decimal(alt_dtype.pyarrow_dtype)
943+
return expected.astype(alt_dtype)
944+
937945
else:
938946
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
939947

@@ -1075,6 +1083,7 @@ def test_arith_series_with_scalar(
10751083
or pa.types.is_duration(pa_dtype)
10761084
or pa.types.is_timestamp(pa_dtype)
10771085
or pa.types.is_date(pa_dtype)
1086+
or pa.types.is_decimal(pa_dtype)
10781087
):
10791088
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
10801089
# not upcast
@@ -1107,6 +1116,7 @@ def test_arith_frame_with_scalar(
11071116
or pa.types.is_duration(pa_dtype)
11081117
or pa.types.is_timestamp(pa_dtype)
11091118
or pa.types.is_date(pa_dtype)
1119+
or pa.types.is_decimal(pa_dtype)
11101120
):
11111121
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11121122
# not upcast
@@ -1160,6 +1170,7 @@ def test_arith_series_with_array(
11601170
or pa.types.is_duration(pa_dtype)
11611171
or pa.types.is_timestamp(pa_dtype)
11621172
or pa.types.is_date(pa_dtype)
1173+
or pa.types.is_decimal(pa_dtype)
11631174
):
11641175
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
11651176
self.check_opname(ser, op_name, other, exc=self.series_array_exc)

pandas/tests/extension/test_numpy.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
import pytest
2020

2121
from pandas.core.dtypes.cast import can_hold_element
22-
from pandas.core.dtypes.dtypes import (
23-
ExtensionDtype,
24-
NumpyEADtype,
25-
)
22+
from pandas.core.dtypes.dtypes import NumpyEADtype
2623

2724
import pandas as pd
2825
import pandas._testing as tm
@@ -176,17 +173,7 @@ def skip_numpy_object(dtype, request):
176173

177174

178175
class BaseNumPyTests:
179-
@classmethod
180-
def assert_series_equal(cls, left, right, *args, **kwargs):
181-
# base class tests hard-code expected values with numpy dtypes,
182-
# whereas we generally want the corresponding NumpyEADtype
183-
if (
184-
isinstance(right, pd.Series)
185-
and not isinstance(right.dtype, ExtensionDtype)
186-
and isinstance(left.dtype, NumpyEADtype)
187-
):
188-
right = right.astype(NumpyEADtype(right.dtype))
189-
return tm.assert_series_equal(left, right, *args, **kwargs)
176+
pass
190177

191178

192179
class TestCasting(BaseNumPyTests, base.BaseCastingTests):

0 commit comments

Comments
 (0)