@@ -65,35 +65,71 @@ def data_for_grouping():
65
65
return DecimalArray ([b , b , na , na , a , a , b , c ])
66
66
67
67
68
- class TestDtype (base .BaseDtypeTests ):
69
- pass
68
+ class TestDecimalArray (base .ExtensionTests ):
69
+ def _get_expected_exception (
70
+ self , op_name : str , obj , other
71
+ ) -> type [Exception ] | None :
72
+ return None
70
73
74
+ def _supports_reduction (self , obj , op_name : str ) -> bool :
75
+ return True
71
76
72
- class TestInterface (base .BaseInterfaceTests ):
73
- pass
77
+ def check_reduce (self , s , op_name , skipna ):
78
+ if op_name == "count" :
79
+ return super ().check_reduce (s , op_name , skipna )
80
+ else :
81
+ result = getattr (s , op_name )(skipna = skipna )
82
+ expected = getattr (np .asarray (s ), op_name )()
83
+ tm .assert_almost_equal (result , expected )
74
84
85
+ def test_reduce_series_numeric (self , data , all_numeric_reductions , skipna , request ):
86
+ if all_numeric_reductions in ["kurt" , "skew" , "sem" , "median" ]:
87
+ mark = pytest .mark .xfail (raises = NotImplementedError )
88
+ request .node .add_marker (mark )
89
+ super ().test_reduce_series_numeric (data , all_numeric_reductions , skipna )
75
90
76
- class TestConstructors (base .BaseConstructorsTests ):
77
- pass
91
+ def test_reduce_frame (self , data , all_numeric_reductions , skipna , request ):
92
+ op_name = all_numeric_reductions
93
+ if op_name in ["skew" , "median" ]:
94
+ mark = pytest .mark .xfail (raises = NotImplementedError )
95
+ request .node .add_marker (mark )
78
96
97
+ return super ().test_reduce_frame (data , all_numeric_reductions , skipna )
79
98
80
- class TestReshaping (base .BaseReshapingTests ):
81
- pass
99
+ def test_compare_scalar (self , data , comparison_op ):
100
+ ser = pd .Series (data )
101
+ self ._compare_other (ser , data , comparison_op , 0.5 )
82
102
103
+ def test_compare_array (self , data , comparison_op ):
104
+ ser = pd .Series (data )
83
105
84
- class TestGetitem (base .BaseGetitemTests ):
85
- def test_take_na_value_other_decimal (self ):
86
- arr = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("2.0" )])
87
- result = arr .take ([0 , - 1 ], allow_fill = True , fill_value = decimal .Decimal ("-1.0" ))
88
- expected = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("-1.0" )])
89
- tm .assert_extension_array_equal (result , expected )
106
+ alter = np .random .default_rng (2 ).choice ([- 1 , 0 , 1 ], len (data ))
107
+ # Randomly double, halve or keep same value
108
+ other = pd .Series (data ) * [decimal .Decimal (pow (2.0 , i )) for i in alter ]
109
+ self ._compare_other (ser , data , comparison_op , other )
90
110
111
+ def test_arith_series_with_array (self , data , all_arithmetic_operators ):
112
+ op_name = all_arithmetic_operators
113
+ ser = pd .Series (data )
114
+
115
+ context = decimal .getcontext ()
116
+ divbyzerotrap = context .traps [decimal .DivisionByZero ]
117
+ invalidoptrap = context .traps [decimal .InvalidOperation ]
118
+ context .traps [decimal .DivisionByZero ] = 0
119
+ context .traps [decimal .InvalidOperation ] = 0
91
120
92
- class TestIndex (base .BaseIndexTests ):
93
- pass
121
+ # Decimal supports ops with int, but not float
122
+ other = pd .Series ([int (d * 100 ) for d in data ])
123
+ self .check_opname (ser , op_name , other )
124
+
125
+ if "mod" not in op_name :
126
+ self .check_opname (ser , op_name , ser * 2 )
94
127
128
+ self .check_opname (ser , op_name , 0 )
129
+ self .check_opname (ser , op_name , 5 )
130
+ context .traps [decimal .DivisionByZero ] = divbyzerotrap
131
+ context .traps [decimal .InvalidOperation ] = invalidoptrap
95
132
96
- class TestMissing (base .BaseMissingTests ):
97
133
def test_fillna_frame (self , data_missing ):
98
134
msg = "ExtensionArray.fillna added a 'copy' keyword"
99
135
with tm .assert_produces_warning (
@@ -141,59 +177,6 @@ def test_fillna_series_method(self, data_missing, fillna_method):
141
177
):
142
178
super ().test_fillna_series_method (data_missing , fillna_method )
143
179
144
-
145
- class Reduce :
146
- def _supports_reduction (self , obj , op_name : str ) -> bool :
147
- return True
148
-
149
- def check_reduce (self , s , op_name , skipna ):
150
- if op_name == "count" :
151
- return super ().check_reduce (s , op_name , skipna )
152
- else :
153
- result = getattr (s , op_name )(skipna = skipna )
154
- expected = getattr (np .asarray (s ), op_name )()
155
- tm .assert_almost_equal (result , expected )
156
-
157
- def test_reduction_without_keepdims (self ):
158
- # GH52788
159
- # test _reduce without keepdims
160
-
161
- class DecimalArray2 (DecimalArray ):
162
- def _reduce (self , name : str , * , skipna : bool = True , ** kwargs ):
163
- # no keepdims in signature
164
- return super ()._reduce (name , skipna = skipna )
165
-
166
- arr = DecimalArray2 ([decimal .Decimal (2 ) for _ in range (100 )])
167
-
168
- ser = pd .Series (arr )
169
- result = ser .agg ("sum" )
170
- expected = decimal .Decimal (200 )
171
- assert result == expected
172
-
173
- df = pd .DataFrame ({"a" : arr , "b" : arr })
174
- with tm .assert_produces_warning (FutureWarning ):
175
- result = df .agg ("sum" )
176
- expected = pd .Series ({"a" : 200 , "b" : 200 }, dtype = object )
177
- tm .assert_series_equal (result , expected )
178
-
179
-
180
- class TestReduce (Reduce , base .BaseReduceTests ):
181
- def test_reduce_series_numeric (self , data , all_numeric_reductions , skipna , request ):
182
- if all_numeric_reductions in ["kurt" , "skew" , "sem" , "median" ]:
183
- mark = pytest .mark .xfail (raises = NotImplementedError )
184
- request .node .add_marker (mark )
185
- super ().test_reduce_series_numeric (data , all_numeric_reductions , skipna )
186
-
187
- def test_reduce_frame (self , data , all_numeric_reductions , skipna , request ):
188
- op_name = all_numeric_reductions
189
- if op_name in ["skew" , "median" ]:
190
- mark = pytest .mark .xfail (raises = NotImplementedError )
191
- request .node .add_marker (mark )
192
-
193
- return super ().test_reduce_frame (data , all_numeric_reductions , skipna )
194
-
195
-
196
- class TestMethods (base .BaseMethodsTests ):
197
180
def test_fillna_copy_frame (self , data_missing , using_copy_on_write ):
198
181
warn = FutureWarning if not using_copy_on_write else None
199
182
msg = "ExtensionArray.fillna added a 'copy' keyword"
@@ -226,27 +209,31 @@ def test_value_counts(self, all_data, dropna, request):
226
209
227
210
tm .assert_series_equal (result , expected )
228
211
229
-
230
- class TestCasting (base .BaseCastingTests ):
231
- pass
232
-
233
-
234
- class TestGroupby (base .BaseGroupbyTests ):
235
- pass
236
-
237
-
238
- class TestSetitem (base .BaseSetitemTests ):
239
- pass
240
-
241
-
242
- class TestPrinting (base .BasePrintingTests ):
243
212
def test_series_repr (self , data ):
244
213
# Overriding this base test to explicitly test that
245
214
# the custom _formatter is used
246
215
ser = pd .Series (data )
247
216
assert data .dtype .name in repr (ser )
248
217
assert "Decimal: " in repr (ser )
249
218
219
+ @pytest .mark .xfail (
220
+ reason = "Looks like the test (incorrectly) implicitly assumes int/bool dtype"
221
+ )
222
+ def test_invert (self , data ):
223
+ super ().test_invert (data )
224
+
225
+ @pytest .mark .xfail (reason = "Inconsistent array-vs-scalar behavior" )
226
+ @pytest .mark .parametrize ("ufunc" , [np .positive , np .negative , np .abs ])
227
+ def test_unary_ufunc_dunder_equivalence (self , data , ufunc ):
228
+ super ().test_unary_ufunc_dunder_equivalence (data , ufunc )
229
+
230
+
231
+ def test_take_na_value_other_decimal ():
232
+ arr = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("2.0" )])
233
+ result = arr .take ([0 , - 1 ], allow_fill = True , fill_value = decimal .Decimal ("-1.0" ))
234
+ expected = DecimalArray ([decimal .Decimal ("1.0" ), decimal .Decimal ("-1.0" )])
235
+ tm .assert_extension_array_equal (result , expected )
236
+
250
237
251
238
def test_series_constructor_coerce_data_to_extension_dtype ():
252
239
dtype = DecimalDtype ()
@@ -305,53 +292,6 @@ def test_astype_dispatches(frame):
305
292
assert result .dtype .context .prec == ctx .prec
306
293
307
294
308
- class TestArithmeticOps (base .BaseArithmeticOpsTests ):
309
- series_scalar_exc = None
310
- frame_scalar_exc = None
311
- series_array_exc = None
312
-
313
- def _get_expected_exception (
314
- self , op_name : str , obj , other
315
- ) -> type [Exception ] | None :
316
- return None
317
-
318
- def test_arith_series_with_array (self , data , all_arithmetic_operators ):
319
- op_name = all_arithmetic_operators
320
- s = pd .Series (data )
321
-
322
- context = decimal .getcontext ()
323
- divbyzerotrap = context .traps [decimal .DivisionByZero ]
324
- invalidoptrap = context .traps [decimal .InvalidOperation ]
325
- context .traps [decimal .DivisionByZero ] = 0
326
- context .traps [decimal .InvalidOperation ] = 0
327
-
328
- # Decimal supports ops with int, but not float
329
- other = pd .Series ([int (d * 100 ) for d in data ])
330
- self .check_opname (s , op_name , other )
331
-
332
- if "mod" not in op_name :
333
- self .check_opname (s , op_name , s * 2 )
334
-
335
- self .check_opname (s , op_name , 0 )
336
- self .check_opname (s , op_name , 5 )
337
- context .traps [decimal .DivisionByZero ] = divbyzerotrap
338
- context .traps [decimal .InvalidOperation ] = invalidoptrap
339
-
340
-
341
- class TestComparisonOps (base .BaseComparisonOpsTests ):
342
- def test_compare_scalar (self , data , comparison_op ):
343
- s = pd .Series (data )
344
- self ._compare_other (s , data , comparison_op , 0.5 )
345
-
346
- def test_compare_array (self , data , comparison_op ):
347
- s = pd .Series (data )
348
-
349
- alter = np .random .default_rng (2 ).choice ([- 1 , 0 , 1 ], len (data ))
350
- # Randomly double, halve or keep same value
351
- other = pd .Series (data ) * [decimal .Decimal (pow (2.0 , i )) for i in alter ]
352
- self ._compare_other (s , data , comparison_op , other )
353
-
354
-
355
295
class DecimalArrayWithoutFromSequence (DecimalArray ):
356
296
"""Helper class for testing error handling in _from_sequence."""
357
297
0 commit comments