diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index cf12759c051fc..65494c7a789e2 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -446,7 +446,7 @@ ExtensionType Changes - Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`) - :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`) - :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`) -- +- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`). .. _whatsnew_0240.api.incompatibilities: diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index c126117060c3d..eef6a756e2bc9 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -8,6 +8,7 @@ from pandas.compat import u, range from pandas.compat import set_function_name +from pandas.core.dtypes.cast import astype_nansafe from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass from pandas.core.dtypes.common import ( is_integer, is_scalar, is_float, @@ -391,7 +392,7 @@ def astype(self, dtype, copy=True): # coerce data = self._coerce_to_ndarray() - return data.astype(dtype=dtype, copy=False) + return astype_nansafe(data, dtype, copy=None) @property def _ndarray_values(self): diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 3971e90e64a14..410e061c895db 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -647,7 +647,16 @@ def conv(r, dtype): def astype_nansafe(arr, dtype, copy=True): """ return a view if copy is False, but - need to be very careful as the result shape could change! """ + need to be very careful as the result shape could change! + + Parameters + ---------- + arr : ndarray + dtype : np.dtype + copy : bool, default True + If False, a view will be attempted but may fail, if + e.g. the itemsizes don't align. + """ # dispatch on extension dtype if needed if is_extension_array_dtype(dtype): @@ -733,8 +742,10 @@ def astype_nansafe(arr, dtype, copy=True): FutureWarning, stacklevel=5) dtype = np.dtype(dtype.name + "[ns]") - if copy: + if copy or is_object_dtype(arr) or is_object_dtype(dtype): + # Explicit copy, or required since NumPy can't view from / to object. return arr.astype(dtype, copy=True) + return arr.view(dtype) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index f0635014b166b..0bfc7650a24aa 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -637,22 +637,25 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, # force the copy here if values is None: - if issubclass(dtype.type, - (compat.text_type, compat.string_types)): + if self.is_extension: + values = self.values.astype(dtype) + else: + if issubclass(dtype.type, + (compat.text_type, compat.string_types)): - # use native type formatting for datetime/tz/timedelta - if self.is_datelike: - values = self.to_native_types() + # use native type formatting for datetime/tz/timedelta + if self.is_datelike: + values = self.to_native_types() - # astype formatting - else: - values = self.get_values() + # astype formatting + else: + values = self.get_values() - else: - values = self.get_values(dtype=dtype) + else: + values = self.get_values(dtype=dtype) - # _astype_nansafe works fine with 1-d only - values = astype_nansafe(values.ravel(), dtype, copy=True) + # _astype_nansafe works fine with 1-d only + values = astype_nansafe(values.ravel(), dtype, copy=True) # TODO(extension) # should we make this attribute? diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 108b8874b3ac5..f3475dead2418 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype): name = 'decimal' na_value = decimal.Decimal('NaN') + def __init__(self, context=None): + self.context = context or decimal.getcontext() + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.context == other.context + return super(DecimalDtype, self).__eq__(other) + + def __repr__(self): + return 'DecimalDtype(context={})'.format(self.context) + @classmethod def construct_array_type(cls): """Return the array type associated with this dtype @@ -35,13 +46,12 @@ def construct_from_string(cls, string): class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin): - dtype = DecimalDtype() - def __init__(self, values, dtype=None, copy=False): + def __init__(self, values, dtype=None, copy=False, context=None): for val in values: - if not isinstance(val, self.dtype.type): + if not isinstance(val, decimal.Decimal): raise TypeError("All values must be of type " + - str(self.dtype.type)) + str(decimal.Decimal)) values = np.asarray(values, dtype=object) self._data = values @@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False): # those aliases are currently not working due to assumptions # in internal code (GH-20735) # self._values = self.values = self.data + self._dtype = DecimalDtype(context) + + @property + def dtype(self): + return self._dtype @classmethod def _from_sequence(cls, scalars, dtype=None, copy=False): @@ -82,6 +97,11 @@ def copy(self, deep=False): return type(self)(self._data.copy()) return type(self)(self) + def astype(self, dtype, copy=True): + if isinstance(dtype, type(self.dtype)): + return type(self)(self._data, context=dtype.context) + return super(DecimalArray, self).astype(dtype, copy) + def __setitem__(self, key, value): if pd.api.types.is_list_like(value): value = [decimal.Decimal(v) for v in value] diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index bc7237f263b1d..04e855242b5e6 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -205,6 +205,27 @@ def test_dataframe_constructor_with_dtype(): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize("frame", [True, False]) +def test_astype_dispatches(frame): + # This is a dtype-specific test that ensures Series[decimal].astype + # gets all the way through to ExtensionArray.astype + # Designing a reliable smoke test that works for arbitrary data types + # is difficult. + data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a') + ctx = decimal.Context() + ctx.prec = 5 + + if frame: + data = data.to_frame() + + result = data.astype(DecimalDtype(ctx)) + + if frame: + result = result['a'] + + assert result.dtype.context.prec == ctx.prec + + class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests): def check_opname(self, s, op_name, other, exc=None): diff --git a/pandas/tests/extension/integer/test_integer.py b/pandas/tests/extension/integer/test_integer.py index 5e0f5bf0a5dcf..55451da8f7eed 100644 --- a/pandas/tests/extension/integer/test_integer.py +++ b/pandas/tests/extension/integer/test_integer.py @@ -697,6 +697,15 @@ def test_cross_type_arithmetic(): tm.assert_series_equal(result, expected) +def test_astype_nansafe(): + # https://github.com/pandas-dev/pandas/pull/22343 + arr = IntegerArray([np.nan, 1, 2], dtype="Int8") + + with tm.assert_raises_regex( + ValueError, 'cannot convert float NaN to integer'): + arr.astype('uint32') + + # TODO(jreback) - these need testing / are broken # shift