From c3699d092f24d6b9ecc1ddaea452df209ea6d26a Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 24 May 2021 11:02:50 -0700 Subject: [PATCH 1/5] REF: avoid special-case in EA.astype --- pandas/core/arrays/base.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 7dddb9f3d6f25..44d345b705bb8 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -515,7 +515,7 @@ def nbytes(self) -> int: # Additional Methods # ------------------------------------------------------------------------ - def astype(self, dtype, copy=True): + def astype(self, dtype: Dtype, copy: bool = True): """ Cast to a NumPy array with 'dtype'. @@ -530,10 +530,8 @@ def astype(self, dtype, copy=True): Returns ------- - array : ndarray - NumPy ndarray with 'dtype' for its dtype. + np.ndarray or ExtensionArray """ - from pandas.core.arrays.string_ import StringDtype dtype = pandas_dtype(dtype) if is_dtype_equal(dtype, self.dtype): @@ -542,10 +540,10 @@ def astype(self, dtype, copy=True): else: return self.copy() - # FIXME: Really hard-code here? - if isinstance(dtype, StringDtype): - # allow conversion to StringArrays - return dtype.construct_array_type()._from_sequence(self, copy=False) + if isinstance(dtype, ExtensionDtype): + # allow conversion to e.g. StringArrays + cls = dtype.construct_array_type() + return cls._from_sequence(self, dtype=dtype, copy=copy) return np.array(self, dtype=dtype, copy=copy) From ef8e571b5afe44e5e5797e36a74a5639fd1d057b Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 24 May 2021 13:25:53 -0700 Subject: [PATCH 2/5] REF: standardize usage in Categorical.astype --- pandas/core/arrays/categorical.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index cb8a08f5668ac..d32386d1f42c1 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -101,7 +101,6 @@ ) import pandas.core.common as com from pandas.core.construction import ( - array as pd_array, extract_array, sanitize_array, ) @@ -501,19 +500,18 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: """ dtype = pandas_dtype(dtype) if self.dtype is dtype: - result = self.copy() if copy else self + return self.copy() if copy else self elif is_categorical_dtype(dtype): dtype = cast(Union[str, CategoricalDtype], dtype) # GH 10696/18593/18630 dtype = self.dtype.update_dtype(dtype) - self = self.copy() if copy else self - result = self._set_dtype(dtype) + obj = self.copy() if copy else self + return obj._set_dtype(dtype) - # TODO: consolidate with ndarray case? elif isinstance(dtype, ExtensionDtype): - result = pd_array(self, dtype=dtype, copy=copy) + return super().astype(dtype, copy=copy) elif is_integer_dtype(dtype) and self.isna().any(): raise ValueError("Cannot convert float NaN to integer") From 690d443ff76fe6ef8ff5942fcf5522349e7ac02d Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 24 May 2021 13:27:04 -0700 Subject: [PATCH 3/5] REF: standardize astype --- pandas/core/arrays/masked.py | 3 +-- pandas/core/arrays/string_.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 11f9f645920ec..e2d6c1821abb0 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -320,8 +320,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: return cls(data, mask, copy=False) if isinstance(dtype, ExtensionDtype): - eacls = dtype.construct_array_type() - return eacls._from_sequence(self, dtype=dtype, copy=copy) + return super().astype(dtype=dtype, copy=copy) raise NotImplementedError("subclass must implement astype to np.dtype") diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 74ca5130ca322..03af604ad8d10 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -331,8 +331,8 @@ def astype(self, dtype, copy=True): values = arr.astype(dtype.numpy_dtype) return FloatingArray(values, mask, copy=False) elif isinstance(dtype, ExtensionDtype): - cls = dtype.construct_array_type() - return cls._from_sequence(self, dtype=dtype, copy=copy) + return super().astype(dtype=dtype, copy=copy) + elif np.issubdtype(dtype, np.floating): arr = self._ndarray.copy() mask = self.isna() From 99f4577e632f1d913dc9dea2e4b8c70b72f3d571 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 24 May 2021 15:55:24 -0700 Subject: [PATCH 4/5] REF: standardize Index subclass astypes --- pandas/core/arrays/datetimelike.py | 16 ++++++---------- pandas/core/arrays/interval.py | 15 +++++++-------- pandas/core/dtypes/cast.py | 4 ++-- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index ff46715d0a527..2d94943d6b0d9 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -69,7 +69,6 @@ is_datetime64tz_dtype, is_datetime_or_timedelta_dtype, is_dtype_equal, - is_extension_array_dtype, is_float_dtype, is_integer_dtype, is_list_like, @@ -82,6 +81,7 @@ ) from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, + ExtensionDtype, PeriodDtype, ) from pandas.core.dtypes.missing import ( @@ -381,14 +381,13 @@ def astype(self, dtype, copy: bool = True): # 3. DatetimeArray.astype handles datetime -> period dtype = pandas_dtype(dtype) + if isinstance(dtype, ExtensionDtype): + return super().astype(dtype=dtype, copy=copy) + if is_object_dtype(dtype): return self._box_values(self.asi8.ravel()).reshape(self.shape) - elif is_string_dtype(dtype) and not is_categorical_dtype(dtype): - if is_extension_array_dtype(dtype): - arr_cls = dtype.construct_array_type() - return arr_cls._from_sequence(self, dtype=dtype, copy=copy) - else: - return self._format_native_types() + elif is_string_dtype(dtype): + return self._format_native_types() elif is_integer_dtype(dtype): # we deliberately ignore int32 vs. int64 here. # See https://github.com/pandas-dev/pandas/issues/24381 for more. @@ -418,9 +417,6 @@ def astype(self, dtype, copy: bool = True): # and conversions for any datetimelike to float msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" raise TypeError(msg) - elif is_categorical_dtype(dtype): - arr_cls = dtype.construct_array_type() - return arr_cls(self, dtype=dtype) else: return np.asarray(self, dtype=dtype) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index a99bf245a6073..2db2fef6a466a 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -49,7 +49,10 @@ needs_i8_conversion, pandas_dtype, ) -from pandas.core.dtypes.dtypes import IntervalDtype +from pandas.core.dtypes.dtypes import ( + ExtensionDtype, + IntervalDtype, +) from pandas.core.dtypes.generic import ( ABCDataFrame, ABCDatetimeIndex, @@ -71,7 +74,6 @@ ExtensionArray, _extension_array_shared_docs, ) -from pandas.core.arrays.categorical import Categorical import pandas.core.common as com from pandas.core.construction import ( array as pd_array, @@ -828,7 +830,6 @@ def astype(self, dtype, copy: bool = True): ExtensionArray or NumPy ndarray with 'dtype' for its dtype. """ from pandas import Index - from pandas.core.arrays.string_ import StringDtype if dtype is not None: dtype = pandas_dtype(dtype) @@ -849,13 +850,11 @@ def astype(self, dtype, copy: bool = True): ) raise TypeError(msg) from err return self._shallow_copy(new_left, new_right) - elif is_categorical_dtype(dtype): - return Categorical(np.asarray(self), dtype=dtype) - elif isinstance(dtype, StringDtype): - return dtype.construct_array_type()._from_sequence(self, copy=False) - # TODO: This try/except will be repeated. try: + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + return cls._from_sequence(self, dtype=dtype, copy=copy) return np.asarray(self).astype(dtype, copy=copy) except (TypeError, ValueError) as err: msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index e3616bc857140..888985ffac575 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1126,7 +1126,7 @@ def astype_nansafe( if issubclass(dtype.type, str): return lib.ensure_string_array(arr, skipna=skipna, convert_na_value=False) - elif is_datetime64_dtype(arr): + elif is_datetime64_dtype(arr.dtype): if dtype == np.int64: warnings.warn( f"casting {arr.dtype} values to int64 with .astype(...) " @@ -1146,7 +1146,7 @@ def astype_nansafe( raise TypeError(f"cannot astype a datetimelike from [{arr.dtype}] to [{dtype}]") - elif is_timedelta64_dtype(arr): + elif is_timedelta64_dtype(arr.dtype): if dtype == np.int64: warnings.warn( f"casting {arr.dtype} values to int64 with .astype(...) " From ea77795282b915da3cf3d1cb58b3bb73b2930f3d Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 25 May 2021 07:08:54 -0700 Subject: [PATCH 5/5] revert annotation --- pandas/core/arrays/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 44d345b705bb8..abed92207b632 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -515,7 +515,7 @@ def nbytes(self) -> int: # Additional Methods # ------------------------------------------------------------------------ - def astype(self, dtype: Dtype, copy: bool = True): + def astype(self, dtype, copy=True): """ Cast to a NumPy array with 'dtype'.