diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 7cc2e19f477dd..9dac2fe3c36c6 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -441,6 +441,7 @@ Groupby/resample/rolling Reshaping ^^^^^^^^^ - Bug in :func:`crosstab` when ``dropna=False`` would not keep ``np.nan`` in the result (:issue:`10772`) +- Bug in :func:`merge_asof` raising ``KeyError`` for extension dtypes (:issue:`52904`) - Bug in :meth:`DataFrame.agg` and :meth:`Series.agg` on non-unique columns would return incorrect type when dist-like argument passed in (:issue:`51099`) - Bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax`, where the axis dtype would be lost for empty frames (:issue:`53265`) - Bug in :meth:`DataFrame.merge` not merging correctly when having ``MultiIndex`` with single level (:issue:`52331`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 5a9e4a97eccea..817d5d0932744 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -149,6 +149,8 @@ def floordiv_compat( ) from pandas import Series + from pandas.core.arrays.datetimes import DatetimeArray + from pandas.core.arrays.timedeltas import TimedeltaArray def get_unit_from_pa_dtype(pa_dtype): @@ -1168,6 +1170,41 @@ def take( indices_array[indices_array < 0] += len(self._pa_array) return type(self)(self._pa_array.take(indices_array)) + def _maybe_convert_datelike_array(self): + """Maybe convert to a datelike array.""" + pa_type = self._pa_array.type + if pa.types.is_timestamp(pa_type): + return self._to_datetimearray() + elif pa.types.is_duration(pa_type): + return self._to_timedeltaarray() + return self + + def _to_datetimearray(self) -> DatetimeArray: + """Convert a pyarrow timestamp typed array to a DatetimeArray.""" + from pandas.core.arrays.datetimes import ( + DatetimeArray, + tz_to_dtype, + ) + + pa_type = self._pa_array.type + assert pa.types.is_timestamp(pa_type) + np_dtype = np.dtype(f"M8[{pa_type.unit}]") + dtype = tz_to_dtype(pa_type.tz, pa_type.unit) + np_array = self._pa_array.to_numpy() + np_array = np_array.astype(np_dtype) + return DatetimeArray._simple_new(np_array, dtype=dtype) + + def _to_timedeltaarray(self) -> TimedeltaArray: + """Convert a pyarrow duration typed array to a TimedeltaArray.""" + from pandas.core.arrays.timedeltas import TimedeltaArray + + pa_type = self._pa_array.type + assert pa.types.is_duration(pa_type) + np_dtype = np.dtype(f"m8[{pa_type.unit}]") + np_array = self._pa_array.to_numpy() + np_array = np_array.astype(np_dtype) + return TimedeltaArray._simple_new(np_array, dtype=np_dtype) + @doc(ExtensionArray.to_numpy) def to_numpy( self, @@ -1184,33 +1221,12 @@ def to_numpy( na_value = self.dtype.na_value pa_type = self._pa_array.type - if pa.types.is_timestamp(pa_type): - from pandas.core.arrays.datetimes import ( - DatetimeArray, - tz_to_dtype, - ) - - np_dtype = np.dtype(f"M8[{pa_type.unit}]") - result = self._pa_array.to_numpy() - result = result.astype(np_dtype, copy=copy) + if pa.types.is_timestamp(pa_type) or pa.types.is_duration(pa_type): + result = self._maybe_convert_datelike_array() if dtype is None or dtype.kind == "O": - dta_dtype = tz_to_dtype(pa_type.tz, pa_type.unit) - result = DatetimeArray._simple_new(result, dtype=dta_dtype) result = result.to_numpy(dtype=object, na_value=na_value) - elif result.dtype != dtype: - result = result.astype(dtype, copy=False) - return result - elif pa.types.is_duration(pa_type): - from pandas.core.arrays.timedeltas import TimedeltaArray - - np_dtype = np.dtype(f"m8[{pa_type.unit}]") - result = self._pa_array.to_numpy() - result = result.astype(np_dtype, copy=copy) - if dtype is None or dtype.kind == "O": - result = TimedeltaArray._simple_new(result, dtype=np_dtype) - result = result.to_numpy(dtype=object, na_value=na_value) - elif result.dtype != dtype: - result = result.astype(dtype, copy=False) + else: + result = result.to_numpy(dtype=dtype) return result elif pa.types.is_time(pa_type): # convert to list of python datetime.time objects before diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index ee5dff24c8bed..13f314ba8381e 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -2223,23 +2223,11 @@ def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str): ): data = data.to_numpy("int64", na_value=iNaT) copy = False - elif isinstance(data, ArrowExtensionArray) and data.dtype.kind == "M": - from pandas.core.arrays import DatetimeArray - from pandas.core.arrays.datetimes import tz_to_dtype - - pa_type = data._pa_array.type - dtype = tz_to_dtype(tz=pa_type.tz, unit=pa_type.unit) - data = data.to_numpy(f"M8[{pa_type.unit}]", na_value=iNaT) - data = DatetimeArray._simple_new(data, dtype=dtype) + elif isinstance(data, ArrowExtensionArray): + data = data._maybe_convert_datelike_array() + data = data.to_numpy() copy = False - elif isinstance(data, ArrowExtensionArray) and data.dtype.kind == "m": - pa_type = data._pa_array.type - dtype = np.dtype(f"m8[{pa_type.unit}]") - data = data.to_numpy(dtype, na_value=iNaT) - copy = False - elif not isinstance(data, (np.ndarray, ExtensionArray)) or isinstance( - data, ArrowExtensionArray - ): + elif not isinstance(data, (np.ndarray, ExtensionArray)): # GH#24539 e.g. xarray, dask object data = np.asarray(data) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index c9a5fb5c809ed..68e9006e85f7a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -151,7 +151,6 @@ Categorical, ExtensionArray, ) -from pandas.core.arrays.datetimes import tz_to_dtype from pandas.core.arrays.string_ import StringArray from pandas.core.base import ( IndexOpsMixin, @@ -192,11 +191,7 @@ MultiIndex, Series, ) - from pandas.core.arrays import ( - DatetimeArray, - PeriodArray, - TimedeltaArray, - ) + from pandas.core.arrays import PeriodArray __all__ = ["Index"] @@ -845,14 +840,10 @@ def _engine( pa_type = self._values._pa_array.type if pa.types.is_timestamp(pa_type): - dtype = tz_to_dtype(pa_type.tz, pa_type.unit) - target_values = self._values.astype(dtype) - target_values = cast("DatetimeArray", target_values) + target_values = self._values._to_datetimearray() return libindex.DatetimeEngine(target_values._ndarray) elif pa.types.is_duration(pa_type): - dtype = np.dtype(f"m8[{pa_type.unit}]") - target_values = self._values.astype(dtype) - target_values = cast("TimedeltaArray", target_values) + target_values = self._values._to_timedeltaarray() return libindex.TimedeltaEngine(target_values._ndarray) if isinstance(target_values, ExtensionArray): @@ -5117,14 +5108,10 @@ def _get_engine_target(self) -> ArrayLike: pa_type = vals._pa_array.type if pa.types.is_timestamp(pa_type): - dtype = tz_to_dtype(pa_type.tz, pa_type.unit) - vals = vals.astype(dtype) - vals = cast("DatetimeArray", vals) + vals = vals._to_datetimearray() return vals._ndarray.view("i8") elif pa.types.is_duration(pa_type): - dtype = np.dtype(f"m8[{pa_type.unit}]") - vals = vals.astype(dtype) - vals = cast("TimedeltaArray", vals) + vals = vals._to_timedeltaarray() return vals._ndarray.view("i8") if ( type(self) is Index diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index b9fed644192e4..65cd7e7983bfe 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -2112,6 +2112,12 @@ def injection(obj): raise ValueError(f"Merge keys contain null values on {side} side") raise ValueError(f"{side} keys must be sorted") + if isinstance(left_values, ArrowExtensionArray): + left_values = left_values._maybe_convert_datelike_array() + + if isinstance(right_values, ArrowExtensionArray): + right_values = right_values._maybe_convert_datelike_array() + # initial type conversion as needed if needs_i8_conversion(getattr(left_values, "dtype", None)): if tolerance is not None: @@ -2132,6 +2138,18 @@ def injection(obj): left_values = left_values.view("i8") right_values = right_values.view("i8") + if isinstance(left_values, BaseMaskedArray): + # we've verified above that no nulls exist + left_values = left_values._data + elif isinstance(left_values, ExtensionArray): + left_values = np.array(left_values) + + if isinstance(right_values, BaseMaskedArray): + # we've verified above that no nulls exist + right_values = right_values._data + elif isinstance(right_values, ExtensionArray): + right_values = np.array(right_values) + # a "by" parameter requires special handling if self.left_by is not None: # remove 'on' parameter from values if one existed diff --git a/pandas/tests/reshape/merge/test_merge_asof.py b/pandas/tests/reshape/merge/test_merge_asof.py index 2f0861201155b..d62dc44cda219 100644 --- a/pandas/tests/reshape/merge/test_merge_asof.py +++ b/pandas/tests/reshape/merge/test_merge_asof.py @@ -4,6 +4,8 @@ import pytest import pytz +import pandas.util._test_decorators as td + import pandas as pd from pandas import ( Index, @@ -1589,3 +1591,39 @@ def test_merge_asof_raise_for_duplicate_columns(): with pytest.raises(ValueError, match="column label 'a'"): merge_asof(left, right, left_on="left_val", right_on="a") + + +@pytest.mark.parametrize( + "dtype", + [ + "Int64", + pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")), + pytest.param("timestamp[s][pyarrow]", marks=td.skip_if_no("pyarrow")), + ], +) +def test_merge_asof_extension_dtype(dtype): + # GH 52904 + left = pd.DataFrame( + { + "join_col": [1, 3, 5], + "left_val": [1, 2, 3], + } + ) + right = pd.DataFrame( + { + "join_col": [2, 3, 4], + "right_val": [1, 2, 3], + } + ) + left = left.astype({"join_col": dtype}) + right = right.astype({"join_col": dtype}) + result = merge_asof(left, right, on="join_col") + expected = pd.DataFrame( + { + "join_col": [1, 3, 5], + "left_val": [1, 2, 3], + "right_val": [np.nan, 2.0, 3.0], + } + ) + expected = expected.astype({"join_col": dtype}) + tm.assert_frame_equal(result, expected)