Skip to content

Commit 5f672dc

Browse files
authored
REF: update decimal tests to TestExtension (#54455)
1 parent bbe11b2 commit 5f672dc

File tree

1 file changed

+71
-131
lines changed

1 file changed

+71
-131
lines changed

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 71 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -65,35 +65,71 @@ def data_for_grouping():
6565
return DecimalArray([b, b, na, na, a, a, b, c])
6666

6767

68-
class TestDtype(base.BaseDtypeTests):
69-
pass
68+
class TestDecimalArray(base.ExtensionTests):
69+
def _get_expected_exception(
70+
self, op_name: str, obj, other
71+
) -> type[Exception] | None:
72+
return None
7073

74+
def _supports_reduction(self, obj, op_name: str) -> bool:
75+
return True
7176

72-
class TestInterface(base.BaseInterfaceTests):
73-
pass
77+
def check_reduce(self, s, op_name, skipna):
78+
if op_name == "count":
79+
return super().check_reduce(s, op_name, skipna)
80+
else:
81+
result = getattr(s, op_name)(skipna=skipna)
82+
expected = getattr(np.asarray(s), op_name)()
83+
tm.assert_almost_equal(result, expected)
7484

85+
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
86+
if all_numeric_reductions in ["kurt", "skew", "sem", "median"]:
87+
mark = pytest.mark.xfail(raises=NotImplementedError)
88+
request.node.add_marker(mark)
89+
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
7590

76-
class TestConstructors(base.BaseConstructorsTests):
77-
pass
91+
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
92+
op_name = all_numeric_reductions
93+
if op_name in ["skew", "median"]:
94+
mark = pytest.mark.xfail(raises=NotImplementedError)
95+
request.node.add_marker(mark)
7896

97+
return super().test_reduce_frame(data, all_numeric_reductions, skipna)
7998

80-
class TestReshaping(base.BaseReshapingTests):
81-
pass
99+
def test_compare_scalar(self, data, comparison_op):
100+
ser = pd.Series(data)
101+
self._compare_other(ser, data, comparison_op, 0.5)
82102

103+
def test_compare_array(self, data, comparison_op):
104+
ser = pd.Series(data)
83105

84-
class TestGetitem(base.BaseGetitemTests):
85-
def test_take_na_value_other_decimal(self):
86-
arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
87-
result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
88-
expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
89-
tm.assert_extension_array_equal(result, expected)
106+
alter = np.random.default_rng(2).choice([-1, 0, 1], len(data))
107+
# Randomly double, halve or keep same value
108+
other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
109+
self._compare_other(ser, data, comparison_op, other)
90110

111+
def test_arith_series_with_array(self, data, all_arithmetic_operators):
112+
op_name = all_arithmetic_operators
113+
ser = pd.Series(data)
114+
115+
context = decimal.getcontext()
116+
divbyzerotrap = context.traps[decimal.DivisionByZero]
117+
invalidoptrap = context.traps[decimal.InvalidOperation]
118+
context.traps[decimal.DivisionByZero] = 0
119+
context.traps[decimal.InvalidOperation] = 0
91120

92-
class TestIndex(base.BaseIndexTests):
93-
pass
121+
# Decimal supports ops with int, but not float
122+
other = pd.Series([int(d * 100) for d in data])
123+
self.check_opname(ser, op_name, other)
124+
125+
if "mod" not in op_name:
126+
self.check_opname(ser, op_name, ser * 2)
94127

128+
self.check_opname(ser, op_name, 0)
129+
self.check_opname(ser, op_name, 5)
130+
context.traps[decimal.DivisionByZero] = divbyzerotrap
131+
context.traps[decimal.InvalidOperation] = invalidoptrap
95132

96-
class TestMissing(base.BaseMissingTests):
97133
def test_fillna_frame(self, data_missing):
98134
msg = "ExtensionArray.fillna added a 'copy' keyword"
99135
with tm.assert_produces_warning(
@@ -141,59 +177,6 @@ def test_fillna_series_method(self, data_missing, fillna_method):
141177
):
142178
super().test_fillna_series_method(data_missing, fillna_method)
143179

144-
145-
class Reduce:
146-
def _supports_reduction(self, obj, op_name: str) -> bool:
147-
return True
148-
149-
def check_reduce(self, s, op_name, skipna):
150-
if op_name == "count":
151-
return super().check_reduce(s, op_name, skipna)
152-
else:
153-
result = getattr(s, op_name)(skipna=skipna)
154-
expected = getattr(np.asarray(s), op_name)()
155-
tm.assert_almost_equal(result, expected)
156-
157-
def test_reduction_without_keepdims(self):
158-
# GH52788
159-
# test _reduce without keepdims
160-
161-
class DecimalArray2(DecimalArray):
162-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
163-
# no keepdims in signature
164-
return super()._reduce(name, skipna=skipna)
165-
166-
arr = DecimalArray2([decimal.Decimal(2) for _ in range(100)])
167-
168-
ser = pd.Series(arr)
169-
result = ser.agg("sum")
170-
expected = decimal.Decimal(200)
171-
assert result == expected
172-
173-
df = pd.DataFrame({"a": arr, "b": arr})
174-
with tm.assert_produces_warning(FutureWarning):
175-
result = df.agg("sum")
176-
expected = pd.Series({"a": 200, "b": 200}, dtype=object)
177-
tm.assert_series_equal(result, expected)
178-
179-
180-
class TestReduce(Reduce, base.BaseReduceTests):
181-
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request):
182-
if all_numeric_reductions in ["kurt", "skew", "sem", "median"]:
183-
mark = pytest.mark.xfail(raises=NotImplementedError)
184-
request.node.add_marker(mark)
185-
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
186-
187-
def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
188-
op_name = all_numeric_reductions
189-
if op_name in ["skew", "median"]:
190-
mark = pytest.mark.xfail(raises=NotImplementedError)
191-
request.node.add_marker(mark)
192-
193-
return super().test_reduce_frame(data, all_numeric_reductions, skipna)
194-
195-
196-
class TestMethods(base.BaseMethodsTests):
197180
def test_fillna_copy_frame(self, data_missing, using_copy_on_write):
198181
warn = FutureWarning if not using_copy_on_write else None
199182
msg = "ExtensionArray.fillna added a 'copy' keyword"
@@ -226,27 +209,31 @@ def test_value_counts(self, all_data, dropna, request):
226209

227210
tm.assert_series_equal(result, expected)
228211

229-
230-
class TestCasting(base.BaseCastingTests):
231-
pass
232-
233-
234-
class TestGroupby(base.BaseGroupbyTests):
235-
pass
236-
237-
238-
class TestSetitem(base.BaseSetitemTests):
239-
pass
240-
241-
242-
class TestPrinting(base.BasePrintingTests):
243212
def test_series_repr(self, data):
244213
# Overriding this base test to explicitly test that
245214
# the custom _formatter is used
246215
ser = pd.Series(data)
247216
assert data.dtype.name in repr(ser)
248217
assert "Decimal: " in repr(ser)
249218

219+
@pytest.mark.xfail(
220+
reason="Looks like the test (incorrectly) implicitly assumes int/bool dtype"
221+
)
222+
def test_invert(self, data):
223+
super().test_invert(data)
224+
225+
@pytest.mark.xfail(reason="Inconsistent array-vs-scalar behavior")
226+
@pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
227+
def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
228+
super().test_unary_ufunc_dunder_equivalence(data, ufunc)
229+
230+
231+
def test_take_na_value_other_decimal():
232+
arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
233+
result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
234+
expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
235+
tm.assert_extension_array_equal(result, expected)
236+
250237

251238
def test_series_constructor_coerce_data_to_extension_dtype():
252239
dtype = DecimalDtype()
@@ -305,53 +292,6 @@ def test_astype_dispatches(frame):
305292
assert result.dtype.context.prec == ctx.prec
306293

307294

308-
class TestArithmeticOps(base.BaseArithmeticOpsTests):
309-
series_scalar_exc = None
310-
frame_scalar_exc = None
311-
series_array_exc = None
312-
313-
def _get_expected_exception(
314-
self, op_name: str, obj, other
315-
) -> type[Exception] | None:
316-
return None
317-
318-
def test_arith_series_with_array(self, data, all_arithmetic_operators):
319-
op_name = all_arithmetic_operators
320-
s = pd.Series(data)
321-
322-
context = decimal.getcontext()
323-
divbyzerotrap = context.traps[decimal.DivisionByZero]
324-
invalidoptrap = context.traps[decimal.InvalidOperation]
325-
context.traps[decimal.DivisionByZero] = 0
326-
context.traps[decimal.InvalidOperation] = 0
327-
328-
# Decimal supports ops with int, but not float
329-
other = pd.Series([int(d * 100) for d in data])
330-
self.check_opname(s, op_name, other)
331-
332-
if "mod" not in op_name:
333-
self.check_opname(s, op_name, s * 2)
334-
335-
self.check_opname(s, op_name, 0)
336-
self.check_opname(s, op_name, 5)
337-
context.traps[decimal.DivisionByZero] = divbyzerotrap
338-
context.traps[decimal.InvalidOperation] = invalidoptrap
339-
340-
341-
class TestComparisonOps(base.BaseComparisonOpsTests):
342-
def test_compare_scalar(self, data, comparison_op):
343-
s = pd.Series(data)
344-
self._compare_other(s, data, comparison_op, 0.5)
345-
346-
def test_compare_array(self, data, comparison_op):
347-
s = pd.Series(data)
348-
349-
alter = np.random.default_rng(2).choice([-1, 0, 1], len(data))
350-
# Randomly double, halve or keep same value
351-
other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
352-
self._compare_other(s, data, comparison_op, other)
353-
354-
355295
class DecimalArrayWithoutFromSequence(DecimalArray):
356296
"""Helper class for testing error handling in _from_sequence."""
357297

0 commit comments

Comments
 (0)