From a7ba8f6e7ee1861238e386860d8d56ed0560c1ba Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 13 Aug 2018 16:06:23 -0500 Subject: [PATCH 1/9] API: dispatch to EA.astype Closes #21185 --- doc/source/whatsnew/v0.24.0.txt | 2 +- pandas/core/arrays/integer.py | 3 +- pandas/core/dtypes/cast.py | 23 +++++++++++++-- pandas/core/internals/blocks.py | 27 ++++++++++-------- pandas/tests/extension/decimal/array.py | 28 ++++++++++++++++--- .../tests/extension/decimal/test_decimal.py | 18 ++++++++++++ 6 files changed, 81 insertions(+), 20 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 3ebdf853a9c64..b877076a327df 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -446,7 +446,7 @@ ExtensionType Changes - Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`) - :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`) - :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`) -- +- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`). .. _whatsnew_0240.api.incompatibilities: diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index c126117060c3d..eef6a756e2bc9 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -8,6 +8,7 @@ from pandas.compat import u, range from pandas.compat import set_function_name +from pandas.core.dtypes.cast import astype_nansafe from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass from pandas.core.dtypes.common import ( is_integer, is_scalar, is_float, @@ -391,7 +392,7 @@ def astype(self, dtype, copy=True): # coerce data = self._coerce_to_ndarray() - return data.astype(dtype=dtype, copy=False) + return astype_nansafe(data, dtype, copy=None) @property def _ndarray_values(self): diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 3971e90e64a14..cf89c2be2fe98 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -647,7 +647,17 @@ def conv(r, dtype): def astype_nansafe(arr, dtype, copy=True): """ return a view if copy is False, but - need to be very careful as the result shape could change! """ + need to be very careful as the result shape could change! + + Parameters + ---------- + arr : ndarray + dtype : np.dtype + copy : bool or None, default True + Whether to copy during the `.astype` (True) or + just return a view (False). Passing `copy=None` will + attempt to return a view, but will copy if necessary. + """ # dispatch on extension dtype if needed if is_extension_array_dtype(dtype): @@ -735,7 +745,16 @@ def astype_nansafe(arr, dtype, copy=True): if copy: return arr.astype(dtype, copy=True) - return arr.view(dtype) + else: + try: + return arr.view(dtype) + except TypeError: + if copy is None: + # allowed to copy if necessary (e.g. object) + return arr.astype(dtype, copy=True) + else: + raise + def maybe_convert_objects(values, convert_dates=True, convert_numeric=True, diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index f0635014b166b..0bfc7650a24aa 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -637,22 +637,25 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, # force the copy here if values is None: - if issubclass(dtype.type, - (compat.text_type, compat.string_types)): + if self.is_extension: + values = self.values.astype(dtype) + else: + if issubclass(dtype.type, + (compat.text_type, compat.string_types)): - # use native type formatting for datetime/tz/timedelta - if self.is_datelike: - values = self.to_native_types() + # use native type formatting for datetime/tz/timedelta + if self.is_datelike: + values = self.to_native_types() - # astype formatting - else: - values = self.get_values() + # astype formatting + else: + values = self.get_values() - else: - values = self.get_values(dtype=dtype) + else: + values = self.get_values(dtype=dtype) - # _astype_nansafe works fine with 1-d only - values = astype_nansafe(values.ravel(), dtype, copy=True) + # _astype_nansafe works fine with 1-d only + values = astype_nansafe(values.ravel(), dtype, copy=True) # TODO(extension) # should we make this attribute? diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 108b8874b3ac5..c8daa05041231 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype): name = 'decimal' na_value = decimal.Decimal('NaN') + def __init__(self, context=None): + self.context = context or decimal.getcontext() + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.context == other.context + return super(DecimalDtype, self).__eq__(other) + + def __repr__(self): + return 'DecimalDtype(context={})'.format(self.context) + @classmethod def construct_array_type(cls): """Return the array type associated with this dtype @@ -35,13 +46,12 @@ def construct_from_string(cls, string): class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin): - dtype = DecimalDtype() - def __init__(self, values, dtype=None, copy=False): + def __init__(self, values, dtype=None, copy=False, context=None): for val in values: - if not isinstance(val, self.dtype.type): + if not isinstance(val, decimal.Decimal): raise TypeError("All values must be of type " + - str(self.dtype.type)) + str(decimal.Decimal)) values = np.asarray(values, dtype=object) self._data = values @@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False): # those aliases are currently not working due to assumptions # in internal code (GH-20735) # self._values = self.values = self.data + self._dtype = DecimalDtype(context) + + @property + def dtype(self): + return self._dtype @classmethod def _from_sequence(cls, scalars, dtype=None, copy=False): @@ -82,6 +97,11 @@ def copy(self, deep=False): return type(self)(self._data.copy()) return type(self)(self) + def astype(self, dtype, copy=True): + if isinstance(dtype, type(self.dtype)): + return type(self)(self._data, context=dtype.context) + return super().astype(dtype, copy) + def __setitem__(self, key, value): if pd.api.types.is_list_like(value): value = [decimal.Decimal(v) for v in value] diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index bc7237f263b1d..92905a07dad2a 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -205,6 +205,24 @@ def test_dataframe_constructor_with_dtype(): tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize("frame", [True, False]) +def test_astype_dispatches(frame): + data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a') + ctx = decimal.Context() + ctx.prec = 5 + + if frame: + data = data.to_frame() + + result = data.astype(DecimalDtype(ctx)) + + if frame: + result = result['a'] + + assert result.dtype.context.prec == ctx.prec + + + class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests): def check_opname(self, s, op_name, other, exc=None): From 6eeec11f73cd253f67f9015456cbd7b99a74fe05 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 14 Aug 2018 11:15:47 -0500 Subject: [PATCH 2/9] py2 compat --- pandas/tests/extension/decimal/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index c8daa05041231..f3475dead2418 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -100,7 +100,7 @@ def copy(self, deep=False): def astype(self, dtype, copy=True): if isinstance(dtype, type(self.dtype)): return type(self)(self._data, context=dtype.context) - return super().astype(dtype, copy) + return super(DecimalArray, self).astype(dtype, copy) def __setitem__(self, key, value): if pd.api.types.is_list_like(value): From f1b860fcdb2078c2034b8bf0b67d17a643399fd1 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 14 Aug 2018 15:16:56 -0500 Subject: [PATCH 3/9] explainers --- pandas/tests/extension/decimal/test_decimal.py | 4 ++++ pandas/tests/extension/integer/test_integer.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 92905a07dad2a..85f01354a1d55 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -207,6 +207,10 @@ def test_dataframe_constructor_with_dtype(): @pytest.mark.parametrize("frame", [True, False]) def test_astype_dispatches(frame): + # This is a dtype-specific test that ensures Series[decimal].astype + # gets all the way through to ExtensionArray.astype + # Designing a reliable smoke test that works for arbitrary data types + # is difficult. data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a') ctx = decimal.Context() ctx.prec = 5 diff --git a/pandas/tests/extension/integer/test_integer.py b/pandas/tests/extension/integer/test_integer.py index 5e0f5bf0a5dcf..a71528d17524a 100644 --- a/pandas/tests/extension/integer/test_integer.py +++ b/pandas/tests/extension/integer/test_integer.py @@ -567,6 +567,14 @@ def test_astype(self, all_data): expected = pd.Series(np.asarray(mixed)) tm.assert_series_equal(result, expected) + def test_astype_nansafe(self): + # https://github.com/pandas-dev/pandas/pull/22343 + arr = IntegerArray([np.nan, 1, 2], dtype="Int8") + + with tm.assert_raises_regex( + ValueError, 'cannot convert float NaN to integer'): + arr.astype('uint32') + @pytest.mark.parametrize('dtype', [Int8Dtype(), 'Int8']) def test_astype_specific_casting(self, dtype): s = pd.Series([1, 2, 3], dtype='Int64') From 5c442755bf5a6199996f004de5bd8805f0ab899a Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 14 Aug 2018 15:17:51 -0500 Subject: [PATCH 4/9] linting --- pandas/tests/extension/decimal/test_decimal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 85f01354a1d55..04e855242b5e6 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -226,7 +226,6 @@ def test_astype_dispatches(frame): assert result.dtype.context.prec == ctx.prec - class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests): def check_opname(self, s, op_name, other, exc=None): From de1fb5bbe48e623262b08b923f66d5f5cf7fc970 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 14 Aug 2018 16:40:58 -0500 Subject: [PATCH 5/9] lint --- pandas/core/dtypes/cast.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index cf89c2be2fe98..c73522589d2ba 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -756,7 +756,6 @@ def astype_nansafe(arr, dtype, copy=True): raise - def maybe_convert_objects(values, convert_dates=True, convert_numeric=True, convert_timedeltas=True, copy=True): """ if we have an object dtype, try to coerce dates and/or numbers """ From f1476358ce3d52cc47520c868a74c4248ba647b8 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 15 Aug 2018 08:44:19 -0500 Subject: [PATCH 6/9] try removing --- pandas/core/dtypes/cast.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index c73522589d2ba..99f1bdeb0b737 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -653,10 +653,6 @@ def astype_nansafe(arr, dtype, copy=True): ---------- arr : ndarray dtype : np.dtype - copy : bool or None, default True - Whether to copy during the `.astype` (True) or - just return a view (False). Passing `copy=None` will - attempt to return a view, but will copy if necessary. """ # dispatch on extension dtype if needed @@ -745,15 +741,7 @@ def astype_nansafe(arr, dtype, copy=True): if copy: return arr.astype(dtype, copy=True) - else: - try: - return arr.view(dtype) - except TypeError: - if copy is None: - # allowed to copy if necessary (e.g. object) - return arr.astype(dtype, copy=True) - else: - raise + return arr.view(dtype) def maybe_convert_objects(values, convert_dates=True, convert_numeric=True, From 767e3eefa8ce03341118ab0b1bf864a4ed62ac95 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 15 Aug 2018 09:02:15 -0500 Subject: [PATCH 7/9] Simpler than catching --- pandas/core/dtypes/cast.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 99f1bdeb0b737..4af47cad7761a 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -739,8 +739,10 @@ def astype_nansafe(arr, dtype, copy=True): FutureWarning, stacklevel=5) dtype = np.dtype(dtype.name + "[ns]") - if copy: + if copy or is_object_dtype(arr) or is_object_dtype(dtype): + # Explicit copy, or required since NumPy can't view from / to object. return arr.astype(dtype, copy=True) + return arr.view(dtype) From 5602330124dc3596032b05d1d423fd0c132bc72c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 15 Aug 2018 09:02:59 -0500 Subject: [PATCH 8/9] Docstring --- pandas/core/dtypes/cast.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 4af47cad7761a..4e1c41760c2cd 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -653,6 +653,7 @@ def astype_nansafe(arr, dtype, copy=True): ---------- arr : ndarray dtype : np.dtype + copy : bool, default True """ # dispatch on extension dtype if needed From 2606d02bb738fa914d48d5b460725b506d8e0a0a Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 16 Aug 2018 07:46:56 -0500 Subject: [PATCH 9/9] move test --- pandas/core/dtypes/cast.py | 2 ++ pandas/tests/extension/integer/test_integer.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 4e1c41760c2cd..410e061c895db 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -654,6 +654,8 @@ def astype_nansafe(arr, dtype, copy=True): arr : ndarray dtype : np.dtype copy : bool, default True + If False, a view will be attempted but may fail, if + e.g. the itemsizes don't align. """ # dispatch on extension dtype if needed diff --git a/pandas/tests/extension/integer/test_integer.py b/pandas/tests/extension/integer/test_integer.py index a71528d17524a..55451da8f7eed 100644 --- a/pandas/tests/extension/integer/test_integer.py +++ b/pandas/tests/extension/integer/test_integer.py @@ -567,14 +567,6 @@ def test_astype(self, all_data): expected = pd.Series(np.asarray(mixed)) tm.assert_series_equal(result, expected) - def test_astype_nansafe(self): - # https://github.com/pandas-dev/pandas/pull/22343 - arr = IntegerArray([np.nan, 1, 2], dtype="Int8") - - with tm.assert_raises_regex( - ValueError, 'cannot convert float NaN to integer'): - arr.astype('uint32') - @pytest.mark.parametrize('dtype', [Int8Dtype(), 'Int8']) def test_astype_specific_casting(self, dtype): s = pd.Series([1, 2, 3], dtype='Int64') @@ -705,6 +697,15 @@ def test_cross_type_arithmetic(): tm.assert_series_equal(result, expected) +def test_astype_nansafe(): + # https://github.com/pandas-dev/pandas/pull/22343 + arr = IntegerArray([np.nan, 1, 2], dtype="Int8") + + with tm.assert_raises_regex( + ValueError, 'cannot convert float NaN to integer'): + arr.astype('uint32') + + # TODO(jreback) - these need testing / are broken # shift