Skip to content

Commit ae6a335

Browse files
authored
TST: fix Decimal constructor xfail (#54338)
* TST: fix Decimal constructor xfail * mypy fixup
1 parent 55ec5e7 commit ae6a335

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

pandas/tests/extension/decimal/array.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pandas.core.dtypes.common import (
1313
is_dtype_equal,
1414
is_float,
15+
is_integer,
1516
pandas_dtype,
1617
)
1718

@@ -71,11 +72,14 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
7172

7273
def __init__(self, values, dtype=None, copy=False, context=None) -> None:
7374
for i, val in enumerate(values):
74-
if is_float(val):
75+
if is_float(val) or is_integer(val):
7576
if np.isnan(val):
7677
values[i] = DecimalDtype.na_value
7778
else:
78-
values[i] = DecimalDtype.type(val)
79+
# error: Argument 1 has incompatible type "float | int |
80+
# integer[Any]"; expected "Decimal | float | str | tuple[int,
81+
# Sequence[int], int]"
82+
values[i] = DecimalDtype.type(val) # type: ignore[arg-type]
7983
elif not isinstance(val, decimal.Decimal):
8084
raise TypeError("All values must be of type " + str(decimal.Decimal))
8185
values = np.asarray(values, dtype=object)

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -267,20 +267,16 @@ def test_series_repr(self, data):
267267
assert "Decimal: " in repr(ser)
268268

269269

270-
@pytest.mark.xfail(
271-
reason=(
272-
"DecimalArray constructor raises bc _from_sequence wants Decimals, not ints."
273-
"Easy to fix, just need to do it."
274-
),
275-
raises=TypeError,
276-
)
277-
def test_series_constructor_coerce_data_to_extension_dtype_raises():
278-
xpr = (
279-
"Cannot cast data to extension dtype 'decimal'. Pass the "
280-
"extension array directly."
270+
def test_series_constructor_coerce_data_to_extension_dtype():
271+
dtype = DecimalDtype()
272+
ser = pd.Series([0, 1, 2], dtype=dtype)
273+
274+
arr = DecimalArray(
275+
[decimal.Decimal(0), decimal.Decimal(1), decimal.Decimal(2)],
276+
dtype=dtype,
281277
)
282-
with pytest.raises(ValueError, match=xpr):
283-
pd.Series([0, 1, 2], dtype=DecimalDtype())
278+
exp = pd.Series(arr)
279+
tm.assert_series_equal(ser, exp)
284280

285281

286282
def test_series_constructor_with_dtype():

0 commit comments

Comments
 (0)