From 89c97cfbb6e6b4a999dad62167ec2c4a6a12e986 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 13:09:48 -0800 Subject: [PATCH 1/9] REF: implemeent EA.groupby_quantile --- pandas/core/arrays/base.py | 148 +++++++++++++++++++++++++++++++++ pandas/core/groupby/groupby.py | 9 ++ 2 files changed, 157 insertions(+) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c261a41e1e77e..bc1f2cbb1189a 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1688,6 +1688,154 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs) + def groupby_quantile( + self, + *, + qs: npt.NDArray[np.float64], + interpolation: str, + ngroups: int, + ids: npt.NDArray[np.intp], + labels_for_lexsort: npt.NDArray[np.intp], + ): + from functools import partial + + from pandas._libs import groupby as libgroupby + + from pandas.core.dtypes.common import ( + is_bool_dtype, + is_float_dtype, + is_integer_dtype, + is_numeric_dtype, + is_object_dtype, + ) + + from pandas.core.arrays import ( + BaseMaskedArray, + FloatingArray, + ) + + nqs = len(qs) + + func = partial( + libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation + ) + + def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, Dtype | None]: + if is_object_dtype(vals): + raise TypeError( + "'quantile' cannot be performed against 'object' dtypes!" + ) + + inference: Dtype | None = None + if isinstance(vals, BaseMaskedArray) and is_numeric_dtype(vals.dtype): + out = vals.to_numpy(dtype=float, na_value=np.nan) + inference = vals.dtype + elif is_integer_dtype(vals.dtype): + if isinstance(vals, ExtensionArray): + out = vals.to_numpy(dtype=float, na_value=np.nan) + else: + out = vals + inference = np.dtype(np.int64) + elif is_bool_dtype(vals.dtype) and isinstance(vals, ExtensionArray): + out = vals.to_numpy(dtype=float, na_value=np.nan) + elif is_datetime64_dtype(vals.dtype): + inference = vals.dtype + out = np.asarray(vals).astype(float) + elif is_timedelta64_dtype(vals.dtype): + inference = vals.dtype + out = np.asarray(vals).astype(float) + elif isinstance(vals, ExtensionArray) and is_float_dtype(vals): + inference = np.dtype(np.float64) + out = vals.to_numpy(dtype=float, na_value=np.nan) + else: + out = np.asarray(vals) + + return out, inference + + def post_processor( + vals: np.ndarray, + inference: Dtype | None, + result_mask: np.ndarray | None, + orig_vals: ArrayLike, + ) -> ArrayLike: + if inference: + # Check for edge case + if isinstance(orig_vals, BaseMaskedArray): + assert result_mask is not None # for mypy + + if interpolation in {"linear", "midpoint"} and not is_float_dtype( + orig_vals + ): + return FloatingArray(vals, result_mask) + else: + # Item "ExtensionDtype" of "Union[ExtensionDtype, str, + # dtype[Any], Type[object]]" has no attribute "numpy_dtype" + # [union-attr] + return type(orig_vals)( + vals.astype( + inference.numpy_dtype # type: ignore[union-attr] + ), + result_mask, + ) + + elif not ( + is_integer_dtype(inference) + and interpolation in {"linear", "midpoint"} + ): + assert isinstance(inference, np.dtype) # for mypy + return vals.astype(inference) + + return vals + + def blk_func(values: ArrayLike) -> ArrayLike: + orig_vals = values + if isinstance(values, BaseMaskedArray): + mask = values._mask + result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) + else: + mask = isna(values) + result_mask = None + + vals, inference = pre_processor(values) + + ncols = 1 + if vals.ndim == 2: + ncols = vals.shape[0] + shaped_labels = np.broadcast_to( + labels_for_lexsort, (ncols, len(labels_for_lexsort)) + ) + else: + shaped_labels = labels_for_lexsort + + out = np.empty((ncols, ngroups, nqs), dtype=np.float64) + + # Get an index of values sorted by values and then labels + order = (vals, shaped_labels) + sort_arr = np.lexsort(order).astype(np.intp, copy=False) + + if vals.ndim == 1: + # Ea is always 1d + func( + out[0], + values=vals, + mask=mask, + sort_indexer=sort_arr, + result_mask=result_mask, + ) + else: + for i in range(ncols): + func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i]) + + if vals.ndim == 1: + out = out.ravel("K") + if result_mask is not None: + result_mask = result_mask.ravel("K") + else: + out = out.reshape(ncols, ngroups * nqs) + return post_processor(out, inference, result_mask, orig_vals) + + return blk_func(self) + class ExtensionArraySupportsAnyAll(ExtensionArray): def any(self, *, skipna: bool = True) -> bool: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d129e8b37a350..ced2f28c3b135 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3259,6 +3259,15 @@ def post_processor( labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids) def blk_func(values: ArrayLike) -> ArrayLike: + if isinstance(values, ExtensionArray): + return values.groupby_quantile( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + ) + orig_vals = values if isinstance(values, BaseMaskedArray): mask = values._mask From df77e668590481d6037d1d564c930b5b74d40ade Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 13:20:39 -0800 Subject: [PATCH 2/9] REF: simplify groupby_quantile --- pandas/core/arrays/base.py | 224 +++++++++++++++++-------------------- 1 file changed, 103 insertions(+), 121 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index bc1f2cbb1189a..7760b674c842c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -8,6 +8,7 @@ """ from __future__ import annotations +from functools import partial import operator from typing import ( TYPE_CHECKING, @@ -24,12 +25,16 @@ import numpy as np -from pandas._libs import lib +from pandas._libs import ( + groupby as libgroupby, + lib, +) from pandas._typing import ( ArrayLike, AstypeArg, AxisInt, Dtype, + DtypeObj, FillnaOptions, PositionalIndexer, ScalarIndexer, @@ -55,9 +60,14 @@ from pandas.core.dtypes.cast import maybe_cast_to_extension_array from pandas.core.dtypes.common import ( + is_bool_dtype, is_datetime64_dtype, is_dtype_equal, + is_float_dtype, + is_integer_dtype, is_list_like, + is_numeric_dtype, + is_object_dtype, is_scalar, is_timedelta64_dtype, pandas_dtype, @@ -1697,17 +1707,6 @@ def groupby_quantile( ids: npt.NDArray[np.intp], labels_for_lexsort: npt.NDArray[np.intp], ): - from functools import partial - - from pandas._libs import groupby as libgroupby - - from pandas.core.dtypes.common import ( - is_bool_dtype, - is_float_dtype, - is_integer_dtype, - is_numeric_dtype, - is_object_dtype, - ) from pandas.core.arrays import ( BaseMaskedArray, @@ -1716,125 +1715,108 @@ def groupby_quantile( nqs = len(qs) - func = partial( - libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation - ) - - def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, Dtype | None]: - if is_object_dtype(vals): - raise TypeError( - "'quantile' cannot be performed against 'object' dtypes!" - ) - - inference: Dtype | None = None - if isinstance(vals, BaseMaskedArray) and is_numeric_dtype(vals.dtype): - out = vals.to_numpy(dtype=float, na_value=np.nan) - inference = vals.dtype - elif is_integer_dtype(vals.dtype): - if isinstance(vals, ExtensionArray): - out = vals.to_numpy(dtype=float, na_value=np.nan) - else: - out = vals - inference = np.dtype(np.int64) - elif is_bool_dtype(vals.dtype) and isinstance(vals, ExtensionArray): - out = vals.to_numpy(dtype=float, na_value=np.nan) - elif is_datetime64_dtype(vals.dtype): - inference = vals.dtype - out = np.asarray(vals).astype(float) - elif is_timedelta64_dtype(vals.dtype): - inference = vals.dtype - out = np.asarray(vals).astype(float) - elif isinstance(vals, ExtensionArray) and is_float_dtype(vals): - inference = np.dtype(np.float64) - out = vals.to_numpy(dtype=float, na_value=np.nan) - else: - out = np.asarray(vals) - - return out, inference - - def post_processor( - vals: np.ndarray, - inference: Dtype | None, - result_mask: np.ndarray | None, - orig_vals: ArrayLike, - ) -> ArrayLike: - if inference: - # Check for edge case - if isinstance(orig_vals, BaseMaskedArray): - assert result_mask is not None # for mypy - - if interpolation in {"linear", "midpoint"} and not is_float_dtype( - orig_vals - ): - return FloatingArray(vals, result_mask) - else: - # Item "ExtensionDtype" of "Union[ExtensionDtype, str, - # dtype[Any], Type[object]]" has no attribute "numpy_dtype" - # [union-attr] - return type(orig_vals)( - vals.astype( - inference.numpy_dtype # type: ignore[union-attr] - ), - result_mask, - ) - - elif not ( - is_integer_dtype(inference) - and interpolation in {"linear", "midpoint"} - ): - assert isinstance(inference, np.dtype) # for mypy - return vals.astype(inference) - - return vals - - def blk_func(values: ArrayLike) -> ArrayLike: - orig_vals = values - if isinstance(values, BaseMaskedArray): - mask = values._mask - result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) + if isinstance(self, BaseMaskedArray): + mask = self._mask + result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) + else: + mask = self.isna() + result_mask = None + + inference: DtypeObj | None = None + + if is_object_dtype(self.dtype): + raise TypeError("'quantile' cannot be performed against 'object' dtypes!") + elif isinstance(self, BaseMaskedArray) and is_numeric_dtype(self.dtype): + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + inference = self.dtype + elif is_integer_dtype(self.dtype): + if isinstance(self, ExtensionArray): + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) else: - mask = isna(values) - result_mask = None + npy_vals = self + inference = np.dtype(np.int64) + elif is_bool_dtype(self.dtype) and isinstance(self, ExtensionArray): + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + elif is_datetime64_dtype(self.dtype): + inference = self.dtype + npy_vals = np.asarray(self).astype(float) + elif is_timedelta64_dtype(self.dtype): + inference = self.dtype + npy_vals = np.asarray(self).astype(float) + elif isinstance(self, ExtensionArray) and is_float_dtype(self): + inference = np.dtype(np.float64) + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + else: + npy_vals = np.asarray(self) - vals, inference = pre_processor(values) + ncols = 1 + if npy_vals.ndim == 2: + ncols = npy_vals.shape[0] + shaped_labels = np.broadcast_to( + labels_for_lexsort, (ncols, len(labels_for_lexsort)) + ) + else: + shaped_labels = labels_for_lexsort - ncols = 1 - if vals.ndim == 2: - ncols = vals.shape[0] - shaped_labels = np.broadcast_to( - labels_for_lexsort, (ncols, len(labels_for_lexsort)) - ) - else: - shaped_labels = labels_for_lexsort + npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64) - out = np.empty((ncols, ngroups, nqs), dtype=np.float64) + # Get an index of values sorted by values and then labels + order = (npy_vals, shaped_labels) + sort_arr = np.lexsort(order).astype(np.intp, copy=False) - # Get an index of values sorted by values and then labels - order = (vals, shaped_labels) - sort_arr = np.lexsort(order).astype(np.intp, copy=False) + func = partial( + libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation + ) - if vals.ndim == 1: - # Ea is always 1d + if npy_vals.ndim == 1: + func( + npy_out[0], + values=npy_vals, + mask=mask, + sort_indexer=sort_arr, + result_mask=result_mask, + ) + else: + for i in range(ncols): func( - out[0], - values=vals, - mask=mask, - sort_indexer=sort_arr, - result_mask=result_mask, + npy_out[i], + values=npy_vals[i], + mask=mask[i], + sort_indexer=sort_arr[i], ) - else: - for i in range(ncols): - func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i]) - if vals.ndim == 1: - out = out.ravel("K") - if result_mask is not None: - result_mask = result_mask.ravel("K") - else: - out = out.reshape(ncols, ngroups * nqs) - return post_processor(out, inference, result_mask, orig_vals) + if npy_vals.ndim == 1: + npy_out = npy_out.ravel("K") + if result_mask is not None: + result_mask = result_mask.ravel("K") + else: + npy_out = npy_out.reshape(ncols, ngroups * nqs) + + if inference is not None: + # Check for edge case + if isinstance(self, BaseMaskedArray): + assert result_mask is not None # for mypy - return blk_func(self) + if interpolation in {"linear", "midpoint"} and not is_float_dtype(self): + return FloatingArray(npy_out, result_mask) + else: + # Item "ExtensionDtype" of "Union[ExtensionDtype, str, + # dtype[Any], Type[object]]" has no attribute "numpy_dtype" + # [union-attr] + return type(self)( + npy_out.astype( + inference.numpy_dtype # type: ignore[union-attr] + ), + result_mask, + ) + + elif not ( + is_integer_dtype(inference) and interpolation in {"linear", "midpoint"} + ): + assert isinstance(inference, np.dtype) # for mypy + return npy_out.astype(inference) + + return npy_out class ExtensionArraySupportsAnyAll(ExtensionArray): From a9ee791209c4d667af61e3ccc94f2814a8051748 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 13:45:13 -0800 Subject: [PATCH 3/9] REF: separate out BaseMaskedArray.groupby_quantile --- pandas/core/arrays/base.py | 48 +++------------ pandas/core/arrays/masked.py | 55 +++++++++++++++++ pandas/core/groupby/groupby.py | 107 +++++++-------------------------- 3 files changed, 84 insertions(+), 126 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 7760b674c842c..ac71e9ba7ebd6 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -66,7 +66,6 @@ is_float_dtype, is_integer_dtype, is_list_like, - is_numeric_dtype, is_object_dtype, is_scalar, is_timedelta64_dtype, @@ -1708,34 +1707,19 @@ def groupby_quantile( labels_for_lexsort: npt.NDArray[np.intp], ): - from pandas.core.arrays import ( - BaseMaskedArray, - FloatingArray, - ) - nqs = len(qs) - if isinstance(self, BaseMaskedArray): - mask = self._mask - result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) - else: - mask = self.isna() - result_mask = None + mask = self.isna() inference: DtypeObj | None = None + # TODO: 2023-01-26 we only have tests for the dt64/td64 cases here if is_object_dtype(self.dtype): raise TypeError("'quantile' cannot be performed against 'object' dtypes!") - elif isinstance(self, BaseMaskedArray) and is_numeric_dtype(self.dtype): - npy_vals = self.to_numpy(dtype=float, na_value=np.nan) - inference = self.dtype elif is_integer_dtype(self.dtype): - if isinstance(self, ExtensionArray): - npy_vals = self.to_numpy(dtype=float, na_value=np.nan) - else: - npy_vals = self + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) inference = np.dtype(np.int64) - elif is_bool_dtype(self.dtype) and isinstance(self, ExtensionArray): + elif is_bool_dtype(self.dtype): npy_vals = self.to_numpy(dtype=float, na_value=np.nan) elif is_datetime64_dtype(self.dtype): inference = self.dtype @@ -1743,7 +1727,7 @@ def groupby_quantile( elif is_timedelta64_dtype(self.dtype): inference = self.dtype npy_vals = np.asarray(self).astype(float) - elif isinstance(self, ExtensionArray) and is_float_dtype(self): + elif is_float_dtype(self): inference = np.dtype(np.float64) npy_vals = self.to_numpy(dtype=float, na_value=np.nan) else: @@ -1774,7 +1758,7 @@ def groupby_quantile( values=npy_vals, mask=mask, sort_indexer=sort_arr, - result_mask=result_mask, + result_mask=None, ) else: for i in range(ncols): @@ -1787,30 +1771,12 @@ def groupby_quantile( if npy_vals.ndim == 1: npy_out = npy_out.ravel("K") - if result_mask is not None: - result_mask = result_mask.ravel("K") else: npy_out = npy_out.reshape(ncols, ngroups * nqs) if inference is not None: # Check for edge case - if isinstance(self, BaseMaskedArray): - assert result_mask is not None # for mypy - - if interpolation in {"linear", "midpoint"} and not is_float_dtype(self): - return FloatingArray(npy_out, result_mask) - else: - # Item "ExtensionDtype" of "Union[ExtensionDtype, str, - # dtype[Any], Type[object]]" has no attribute "numpy_dtype" - # [union-attr] - return type(self)( - npy_out.astype( - inference.numpy_dtype # type: ignore[union-attr] - ), - result_mask, - ) - - elif not ( + if not ( is_integer_dtype(inference) and interpolation in {"linear", "midpoint"} ): assert isinstance(inference, np.dtype) # for mypy diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index d45fe05d52937..870f5ca65c687 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -14,6 +15,7 @@ import numpy as np from pandas._libs import ( + groupby as libgroupby, lib, missing as libmissing, ) @@ -1383,3 +1385,56 @@ def _accumulate( data, mask = op(data, mask, skipna=skipna, **kwargs) return type(self)(data, mask, copy=False) + + # ------------------------------------------------------------------ + + def groupby_quantile( + self, + *, + qs: npt.NDArray[np.float64], + interpolation: str, + ngroups: int, + ids: npt.NDArray[np.intp], + labels_for_lexsort: npt.NDArray[np.intp], + ): + + nqs = len(qs) + + mask = self._mask + result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) + + npy_vals = self.to_numpy(dtype=float, na_value=np.nan) + + ncols = 1 + shaped_labels = labels_for_lexsort + + npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64) + + # Get an index of values sorted by values and then labels + order = (npy_vals, shaped_labels) + sort_arr = np.lexsort(order).astype(np.intp, copy=False) + + func = partial( + libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation + ) + + func( + npy_out[0], + values=npy_vals, + mask=mask, + sort_indexer=sort_arr, + result_mask=result_mask, + ) + + npy_out = npy_out.ravel("K") + result_mask = result_mask.ravel("K") + + if interpolation in {"linear", "midpoint"} and not is_float_dtype(self.dtype): + from pandas.core.arrays import FloatingArray + + return FloatingArray(npy_out, result_mask) + else: + return type(self)( + npy_out.astype(self.dtype.numpy_dtype), + result_mask, + ) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ced2f28c3b135..73c7e24b737d0 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -49,7 +49,6 @@ class providing the base-class of operations. ArrayLike, Axis, AxisInt, - Dtype, FillnaOptions, IndexLabel, NDFrameT, @@ -73,15 +72,11 @@ class providing the base-class of operations. from pandas.core.dtypes.cast import ensure_dtype_can_hold_na from pandas.core.dtypes.common import ( - is_bool_dtype, - is_datetime64_dtype, - is_float_dtype, is_integer, is_integer_dtype, is_numeric_dtype, is_object_dtype, is_scalar, - is_timedelta64_dtype, ) from pandas.core.dtypes.missing import ( isna, @@ -3169,73 +3164,6 @@ def quantile( f"numeric_only={numeric_only} and dtype {self.obj.dtype}" ) - def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, Dtype | None]: - if is_object_dtype(vals): - raise TypeError( - "'quantile' cannot be performed against 'object' dtypes!" - ) - - inference: Dtype | None = None - if isinstance(vals, BaseMaskedArray) and is_numeric_dtype(vals.dtype): - out = vals.to_numpy(dtype=float, na_value=np.nan) - inference = vals.dtype - elif is_integer_dtype(vals.dtype): - if isinstance(vals, ExtensionArray): - out = vals.to_numpy(dtype=float, na_value=np.nan) - else: - out = vals - inference = np.dtype(np.int64) - elif is_bool_dtype(vals.dtype) and isinstance(vals, ExtensionArray): - out = vals.to_numpy(dtype=float, na_value=np.nan) - elif is_datetime64_dtype(vals.dtype): - inference = vals.dtype - out = np.asarray(vals).astype(float) - elif is_timedelta64_dtype(vals.dtype): - inference = vals.dtype - out = np.asarray(vals).astype(float) - elif isinstance(vals, ExtensionArray) and is_float_dtype(vals): - inference = np.dtype(np.float64) - out = vals.to_numpy(dtype=float, na_value=np.nan) - else: - out = np.asarray(vals) - - return out, inference - - def post_processor( - vals: np.ndarray, - inference: Dtype | None, - result_mask: np.ndarray | None, - orig_vals: ArrayLike, - ) -> ArrayLike: - if inference: - # Check for edge case - if isinstance(orig_vals, BaseMaskedArray): - assert result_mask is not None # for mypy - - if interpolation in {"linear", "midpoint"} and not is_float_dtype( - orig_vals - ): - return FloatingArray(vals, result_mask) - else: - # Item "ExtensionDtype" of "Union[ExtensionDtype, str, - # dtype[Any], Type[object]]" has no attribute "numpy_dtype" - # [union-attr] - return type(orig_vals)( - vals.astype( - inference.numpy_dtype # type: ignore[union-attr] - ), - result_mask, - ) - - elif not ( - is_integer_dtype(inference) - and interpolation in {"linear", "midpoint"} - ): - assert isinstance(inference, np.dtype) # for mypy - return vals.astype(inference) - - return vals - orig_scalar = is_scalar(q) if orig_scalar: # error: Incompatible types in assignment (expression has type "List[ @@ -3268,15 +3196,17 @@ def blk_func(values: ArrayLike) -> ArrayLike: labels_for_lexsort=labels_for_lexsort, ) - orig_vals = values - if isinstance(values, BaseMaskedArray): - mask = values._mask - result_mask = np.zeros((ngroups, nqs), dtype=np.bool_) - else: - mask = isna(values) - result_mask = None + if is_object_dtype(values.dtype): + raise TypeError( + "'quantile' cannot be performed against 'object' dtypes!" + ) - vals, inference = pre_processor(values) + inference: np.dtype | None = None + if is_integer_dtype(values.dtype): + vals = values + inference = np.dtype(np.int64) + else: + vals = np.asarray(values) ncols = 1 if vals.ndim == 2: @@ -3293,14 +3223,15 @@ def blk_func(values: ArrayLike) -> ArrayLike: order = (vals, shaped_labels) sort_arr = np.lexsort(order).astype(np.intp, copy=False) + mask = isna(values) + if vals.ndim == 1: - # Ea is always 1d func( out[0], values=vals, mask=mask, sort_indexer=sort_arr, - result_mask=result_mask, + result_mask=None, ) else: for i in range(ncols): @@ -3308,11 +3239,17 @@ def blk_func(values: ArrayLike) -> ArrayLike: if vals.ndim == 1: out = out.ravel("K") - if result_mask is not None: - result_mask = result_mask.ravel("K") else: out = out.reshape(ncols, ngroups * nqs) - return post_processor(out, inference, result_mask, orig_vals) + + if inference: + # Check for edge case + if not ( + is_integer_dtype(inference) + and interpolation in {"linear", "midpoint"} + ): + return out.astype(inference) + return out obj = self._obj_with_exclusions is_ser = obj.ndim == 1 From c44120845be37d0891a8a4019f20d6ac29fd83e4 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 15:03:34 -0800 Subject: [PATCH 4/9] REF: implement groupby_quantile_ndim_compat --- pandas/core/array_algos/quantile.py | 91 +++++++++++++++++++++++++++++ pandas/core/arrays/base.py | 61 +++++-------------- pandas/core/arrays/masked.py | 36 +++++------- pandas/core/groupby/groupby.py | 70 ++++++---------------- 4 files changed, 135 insertions(+), 123 deletions(-) diff --git a/pandas/core/array_algos/quantile.py b/pandas/core/array_algos/quantile.py index d3d9cb1b29b9a..7c29950e4e633 100644 --- a/pandas/core/array_algos/quantile.py +++ b/pandas/core/array_algos/quantile.py @@ -1,7 +1,10 @@ from __future__ import annotations +from functools import partial + import numpy as np +from pandas._libs import groupby as libgroupby from pandas._typing import ( ArrayLike, Scalar, @@ -15,6 +18,94 @@ ) +def groupby_quantile_ndim_compat( + *, + qs: npt.NDArray[np.float64], + interpolation: str, + ngroups: int, + ids: npt.NDArray[np.intp], + labels_for_lexsort: npt.NDArray[np.intp], + npy_vals: np.ndarray, + mask: npt.NDArray[np.bool_], + result_mask: npt.NDArray[np.bool_] | None, +) -> np.ndarray: + """ + Compatibility layer to handle either 1D arrays or 2D ndarrays in + GroupBy.quantile. Located here to be available to + ExtensionArray.groupby_quantile for dispatching after casting to numpy. + + Parameters + ---------- + qs : np.ndarray[float64] + Values between 0 and 1 providing the quantile(s) to compute. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + Method to use when the desired quantile falls between two points. + ngroups : int + The number of groupby groups. + ids : np.ndarray[intp] + Group labels. + labels_for_lexsort : np.ndarray[intp] + Group labels, but with -1s moved moved to the end to sort NAs last. + npy_vals : np.ndarray + The values for which we are computing quantiles. + mask : np.ndarray[bool] + Locations to treat as NA. + result_mask : np.ndarray[bool] or None + If present, set to True for locations that should be treated as missing + a result. Modified in-place. + + Returns + ------- + np.ndarray + """ + nqs = len(qs) + + ncols = 1 + if npy_vals.ndim == 2: + ncols = npy_vals.shape[0] + shaped_labels = np.broadcast_to( + labels_for_lexsort, (ncols, len(labels_for_lexsort)) + ) + else: + shaped_labels = labels_for_lexsort + + npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64) + + # Get an index of values sorted by values and then labels + order = (npy_vals, shaped_labels) + sort_arr = np.lexsort(order).astype(np.intp, copy=False) + + func = partial( + libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation + ) + + if npy_vals.ndim == 1: + func( + npy_out[0], + values=npy_vals, + mask=mask, + sort_indexer=sort_arr, + result_mask=result_mask, + ) + else: + # if we ever did get here with non-None result_mask, we'd pass result_mask[i] + assert result_mask is None + for i in range(ncols): + func( + npy_out[i], + values=npy_vals[i], + mask=mask[i], + sort_indexer=sort_arr[i], + ) + + if npy_vals.ndim == 1: + npy_out = npy_out.reshape(ngroups * nqs) + else: + npy_out = npy_out.reshape(ncols, ngroups * nqs) + + return npy_out + + def quantile_compat( values: ArrayLike, qs: npt.NDArray[np.float64], interpolation: str ) -> ArrayLike: diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index ac71e9ba7ebd6..91df7045fb748 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -8,7 +8,6 @@ """ from __future__ import annotations -from functools import partial import operator from typing import ( TYPE_CHECKING, @@ -25,10 +24,7 @@ import numpy as np -from pandas._libs import ( - groupby as libgroupby, - lib, -) +from pandas._libs import lib from pandas._typing import ( ArrayLike, AstypeArg, @@ -91,7 +87,10 @@ rank, unique, ) -from pandas.core.array_algos.quantile import quantile_with_mask +from pandas.core.array_algos.quantile import ( + groupby_quantile_ndim_compat, + quantile_with_mask, +) from pandas.core.sorting import ( nargminmax, nargsort, @@ -1707,8 +1706,6 @@ def groupby_quantile( labels_for_lexsort: npt.NDArray[np.intp], ): - nqs = len(qs) - mask = self.isna() inference: DtypeObj | None = None @@ -1733,47 +1730,17 @@ def groupby_quantile( else: npy_vals = np.asarray(self) - ncols = 1 - if npy_vals.ndim == 2: - ncols = npy_vals.shape[0] - shaped_labels = np.broadcast_to( - labels_for_lexsort, (ncols, len(labels_for_lexsort)) - ) - else: - shaped_labels = labels_for_lexsort - - npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64) - - # Get an index of values sorted by values and then labels - order = (npy_vals, shaped_labels) - sort_arr = np.lexsort(order).astype(np.intp, copy=False) - - func = partial( - libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation + npy_out = groupby_quantile_ndim_compat( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + npy_vals=npy_vals, + mask=np.asarray(mask), + result_mask=None, ) - if npy_vals.ndim == 1: - func( - npy_out[0], - values=npy_vals, - mask=mask, - sort_indexer=sort_arr, - result_mask=None, - ) - else: - for i in range(ncols): - func( - npy_out[i], - values=npy_vals[i], - mask=mask[i], - sort_indexer=sort_arr[i], - ) - - if npy_vals.ndim == 1: - npy_out = npy_out.ravel("K") - else: - npy_out = npy_out.reshape(ncols, ngroups * nqs) - if inference is not None: # Check for edge case if not ( diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 870f5ca65c687..fc341f8b3d781 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -15,7 +14,6 @@ import numpy as np from pandas._libs import ( - groupby as libgroupby, lib, missing as libmissing, ) @@ -79,7 +77,10 @@ masked_accumulations, masked_reductions, ) -from pandas.core.array_algos.quantile import quantile_with_mask +from pandas.core.array_algos.quantile import ( + groupby_quantile_ndim_compat, + quantile_with_mask, +) from pandas.core.arraylike import OpsMixin from pandas.core.arrays import ExtensionArray from pandas.core.construction import ensure_wrapped_if_datetimelike @@ -1388,6 +1389,7 @@ def _accumulate( # ------------------------------------------------------------------ + @doc(ExtensionArray.groupby_quantile) def groupby_quantile( self, *, @@ -1405,29 +1407,17 @@ def groupby_quantile( npy_vals = self.to_numpy(dtype=float, na_value=np.nan) - ncols = 1 - shaped_labels = labels_for_lexsort - - npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64) - - # Get an index of values sorted by values and then labels - order = (npy_vals, shaped_labels) - sort_arr = np.lexsort(order).astype(np.intp, copy=False) - - func = partial( - libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation - ) - - func( - npy_out[0], - values=npy_vals, + npy_out = groupby_quantile_ndim_compat( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + npy_vals=npy_vals, mask=mask, - sort_indexer=sort_arr, result_mask=result_mask, ) - - npy_out = npy_out.ravel("K") - result_mask = result_mask.ravel("K") + result_mask = result_mask.reshape(ngroups * nqs) if interpolation in {"linear", "midpoint"} and not is_float_dtype(self.dtype): from pandas.core.arrays import FloatingArray diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 73c7e24b737d0..843fcce0406bd 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -89,6 +89,7 @@ class providing the base-class of operations. sample, ) from pandas.core._numba import executor +from pandas.core.array_algos.quantile import groupby_quantile_ndim_compat from pandas.core.arrays import ( BaseMaskedArray, BooleanArray, @@ -3174,11 +3175,6 @@ def quantile( qs = np.array(q, dtype=np.float64) ids, _, ngroups = self.grouper.group_info - nqs = len(qs) - - func = partial( - libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation - ) # Put '-1' (NaN) labels as the last group so it does not interfere # with the calculations. Note: length check avoids failure on empty @@ -3201,55 +3197,23 @@ def blk_func(values: ArrayLike) -> ArrayLike: "'quantile' cannot be performed against 'object' dtypes!" ) - inference: np.dtype | None = None - if is_integer_dtype(values.dtype): - vals = values - inference = np.dtype(np.int64) - else: - vals = np.asarray(values) - - ncols = 1 - if vals.ndim == 2: - ncols = vals.shape[0] - shaped_labels = np.broadcast_to( - labels_for_lexsort, (ncols, len(labels_for_lexsort)) - ) - else: - shaped_labels = labels_for_lexsort - - out = np.empty((ncols, ngroups, nqs), dtype=np.float64) - - # Get an index of values sorted by values and then labels - order = (vals, shaped_labels) - sort_arr = np.lexsort(order).astype(np.intp, copy=False) - - mask = isna(values) - - if vals.ndim == 1: - func( - out[0], - values=vals, - mask=mask, - sort_indexer=sort_arr, - result_mask=None, - ) - else: - for i in range(ncols): - func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i]) + npy_out = groupby_quantile_ndim_compat( + qs=qs, + interpolation=interpolation, + ngroups=ngroups, + ids=ids, + labels_for_lexsort=labels_for_lexsort, + npy_vals=values, + mask=isna(values), + result_mask=None, + ) - if vals.ndim == 1: - out = out.ravel("K") - else: - out = out.reshape(ncols, ngroups * nqs) - - if inference: - # Check for edge case - if not ( - is_integer_dtype(inference) - and interpolation in {"linear", "midpoint"} - ): - return out.astype(inference) - return out + if is_integer_dtype(values.dtype) and interpolation not in { + "linear", + "midpoint", + }: + return npy_out.astype(np.dtype(np.int64)) + return npy_out obj = self._obj_with_exclusions is_ser = obj.ndim == 1 From d9f053f1c79573df683e42ef8feb5b9638233914 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 18:19:15 -0800 Subject: [PATCH 5/9] lint fixup --- pandas/core/array_algos/quantile.py | 3 ++- pandas/core/arrays/base.py | 2 +- pandas/core/arrays/masked.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/core/array_algos/quantile.py b/pandas/core/array_algos/quantile.py index 7c29950e4e633..ab821993ca179 100644 --- a/pandas/core/array_algos/quantile.py +++ b/pandas/core/array_algos/quantile.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import partial +from typing import Literal import numpy as np @@ -21,7 +22,7 @@ def groupby_quantile_ndim_compat( *, qs: npt.NDArray[np.float64], - interpolation: str, + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], ngroups: int, ids: npt.NDArray[np.intp], labels_for_lexsort: npt.NDArray[np.intp], diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 91df7045fb748..91bd879256cf6 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1700,7 +1700,7 @@ def groupby_quantile( self, *, qs: npt.NDArray[np.float64], - interpolation: str, + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], ngroups: int, ids: npt.NDArray[np.intp], labels_for_lexsort: npt.NDArray[np.intp], diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index fc341f8b3d781..df2ea58e7c00a 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1394,7 +1394,7 @@ def groupby_quantile( self, *, qs: npt.NDArray[np.float64], - interpolation: str, + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], ngroups: int, ids: npt.NDArray[np.intp], labels_for_lexsort: npt.NDArray[np.intp], From 8c540fa1c53b11dd898ccd031a47693d10ca8082 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 26 Jan 2023 18:19:55 -0800 Subject: [PATCH 6/9] lint fixup --- pandas/core/array_algos/quantile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/array_algos/quantile.py b/pandas/core/array_algos/quantile.py index ab821993ca179..9a57d257dbbcb 100644 --- a/pandas/core/array_algos/quantile.py +++ b/pandas/core/array_algos/quantile.py @@ -84,7 +84,7 @@ def groupby_quantile_ndim_compat( func( npy_out[0], values=npy_vals, - mask=mask, + mask=mask.view(np.uint8), sort_indexer=sort_arr, result_mask=result_mask, ) @@ -95,7 +95,7 @@ def groupby_quantile_ndim_compat( func( npy_out[i], values=npy_vals[i], - mask=mask[i], + mask=mask[i].view(np.uint8), sort_indexer=sort_arr[i], ) From fcc0bcb3d250bf803e467be0d87715aa57fa0bb7 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 27 Jan 2023 07:51:54 -0800 Subject: [PATCH 7/9] mypy fixup --- pandas/core/groupby/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 843fcce0406bd..7036d09ee2bd1 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3115,7 +3115,9 @@ def _nth( def quantile( self, q: float | AnyArrayLike = 0.5, - interpolation: str = "linear", + interpolation: Literal[ + "linear", "lower", "higher", "nearest", "midpoint" + ] = "linear", numeric_only: bool = False, ): """ From ae135d6e9e18cbf81eabb2954a5722d8d9d83f04 Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 7 Feb 2023 13:38:26 -0800 Subject: [PATCH 8/9] REF: privatize groupby_quantile --- pandas/core/array_algos/quantile.py | 2 +- pandas/core/arrays/base.py | 2 +- pandas/core/arrays/masked.py | 2 +- pandas/core/groupby/groupby.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/core/array_algos/quantile.py b/pandas/core/array_algos/quantile.py index 9a57d257dbbcb..4ad7a1549e207 100644 --- a/pandas/core/array_algos/quantile.py +++ b/pandas/core/array_algos/quantile.py @@ -33,7 +33,7 @@ def groupby_quantile_ndim_compat( """ Compatibility layer to handle either 1D arrays or 2D ndarrays in GroupBy.quantile. Located here to be available to - ExtensionArray.groupby_quantile for dispatching after casting to numpy. + ExtensionArray._groupby_quantile for dispatching after casting to numpy. Parameters ---------- diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 91bd879256cf6..de8eae511696b 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1696,7 +1696,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs) - def groupby_quantile( + def _groupby_quantile( self, *, qs: npt.NDArray[np.float64], diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 2749a411a9905..ce9c239e688a6 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1390,7 +1390,7 @@ def _accumulate( # ------------------------------------------------------------------ @doc(ExtensionArray.groupby_quantile) - def groupby_quantile( + def _groupby_quantile( self, *, qs: npt.NDArray[np.float64], diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 39b25b0575e55..392b94950079f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3064,7 +3064,7 @@ def quantile( def blk_func(values: ArrayLike) -> ArrayLike: if isinstance(values, ExtensionArray): - return values.groupby_quantile( + return values._groupby_quantile( qs=qs, interpolation=interpolation, ngroups=ngroups, From f15222043b07a1867a4fdfbfadb0c6237ea4367d Mon Sep 17 00:00:00 2001 From: Brock Date: Tue, 7 Feb 2023 13:59:17 -0800 Subject: [PATCH 9/9] typo fixup --- pandas/core/arrays/masked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index ce9c239e688a6..b13c9418264a3 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1389,7 +1389,7 @@ def _accumulate( # ------------------------------------------------------------------ - @doc(ExtensionArray.groupby_quantile) + @doc(ExtensionArray._groupby_quantile) def _groupby_quantile( self, *,