From 9d637317ebc7cb1698af40d29a13996d3fefa52e Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 30 Mar 2023 17:37:51 -0700 Subject: [PATCH] TYP: concat_compat --- pandas/core/algorithms.py | 6 ++++- pandas/core/dtypes/cast.py | 3 ++- pandas/core/dtypes/concat.py | 46 ++++++++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 5cddf3c2c865b..12e09795f2304 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1666,7 +1666,11 @@ def union_with_duplicates( lvals = lvals._values if isinstance(rvals, ABCIndex): rvals = rvals._values - unique_vals = unique(concat_compat([lvals, rvals])) + # error: List item 0 has incompatible type "Union[ExtensionArray, + # ndarray[Any, Any], Index]"; expected "Union[ExtensionArray, + # ndarray[Any, Any]]" + combined = concat_compat([lvals, rvals]) # type: ignore[list-item] + unique_vals = unique(combined) unique_vals = ensure_wrapped_if_datetimelike(unique_vals) repeats = final_count.reindex(unique_vals).values return np.repeat(unique_vals, repeats) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 9a74da33db531..b758c6cfc3c13 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -10,6 +10,7 @@ TYPE_CHECKING, Any, Literal, + Sequence, Sized, TypeVar, cast, @@ -1317,7 +1318,7 @@ def find_result_type(left: ArrayLike, right: Any) -> DtypeObj: def common_dtype_categorical_compat( - objs: list[Index | ArrayLike], dtype: DtypeObj + objs: Sequence[Index | ArrayLike], dtype: DtypeObj ) -> DtypeObj: """ Update the result of find_common_type to account for NAs in a Categorical. diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index 9917af9da7665..1e4aeb50a348e 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -3,7 +3,11 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Sequence, + cast, +) import numpy as np @@ -24,12 +28,20 @@ ) if TYPE_CHECKING: - from pandas._typing import AxisInt + from pandas._typing import ( + ArrayLike, + AxisInt, + ) - from pandas.core.arrays import Categorical + from pandas.core.arrays import ( + Categorical, + ExtensionArray, + ) -def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False): +def concat_compat( + to_concat: Sequence[ArrayLike], axis: AxisInt = 0, ea_compat_axis: bool = False +) -> ArrayLike: """ provide concatenation of an array of arrays each of which is a single 'normalized' dtypes (in that for example, if it's object, then it is a @@ -38,7 +50,7 @@ def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False): Parameters ---------- - to_concat : array of arrays + to_concat : sequence of arrays axis : axis to provide concatenation ea_compat_axis : bool, default False For ExtensionArray compat, behave as if axis == 1 when determining @@ -93,10 +105,12 @@ def is_nonempty(x) -> bool: if isinstance(to_concat[0], ABCExtensionArray): # TODO: what about EA-backed Index? + to_concat_eas = cast("Sequence[ExtensionArray]", to_concat) cls = type(to_concat[0]) - return cls._concat_same_type(to_concat) + return cls._concat_same_type(to_concat_eas) else: - return np.concatenate(to_concat) + to_concat_arrs = cast("Sequence[np.ndarray]", to_concat) + return np.concatenate(to_concat_arrs) elif all_empty: # we have all empties, but may need to coerce the result dtype to @@ -111,7 +125,10 @@ def is_nonempty(x) -> bool: to_concat = [x.astype("object") for x in to_concat] kinds = {"o"} - result = np.concatenate(to_concat, axis=axis) + # error: Argument 1 to "concatenate" has incompatible type + # "Sequence[Union[ExtensionArray, ndarray[Any, Any]]]"; expected + # "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]]]" + result: np.ndarray = np.concatenate(to_concat, axis=axis) # type: ignore[arg-type] if "b" in kinds and result.dtype.kind in ["i", "u", "f"]: # GH#39817 cast to object instead of casting bools to numeric result = result.astype(object, copy=False) @@ -283,21 +300,21 @@ def _maybe_unwrap(x): return Categorical(new_codes, categories=categories, ordered=ordered, fastpath=True) -def _concatenate_2d(to_concat, axis: AxisInt): +def _concatenate_2d(to_concat: Sequence[np.ndarray], axis: AxisInt) -> np.ndarray: # coerce to 2d if needed & concatenate if axis == 1: to_concat = [np.atleast_2d(x) for x in to_concat] return np.concatenate(to_concat, axis=axis) -def _concat_datetime(to_concat, axis: AxisInt = 0): +def _concat_datetime(to_concat: Sequence[ArrayLike], axis: AxisInt = 0) -> ArrayLike: """ provide concatenation of an datetimelike array of arrays each of which is a single M8[ns], datetime64[ns, tz] or m8[ns] dtype Parameters ---------- - to_concat : array of arrays + to_concat : sequence of arrays axis : axis to provide concatenation Returns @@ -316,5 +333,10 @@ def _concat_datetime(to_concat, axis: AxisInt = 0): # in Timestamp/Timedelta return _concatenate_2d([x.astype(object) for x in to_concat], axis=axis) - result = type(to_concat[0])._concat_same_type(to_concat, axis=axis) + # error: Unexpected keyword argument "axis" for "_concat_same_type" of + # "ExtensionArray" + to_concat_eas = cast("list[ExtensionArray]", to_concat) + result = type(to_concat_eas[0])._concat_same_type( # type: ignore[call-arg] + to_concat_eas, axis=axis + ) return result