diff --git a/pandas/core/array_algos/quantile.py b/pandas/core/array_algos/quantile.py index d3d9cb1b29b9a..4ad7a1549e207 100644 --- a/pandas/core/array_algos/quantile.py +++ b/pandas/core/array_algos/quantile.py @@ -1,7 +1,11 @@ from __future__ import annotations +from functools import partial +from typing import Literal + import numpy as np +from pandas._libs import groupby as libgroupby from pandas._typing import ( ArrayLike, Scalar, @@ -15,6 +19,94 @@ ) +def groupby_quantile_ndim_compat( + *, + qs: npt.NDArray[np.float64], + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], + ngroups: int, + ids: npt.NDArray[np.intp], + labels_for_lexsort: npt.NDArray[np.intp], + npy_vals: np.ndarray, + mask: npt.NDArray[np.bool_], + result_mask: npt.NDArray[np.bool_] | None, +) -> np.ndarray: + """ + Compatibility layer to handle either 1D arrays or 2D ndarrays in + GroupBy.quantile. Located here to be available to + ExtensionArray._groupby_quantile for dispatching after casting to numpy. + + Parameters + ---------- + qs : np.ndarray[float64] + Values between 0 and 1 providing the quantile(s) to compute. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + Method to use when the desired quantile falls between two points. + ngroups : int + The number of groupby groups. + ids : np.ndarray[intp] + Group labels. + labels_for_lexsort : np.ndarray[intp] + Group labels, but with -1s moved moved to the end to sort NAs last. + npy_vals : np.ndarray + The values for which we are computing quantiles. + mask : np.ndarray[bool] + Locations to treat as NA. + result_mask : np.ndarray[bool] or None + If present, set to True for locations that should be treated as missing + a result. Modified in-place. + + Returns + ------- + np.ndarray + """ + nqs = len(qs) + + ncols = 1 + if npy_vals.ndim == 2: + ncols = npy_vals.shape[0] + shaped_labels = np.broadcast_to( + labels_for_lexsort, (ncols, len(labels_for_lexsort)) + ) + else: + shaped_labels = labels_for_lexsort + + npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64) + + # Get an index of values sorted by values and then labels + order = (npy_vals, shaped_labels) + sort_arr = np.lexsort(order).astype(np.intp, copy=False) + + func = partial( + libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation + ) + + if npy_vals.ndim == 1: + func( + npy_out[0], + values=npy_vals, + mask=mask.view(np.uint8), + sort_indexer=sort_arr, + result_mask=result_mask, + ) + else: + # if we ever did get here with non-None result_mask, we'd pass result_mask[i] + assert result_mask is None + for i in range(ncols): + func( + npy_out[i], + values=npy_vals[i], + mask=mask[i].view(np.uint8), + sort_indexer=sort_arr[i], + ) + + if npy_vals.ndim == 1: + npy_out = npy_out.reshape(ngroups * nqs) + else: + npy_out = npy_out.reshape(ncols, ngroups * nqs) + + return npy_out + + def quantile_compat( values: ArrayLike, qs: npt.NDArray[np.float64], interpolation: str ) -> ArrayLike: diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c261a41e1e77e..de8eae511696b 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -30,6 +30,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, FillnaOptions, PositionalIndexer, ScalarIndexer, @@ -55,9 +56,13 @@ from pandas.core.dtypes.cast import maybe_cast_to_extension_array from pandas.core.dtypes.common import ( + is_bool_dtype, is_datetime64_dtype, is_dtype_equal, + is_float_dtype, + is_integer_dtype, is_list_like, + is_object_dtype, is_scalar, is_timedelta64_dtype, pandas_dtype, @@ -82,7 +87,10 @@ rank, unique, ) -from pandas.core.array_algos.quantile import quantile_with_mask +from pandas.core.array_algos.quantile import ( + groupby_quantile_ndim_compat, + quantile_with_mask, +) from pandas.core.sorting import ( nargminmax, nargsort, @@ -1688,6 +1696,61 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs) + def _groupby_quantile( + self, + *, + qs: npt.NDArray[np.float64], + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], + ngroups: int, + ids: npt.NDArray[np.intp], + labels_for_lexsort: npt.NDArray[np.intp], + ): + + mask = self.isna() + + inference: DtypeObj | None = None + + # TODO: 2023-01-26 we only have tests for the dt64/td64 cases here + if is_object_dtype(self.dtype): + raise TypeError("'quantile' cannot be performed against 'object' dtypes!") + elif is_integer_dtype(self.dtype): + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + inference = np.dtype(np.int64) + elif is_bool_dtype(self.dtype): + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + elif is_datetime64_dtype(self.dtype): + inference = self.dtype + npy_vals = np.asarray(self).astype(float) + elif is_timedelta64_dtype(self.dtype): + inference = self.dtype + npy_vals = np.asarray(self).astype(float) + elif is_float_dtype(self): + inference = np.dtype(np.float64) + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + else: + npy_vals = np.asarray(self) + + npy_out = groupby_quantile_ndim_compat( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + npy_vals=npy_vals, + mask=np.asarray(mask), + result_mask=None, + ) + + if inference is not None: + # Check for edge case + if not ( + is_integer_dtype(inference) and interpolation in {"linear", "midpoint"} + ): + assert isinstance(inference, np.dtype) # for mypy + return npy_out.astype(inference) + + return npy_out + class ExtensionArraySupportsAnyAll(ExtensionArray): def any(self, *, skipna: bool = True) -> bool: diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 8324d4b2618f1..b13c9418264a3 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -77,7 +77,10 @@ masked_accumulations, masked_reductions, ) -from pandas.core.array_algos.quantile import quantile_with_mask +from pandas.core.array_algos.quantile import ( + groupby_quantile_ndim_compat, + quantile_with_mask, +) from pandas.core.arraylike import OpsMixin from pandas.core.arrays import ExtensionArray from pandas.core.construction import ensure_wrapped_if_datetimelike @@ -1383,3 +1386,45 @@ def _accumulate( data, mask = op(data, mask, skipna=skipna, **kwargs) return type(self)(data, mask, copy=False) + + # ------------------------------------------------------------------ + + @doc(ExtensionArray._groupby_quantile) + def _groupby_quantile( + self, + *, + qs: npt.NDArray[np.float64], + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], + ngroups: int, + ids: npt.NDArray[np.intp], + labels_for_lexsort: npt.NDArray[np.intp], + ): + + nqs = len(qs) + + mask = self._mask + result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) + + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + + npy_out = groupby_quantile_ndim_compat( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + npy_vals=npy_vals, + mask=mask, + result_mask=result_mask, + ) + result_mask = result_mask.reshape(ngroups * nqs) + + if interpolation in {"linear", "midpoint"} and not is_float_dtype(self.dtype): + from pandas.core.arrays import FloatingArray + + return FloatingArray(npy_out, result_mask) + else: + return type(self)( + npy_out.astype(self.dtype.numpy_dtype), + result_mask, + ) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index a54c524094b23..392b94950079f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -47,7 +47,6 @@ class providing the base-class of operations. ArrayLike, Axis, AxisInt, - DtypeObj, FillnaOptions, IndexLabel, NDFrameT, @@ -71,16 +70,12 @@ class providing the base-class of operations. from pandas.core.dtypes.cast import ensure_dtype_can_hold_na from pandas.core.dtypes.common import ( - is_bool_dtype, - is_datetime64_dtype, - is_float_dtype, is_hashable, is_integer, is_integer_dtype, is_numeric_dtype, is_object_dtype, is_scalar, - is_timedelta64_dtype, ) from pandas.core.dtypes.missing import ( isna, @@ -93,6 +88,7 @@ class providing the base-class of operations. sample, ) from pandas.core._numba import executor +from pandas.core.array_algos.quantile import groupby_quantile_ndim_compat from pandas.core.arrays import ( BaseMaskedArray, BooleanArray, @@ -3002,7 +2998,9 @@ def _nth( def quantile( self, q: float | AnyArrayLike = 0.5, - interpolation: str = "linear", + interpolation: Literal[ + "linear", "lower", "higher", "nearest", "midpoint" + ] = "linear", numeric_only: bool = False, ): """ @@ -3047,73 +3045,6 @@ def quantile( b 3.0 """ - def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, DtypeObj | None]: - if is_object_dtype(vals): - raise TypeError( - "'quantile' cannot be performed against 'object' dtypes!" - ) - - inference: DtypeObj | None = None - if isinstance(vals, BaseMaskedArray) and is_numeric_dtype(vals.dtype): - out = vals.to_numpy(dtype=float, na_value=np.nan) - inference = vals.dtype - elif is_integer_dtype(vals.dtype): - if isinstance(vals, ExtensionArray): - out = vals.to_numpy(dtype=float, na_value=np.nan) - else: - out = vals - inference = np.dtype(np.int64) - elif is_bool_dtype(vals.dtype) and isinstance(vals, ExtensionArray): - out = vals.to_numpy(dtype=float, na_value=np.nan) - elif is_datetime64_dtype(vals.dtype): - inference = vals.dtype - out = np.asarray(vals).astype(float) - elif is_timedelta64_dtype(vals.dtype): - inference = vals.dtype - out = np.asarray(vals).astype(float) - elif isinstance(vals, ExtensionArray) and is_float_dtype(vals): - inference = np.dtype(np.float64) - out = vals.to_numpy(dtype=float, na_value=np.nan) - else: - out = np.asarray(vals) - - return out, inference - - def post_processor( - vals: np.ndarray, - inference: DtypeObj | None, - result_mask: np.ndarray | None, - orig_vals: ArrayLike, - ) -> ArrayLike: - if inference: - # Check for edge case - if isinstance(orig_vals, BaseMaskedArray): - assert result_mask is not None # for mypy - - if interpolation in {"linear", "midpoint"} and not is_float_dtype( - orig_vals - ): - return FloatingArray(vals, result_mask) - else: - # Item "ExtensionDtype" of "Union[ExtensionDtype, str, - # dtype[Any], Type[object]]" has no attribute "numpy_dtype" - # [union-attr] - return type(orig_vals)( - vals.astype( - inference.numpy_dtype # type: ignore[union-attr] - ), - result_mask, - ) - - elif not ( - is_integer_dtype(inference) - and interpolation in {"linear", "midpoint"} - ): - assert isinstance(inference, np.dtype) # for mypy - return vals.astype(inference) - - return vals - orig_scalar = is_scalar(q) if orig_scalar: # error: Incompatible types in assignment (expression has type "List[ @@ -3124,11 +3055,6 @@ def post_processor( qs = np.array(q, dtype=np.float64) ids, _, ngroups = self.grouper.group_info - nqs = len(qs) - - func = partial( - libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation - ) # Put '-1' (NaN) labels as the last group so it does not interfere # with the calculations. Note: length check avoids failure on empty @@ -3137,51 +3063,37 @@ def post_processor( labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids) def blk_func(values: ArrayLike) -> ArrayLike: - orig_vals = values - if isinstance(values, BaseMaskedArray): - mask = values._mask - result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) - else: - mask = isna(values) - result_mask = None - - vals, inference = pre_processor(values) - - ncols = 1 - if vals.ndim == 2: - ncols = vals.shape[0] - shaped_labels = np.broadcast_to( - labels_for_lexsort, (ncols, len(labels_for_lexsort)) + if isinstance(values, ExtensionArray): + return values._groupby_quantile( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, ) - else: - shaped_labels = labels_for_lexsort - - out = np.empty((ncols, ngroups, nqs), dtype=np.float64) - # Get an index of values sorted by values and then labels - order = (vals, shaped_labels) - sort_arr = np.lexsort(order).astype(np.intp, copy=False) - - if vals.ndim == 1: - # Ea is always 1d - func( - out[0], - values=vals, - mask=mask, - sort_indexer=sort_arr, - result_mask=result_mask, + if is_object_dtype(values.dtype): + raise TypeError( + "'quantile' cannot be performed against 'object' dtypes!" ) - else: - for i in range(ncols): - func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i]) - if vals.ndim == 1: - out = out.ravel("K") - if result_mask is not None: - result_mask = result_mask.ravel("K") - else: - out = out.reshape(ncols, ngroups * nqs) - return post_processor(out, inference, result_mask, orig_vals) + npy_out = groupby_quantile_ndim_compat( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + npy_vals=values, + mask=isna(values), + result_mask=None, + ) + + if is_integer_dtype(values.dtype) and interpolation not in { + "linear", + "midpoint", + }: + return npy_out.astype(np.dtype(np.int64)) + return npy_out data = self._get_data_to_aggregate(numeric_only=numeric_only, name="quantile") res_mgr = data.grouped_reduce(blk_func)