diff --git a/pandas/_libs/algos.pyi b/pandas/_libs/algos.pyi index 60bdb504c545b..0cc9209fbdfc5 100644 --- a/pandas/_libs/algos.pyi +++ b/pandas/_libs/algos.pyi @@ -109,6 +109,7 @@ def rank_1d( ascending: bool = ..., pct: bool = ..., na_option=..., + mask: npt.NDArray[np.bool_] | None = ..., ) -> np.ndarray: ... # np.ndarray[float64_t, ndim=1] def rank_2d( in_arr: np.ndarray, # ndarray[numeric_object_t, ndim=2] diff --git a/pandas/_libs/algos.pyx b/pandas/_libs/algos.pyx index 6c28b4f821080..d33eba06988e9 100644 --- a/pandas/_libs/algos.pyx +++ b/pandas/_libs/algos.pyx @@ -889,6 +889,7 @@ def rank_1d( bint ascending=True, bint pct=False, na_option="keep", + const uint8_t[:] mask=None, ): """ Fast NaN-friendly version of ``scipy.stats.rankdata``. @@ -918,6 +919,8 @@ def rank_1d( * keep: leave NA values where they are * top: smallest rank if ascending * bottom: smallest rank if descending + mask : np.ndarray[bool], optional, default None + Specify locations to be treated as NA, for e.g. Categorical. """ cdef: TiebreakEnumType tiebreak @@ -927,7 +930,6 @@ def rank_1d( float64_t[::1] out ndarray[numeric_object_t, ndim=1] masked_vals numeric_object_t[:] masked_vals_memview - uint8_t[:] mask bint keep_na, nans_rank_highest, check_labels, check_mask numeric_object_t nan_fill_val @@ -956,6 +958,7 @@ def rank_1d( or numeric_object_t is object or (numeric_object_t is int64_t and is_datetimelike) ) + check_mask = check_mask or mask is not None # Copy values into new array in order to fill missing data # with mask, without obfuscating location of missing data @@ -965,7 +968,9 @@ def rank_1d( else: masked_vals = values.copy() - if numeric_object_t is object: + if mask is not None: + pass + elif numeric_object_t is object: mask = missing.isnaobj(masked_vals) elif numeric_object_t is int64_t and is_datetimelike: mask = (masked_vals == NPY_NAT).astype(np.uint8) diff --git a/pandas/_libs/groupby.pyi b/pandas/_libs/groupby.pyi index 197a8bdc0cd7c..2f0c3980c0c02 100644 --- a/pandas/_libs/groupby.pyi +++ b/pandas/_libs/groupby.pyi @@ -128,6 +128,7 @@ def group_rank( ascending: bool = ..., pct: bool = ..., na_option: Literal["keep", "top", "bottom"] = ..., + mask: npt.NDArray[np.bool_] | None = ..., ) -> None: ... def group_max( out: np.ndarray, # groupby_t[:, ::1] diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 9bc89eef089cd..03f318d08d8cb 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -1262,6 +1262,7 @@ def group_rank( bint ascending=True, bint pct=False, str na_option="keep", + const uint8_t[:, :] mask=None, ) -> None: """ Provides the rank of values within each group. @@ -1294,6 +1295,7 @@ def group_rank( * keep: leave NA values where they are * top: smallest rank if ascending * bottom: smallest rank if descending + mask : np.ndarray[bool] or None, default None Notes ----- @@ -1302,10 +1304,16 @@ def group_rank( cdef: Py_ssize_t i, k, N ndarray[float64_t, ndim=1] result + const uint8_t[:] sub_mask N = values.shape[1] for k in range(N): + if mask is None: + sub_mask = None + else: + sub_mask = mask[:, k] + result = rank_1d( values=values[:, k], labels=labels, @@ -1313,7 +1321,8 @@ def group_rank( ties_method=ties_method, ascending=ascending, pct=pct, - na_option=na_option + na_option=na_option, + mask=sub_mask, ) for i in range(len(result)): # TODO: why can't we do out[:, k] = result? diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 09954bd6be4e4..a769c92e0b542 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -46,7 +46,6 @@ ensure_platform_int, is_1d_only_ea_dtype, is_bool_dtype, - is_categorical_dtype, is_complex_dtype, is_datetime64_any_dtype, is_float_dtype, @@ -56,12 +55,14 @@ is_timedelta64_dtype, needs_i8_conversion, ) +from pandas.core.dtypes.dtypes import CategoricalDtype from pandas.core.dtypes.missing import ( isna, maybe_fill, ) from pandas.core.arrays import ( + Categorical, DatetimeArray, ExtensionArray, PeriodArray, @@ -142,7 +143,15 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None: # "group_any" and "group_all" are also support masks, but don't go # through WrappedCythonOp - _MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max", "last", "first"} + _MASKED_CYTHON_FUNCTIONS = { + "cummin", + "cummax", + "min", + "max", + "last", + "first", + "rank", + } _cython_arity = {"ohlc": 4} # OHLC @@ -229,12 +238,17 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False): # never an invalid op for those dtypes, so return early as fastpath return - if is_categorical_dtype(dtype): + if isinstance(dtype, CategoricalDtype): # NotImplementedError for methods that can fall back to a # non-cython implementation. if how in ["add", "prod", "cumsum", "cumprod"]: raise TypeError(f"{dtype} type does not support {how} operations") - raise NotImplementedError(f"{dtype} dtype not supported") + elif how not in ["rank"]: + # only "rank" is implemented in cython + raise NotImplementedError(f"{dtype} dtype not supported") + elif not dtype.ordered: + # TODO: TypeError? + raise NotImplementedError(f"{dtype} dtype not supported") elif is_sparse(dtype): # categoricals are only 1d, so we @@ -332,6 +346,25 @@ def _ea_wrap_cython_operation( **kwargs, ) + elif isinstance(values, Categorical) and self.uses_mask(): + assert self.how == "rank" # the only one implemented ATM + assert values.ordered # checked earlier + mask = values.isna() + npvalues = values._ndarray + + res_values = self._cython_op_ndim_compat( + npvalues, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=mask, + **kwargs, + ) + + # If we ever have more than just "rank" here, we'll need to do + # `if self.how in self.cast_blocklist` like we do for other dtypes. + return res_values + npvalues = self._ea_to_cython_values(values) res_values = self._cython_op_ndim_compat( @@ -551,6 +584,9 @@ def _call_cython_op( else: # TODO: min_count if self.uses_mask(): + if self.how != "rank": + # TODO: should rank take result_mask? + kwargs["result_mask"] = result_mask func( out=result, values=values, @@ -558,7 +594,6 @@ def _call_cython_op( ngroups=ngroups, is_datetimelike=is_datetimelike, mask=mask, - result_mask=result_mask, **kwargs, ) else: diff --git a/pandas/tests/groupby/test_rank.py b/pandas/tests/groupby/test_rank.py index 7830c229ece2f..8bbe38d3379ac 100644 --- a/pandas/tests/groupby/test_rank.py +++ b/pandas/tests/groupby/test_rank.py @@ -458,6 +458,8 @@ def test_rank_avg_even_vals(dtype, upper): result = df.groupby("key").rank() exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"]) + if upper: + exp_df = exp_df.astype("Float64") tm.assert_frame_equal(result, exp_df) @@ -663,3 +665,17 @@ def test_non_unique_index(): name="value", ) tm.assert_series_equal(result, expected) + + +def test_rank_categorical(): + cat = pd.Categorical(["a", "a", "b", np.nan, "c", "b"], ordered=True) + cat2 = pd.Categorical([1, 2, 3, np.nan, 4, 5], ordered=True) + + df = DataFrame({"col1": [0, 1, 0, 1, 0, 1], "col2": cat, "col3": cat2}) + + gb = df.groupby("col1") + + res = gb.rank() + + expected = df.astype(object).groupby("col1").rank() + tm.assert_frame_equal(res, expected)