Skip to content

Commit 470e945

Browse files
authored
BUG: merge_asof raising KeyError for extension dtypes (#53458)
* fix merge_asof raising KeyError for extension dtypes * reuse new methods elsewhere
1 parent d9067ba commit 470e945

File tree

6 files changed

+107
-59
lines changed

6 files changed

+107
-59
lines changed

doc/source/whatsnew/v2.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ Groupby/resample/rolling
441441
Reshaping
442442
^^^^^^^^^
443443
- Bug in :func:`crosstab` when ``dropna=False`` would not keep ``np.nan`` in the result (:issue:`10772`)
444+
- Bug in :func:`merge_asof` raising ``KeyError`` for extension dtypes (:issue:`52904`)
444445
- Bug in :meth:`DataFrame.agg` and :meth:`Series.agg` on non-unique columns would return incorrect type when dist-like argument passed in (:issue:`51099`)
445446
- Bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax`, where the axis dtype would be lost for empty frames (:issue:`53265`)
446447
- Bug in :meth:`DataFrame.merge` not merging correctly when having ``MultiIndex`` with single level (:issue:`52331`)

pandas/core/arrays/arrow/array.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def floordiv_compat(
149149
)
150150

151151
from pandas import Series
152+
from pandas.core.arrays.datetimes import DatetimeArray
153+
from pandas.core.arrays.timedeltas import TimedeltaArray
152154

153155

154156
def get_unit_from_pa_dtype(pa_dtype):
@@ -1168,6 +1170,41 @@ def take(
11681170
indices_array[indices_array < 0] += len(self._pa_array)
11691171
return type(self)(self._pa_array.take(indices_array))
11701172

1173+
def _maybe_convert_datelike_array(self):
1174+
"""Maybe convert to a datelike array."""
1175+
pa_type = self._pa_array.type
1176+
if pa.types.is_timestamp(pa_type):
1177+
return self._to_datetimearray()
1178+
elif pa.types.is_duration(pa_type):
1179+
return self._to_timedeltaarray()
1180+
return self
1181+
1182+
def _to_datetimearray(self) -> DatetimeArray:
1183+
"""Convert a pyarrow timestamp typed array to a DatetimeArray."""
1184+
from pandas.core.arrays.datetimes import (
1185+
DatetimeArray,
1186+
tz_to_dtype,
1187+
)
1188+
1189+
pa_type = self._pa_array.type
1190+
assert pa.types.is_timestamp(pa_type)
1191+
np_dtype = np.dtype(f"M8[{pa_type.unit}]")
1192+
dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
1193+
np_array = self._pa_array.to_numpy()
1194+
np_array = np_array.astype(np_dtype)
1195+
return DatetimeArray._simple_new(np_array, dtype=dtype)
1196+
1197+
def _to_timedeltaarray(self) -> TimedeltaArray:
1198+
"""Convert a pyarrow duration typed array to a TimedeltaArray."""
1199+
from pandas.core.arrays.timedeltas import TimedeltaArray
1200+
1201+
pa_type = self._pa_array.type
1202+
assert pa.types.is_duration(pa_type)
1203+
np_dtype = np.dtype(f"m8[{pa_type.unit}]")
1204+
np_array = self._pa_array.to_numpy()
1205+
np_array = np_array.astype(np_dtype)
1206+
return TimedeltaArray._simple_new(np_array, dtype=np_dtype)
1207+
11711208
@doc(ExtensionArray.to_numpy)
11721209
def to_numpy(
11731210
self,
@@ -1184,33 +1221,12 @@ def to_numpy(
11841221
na_value = self.dtype.na_value
11851222

11861223
pa_type = self._pa_array.type
1187-
if pa.types.is_timestamp(pa_type):
1188-
from pandas.core.arrays.datetimes import (
1189-
DatetimeArray,
1190-
tz_to_dtype,
1191-
)
1192-
1193-
np_dtype = np.dtype(f"M8[{pa_type.unit}]")
1194-
result = self._pa_array.to_numpy()
1195-
result = result.astype(np_dtype, copy=copy)
1224+
if pa.types.is_timestamp(pa_type) or pa.types.is_duration(pa_type):
1225+
result = self._maybe_convert_datelike_array()
11961226
if dtype is None or dtype.kind == "O":
1197-
dta_dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
1198-
result = DatetimeArray._simple_new(result, dtype=dta_dtype)
11991227
result = result.to_numpy(dtype=object, na_value=na_value)
1200-
elif result.dtype != dtype:
1201-
result = result.astype(dtype, copy=False)
1202-
return result
1203-
elif pa.types.is_duration(pa_type):
1204-
from pandas.core.arrays.timedeltas import TimedeltaArray
1205-
1206-
np_dtype = np.dtype(f"m8[{pa_type.unit}]")
1207-
result = self._pa_array.to_numpy()
1208-
result = result.astype(np_dtype, copy=copy)
1209-
if dtype is None or dtype.kind == "O":
1210-
result = TimedeltaArray._simple_new(result, dtype=np_dtype)
1211-
result = result.to_numpy(dtype=object, na_value=na_value)
1212-
elif result.dtype != dtype:
1213-
result = result.astype(dtype, copy=False)
1228+
else:
1229+
result = result.to_numpy(dtype=dtype)
12141230
return result
12151231
elif pa.types.is_time(pa_type):
12161232
# convert to list of python datetime.time objects before

pandas/core/arrays/datetimelike.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,23 +2223,11 @@ def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str):
22232223
):
22242224
data = data.to_numpy("int64", na_value=iNaT)
22252225
copy = False
2226-
elif isinstance(data, ArrowExtensionArray) and data.dtype.kind == "M":
2227-
from pandas.core.arrays import DatetimeArray
2228-
from pandas.core.arrays.datetimes import tz_to_dtype
2229-
2230-
pa_type = data._pa_array.type
2231-
dtype = tz_to_dtype(tz=pa_type.tz, unit=pa_type.unit)
2232-
data = data.to_numpy(f"M8[{pa_type.unit}]", na_value=iNaT)
2233-
data = DatetimeArray._simple_new(data, dtype=dtype)
2226+
elif isinstance(data, ArrowExtensionArray):
2227+
data = data._maybe_convert_datelike_array()
2228+
data = data.to_numpy()
22342229
copy = False
2235-
elif isinstance(data, ArrowExtensionArray) and data.dtype.kind == "m":
2236-
pa_type = data._pa_array.type
2237-
dtype = np.dtype(f"m8[{pa_type.unit}]")
2238-
data = data.to_numpy(dtype, na_value=iNaT)
2239-
copy = False
2240-
elif not isinstance(data, (np.ndarray, ExtensionArray)) or isinstance(
2241-
data, ArrowExtensionArray
2242-
):
2230+
elif not isinstance(data, (np.ndarray, ExtensionArray)):
22432231
# GH#24539 e.g. xarray, dask object
22442232
data = np.asarray(data)
22452233

pandas/core/indexes/base.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@
151151
Categorical,
152152
ExtensionArray,
153153
)
154-
from pandas.core.arrays.datetimes import tz_to_dtype
155154
from pandas.core.arrays.string_ import StringArray
156155
from pandas.core.base import (
157156
IndexOpsMixin,
@@ -192,11 +191,7 @@
192191
MultiIndex,
193192
Series,
194193
)
195-
from pandas.core.arrays import (
196-
DatetimeArray,
197-
PeriodArray,
198-
TimedeltaArray,
199-
)
194+
from pandas.core.arrays import PeriodArray
200195

201196
__all__ = ["Index"]
202197

@@ -845,14 +840,10 @@ def _engine(
845840

846841
pa_type = self._values._pa_array.type
847842
if pa.types.is_timestamp(pa_type):
848-
dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
849-
target_values = self._values.astype(dtype)
850-
target_values = cast("DatetimeArray", target_values)
843+
target_values = self._values._to_datetimearray()
851844
return libindex.DatetimeEngine(target_values._ndarray)
852845
elif pa.types.is_duration(pa_type):
853-
dtype = np.dtype(f"m8[{pa_type.unit}]")
854-
target_values = self._values.astype(dtype)
855-
target_values = cast("TimedeltaArray", target_values)
846+
target_values = self._values._to_timedeltaarray()
856847
return libindex.TimedeltaEngine(target_values._ndarray)
857848

858849
if isinstance(target_values, ExtensionArray):
@@ -5117,14 +5108,10 @@ def _get_engine_target(self) -> ArrayLike:
51175108

51185109
pa_type = vals._pa_array.type
51195110
if pa.types.is_timestamp(pa_type):
5120-
dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
5121-
vals = vals.astype(dtype)
5122-
vals = cast("DatetimeArray", vals)
5111+
vals = vals._to_datetimearray()
51235112
return vals._ndarray.view("i8")
51245113
elif pa.types.is_duration(pa_type):
5125-
dtype = np.dtype(f"m8[{pa_type.unit}]")
5126-
vals = vals.astype(dtype)
5127-
vals = cast("TimedeltaArray", vals)
5114+
vals = vals._to_timedeltaarray()
51285115
return vals._ndarray.view("i8")
51295116
if (
51305117
type(self) is Index

pandas/core/reshape/merge.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,6 +2112,12 @@ def injection(obj):
21122112
raise ValueError(f"Merge keys contain null values on {side} side")
21132113
raise ValueError(f"{side} keys must be sorted")
21142114

2115+
if isinstance(left_values, ArrowExtensionArray):
2116+
left_values = left_values._maybe_convert_datelike_array()
2117+
2118+
if isinstance(right_values, ArrowExtensionArray):
2119+
right_values = right_values._maybe_convert_datelike_array()
2120+
21152121
# initial type conversion as needed
21162122
if needs_i8_conversion(getattr(left_values, "dtype", None)):
21172123
if tolerance is not None:
@@ -2132,6 +2138,18 @@ def injection(obj):
21322138
left_values = left_values.view("i8")
21332139
right_values = right_values.view("i8")
21342140

2141+
if isinstance(left_values, BaseMaskedArray):
2142+
# we've verified above that no nulls exist
2143+
left_values = left_values._data
2144+
elif isinstance(left_values, ExtensionArray):
2145+
left_values = np.array(left_values)
2146+
2147+
if isinstance(right_values, BaseMaskedArray):
2148+
# we've verified above that no nulls exist
2149+
right_values = right_values._data
2150+
elif isinstance(right_values, ExtensionArray):
2151+
right_values = np.array(right_values)
2152+
21352153
# a "by" parameter requires special handling
21362154
if self.left_by is not None:
21372155
# remove 'on' parameter from values if one existed

pandas/tests/reshape/merge/test_merge_asof.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pytest
55
import pytz
66

7+
import pandas.util._test_decorators as td
8+
79
import pandas as pd
810
from pandas import (
911
Index,
@@ -1589,3 +1591,39 @@ def test_merge_asof_raise_for_duplicate_columns():
15891591

15901592
with pytest.raises(ValueError, match="column label 'a'"):
15911593
merge_asof(left, right, left_on="left_val", right_on="a")
1594+
1595+
1596+
@pytest.mark.parametrize(
1597+
"dtype",
1598+
[
1599+
"Int64",
1600+
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
1601+
pytest.param("timestamp[s][pyarrow]", marks=td.skip_if_no("pyarrow")),
1602+
],
1603+
)
1604+
def test_merge_asof_extension_dtype(dtype):
1605+
# GH 52904
1606+
left = pd.DataFrame(
1607+
{
1608+
"join_col": [1, 3, 5],
1609+
"left_val": [1, 2, 3],
1610+
}
1611+
)
1612+
right = pd.DataFrame(
1613+
{
1614+
"join_col": [2, 3, 4],
1615+
"right_val": [1, 2, 3],
1616+
}
1617+
)
1618+
left = left.astype({"join_col": dtype})
1619+
right = right.astype({"join_col": dtype})
1620+
result = merge_asof(left, right, on="join_col")
1621+
expected = pd.DataFrame(
1622+
{
1623+
"join_col": [1, 3, 5],
1624+
"left_val": [1, 2, 3],
1625+
"right_val": [np.nan, 2.0, 3.0],
1626+
}
1627+
)
1628+
expected = expected.astype({"join_col": dtype})
1629+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)