@@ -873,21 +873,6 @@ def test_basic_equals(self, data):
873
873
class TestBaseArithmeticOps (base .BaseArithmeticOpsTests ):
874
874
divmod_exc = NotImplementedError
875
875
876
- @classmethod
877
- def assert_equal (cls , left , right , ** kwargs ):
878
- if isinstance (left , pd .DataFrame ):
879
- left_pa_type = left .iloc [:, 0 ].dtype .pyarrow_dtype
880
- right_pa_type = right .iloc [:, 0 ].dtype .pyarrow_dtype
881
- else :
882
- left_pa_type = left .dtype .pyarrow_dtype
883
- right_pa_type = right .dtype .pyarrow_dtype
884
- if pa .types .is_decimal (left_pa_type ) or pa .types .is_decimal (right_pa_type ):
885
- # decimal precision can resize in the result type depending on data
886
- # just compare the float values
887
- left = left .astype ("float[pyarrow]" )
888
- right = right .astype ("float[pyarrow]" )
889
- tm .assert_equal (left , right , ** kwargs )
890
-
891
876
def get_op_from_name (self , op_name ):
892
877
short_opname = op_name .strip ("_" )
893
878
if short_opname == "rtruediv" :
@@ -934,6 +919,29 @@ def _patch_combine(self, obj, other, op):
934
919
unit = "us"
935
920
936
921
pa_expected = pa_expected .cast (f"duration[{ unit } ]" )
922
+
923
+ elif pa .types .is_decimal (pa_expected .type ) and pa .types .is_decimal (
924
+ original_dtype .pyarrow_dtype
925
+ ):
926
+ # decimal precision can resize in the result type depending on data
927
+ # just compare the float values
928
+ alt = op (obj , other )
929
+ alt_dtype = tm .get_dtype (alt )
930
+ assert isinstance (alt_dtype , ArrowDtype )
931
+ if op is operator .pow and isinstance (other , Decimal ):
932
+ # TODO: would it make more sense to retain Decimal here?
933
+ alt_dtype = ArrowDtype (pa .float64 ())
934
+ elif (
935
+ op is operator .pow
936
+ and isinstance (other , pd .Series )
937
+ and other .dtype == original_dtype
938
+ ):
939
+ # TODO: would it make more sense to retain Decimal here?
940
+ alt_dtype = ArrowDtype (pa .float64 ())
941
+ else :
942
+ assert pa .types .is_decimal (alt_dtype .pyarrow_dtype )
943
+ return expected .astype (alt_dtype )
944
+
937
945
else :
938
946
pa_expected = pa_expected .cast (original_dtype .pyarrow_dtype )
939
947
@@ -1075,6 +1083,7 @@ def test_arith_series_with_scalar(
1075
1083
or pa .types .is_duration (pa_dtype )
1076
1084
or pa .types .is_timestamp (pa_dtype )
1077
1085
or pa .types .is_date (pa_dtype )
1086
+ or pa .types .is_decimal (pa_dtype )
1078
1087
):
1079
1088
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1080
1089
# not upcast
@@ -1107,6 +1116,7 @@ def test_arith_frame_with_scalar(
1107
1116
or pa .types .is_duration (pa_dtype )
1108
1117
or pa .types .is_timestamp (pa_dtype )
1109
1118
or pa .types .is_date (pa_dtype )
1119
+ or pa .types .is_decimal (pa_dtype )
1110
1120
):
1111
1121
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1112
1122
# not upcast
@@ -1160,6 +1170,7 @@ def test_arith_series_with_array(
1160
1170
or pa .types .is_duration (pa_dtype )
1161
1171
or pa .types .is_timestamp (pa_dtype )
1162
1172
or pa .types .is_date (pa_dtype )
1173
+ or pa .types .is_decimal (pa_dtype )
1163
1174
):
1164
1175
monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
1165
1176
self .check_opname (ser , op_name , other , exc = self .series_array_exc )
0 commit comments