Skip to content

PERF: stronger typing in libgroupby #43103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def group_median_float64(ndarray[float64_t, ndim=2] out,
@cython.wraparound(False)
def group_cumprod_float64(float64_t[:, ::1] out,
const float64_t[:, :] values,
const intp_t[:] labels,
const intp_t[::1] labels,
int ngroups,
bint is_datetimelike,
bint skipna=True) -> None:
Expand Down Expand Up @@ -202,7 +202,7 @@ def group_cumprod_float64(float64_t[:, ::1] out,
@cython.wraparound(False)
def group_cumsum(numeric[:, ::1] out,
ndarray[numeric, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
int ngroups,
is_datetimelike,
bint skipna=True) -> None:
Expand Down Expand Up @@ -269,7 +269,7 @@ def group_cumsum(numeric[:, ::1] out,

@cython.boundscheck(False)
@cython.wraparound(False)
def group_shift_indexer(int64_t[::1] out, const intp_t[:] labels,
def group_shift_indexer(int64_t[::1] out, const intp_t[::1] labels,
int ngroups, int periods) -> None:
cdef:
Py_ssize_t N, i, j, ii, lab
Expand Down Expand Up @@ -390,7 +390,7 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels,
@cython.wraparound(False)
def group_any_all(int8_t[::1] out,
const int8_t[::1] values,
const intp_t[:] labels,
const intp_t[::1] labels,
const uint8_t[::1] mask,
str val_test,
bint skipna,
Expand Down Expand Up @@ -482,7 +482,7 @@ ctypedef fused add_t:
def group_add(add_t[:, ::1] out,
int64_t[::1] counts,
ndarray[add_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=0) -> None:
"""
Only aggregates on axis=0 using Kahan summation
Expand Down Expand Up @@ -565,7 +565,7 @@ def group_add(add_t[:, ::1] out,
def group_prod(floating[:, ::1] out,
int64_t[::1] counts,
ndarray[floating, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=0) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -614,7 +614,7 @@ def group_prod(floating[:, ::1] out,
def group_var(floating[:, ::1] out,
int64_t[::1] counts,
ndarray[floating, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=-1,
int64_t ddof=1) -> None:
cdef:
Expand Down Expand Up @@ -720,7 +720,7 @@ def group_mean(floating[:, ::1] out,
def group_ohlc(floating[:, ::1] out,
int64_t[::1] counts,
ndarray[floating, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=-1) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -910,7 +910,7 @@ cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
def group_last(rank_t[:, ::1] out,
int64_t[::1] counts,
ndarray[rank_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=-1) -> None:
"""
Only aggregates on axis=0
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def group_last(rank_t[:, ::1] out,
def group_nth(rank_t[:, ::1] out,
int64_t[::1] counts,
ndarray[rank_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
int64_t min_count=-1,
int64_t rank=1,
) -> None:
Expand Down Expand Up @@ -1095,7 +1095,7 @@ def group_nth(rank_t[:, ::1] out,
@cython.wraparound(False)
def group_rank(float64_t[:, ::1] out,
ndarray[rank_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
int ngroups,
bint is_datetimelike, str ties_method="average",
bint ascending=True, bint pct=False, str na_option="keep") -> None:
Expand Down Expand Up @@ -1173,7 +1173,7 @@ ctypedef fused groupby_t:
cdef group_min_max(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=-1,
bint is_datetimelike=False,
bint compute_max=True):
Expand Down Expand Up @@ -1274,7 +1274,7 @@ cdef group_min_max(groupby_t[:, ::1] out,
def group_max(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=-1,
bint is_datetimelike=False) -> None:
"""See group_min_max.__doc__"""
Expand All @@ -1294,7 +1294,7 @@ def group_max(groupby_t[:, ::1] out,
def group_min(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
Py_ssize_t min_count=-1,
bint is_datetimelike=False) -> None:
"""See group_min_max.__doc__"""
Expand All @@ -1314,7 +1314,7 @@ def group_min(groupby_t[:, ::1] out,
cdef group_cummin_max(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
uint8_t[:, ::1] mask,
const intp_t[:] labels,
const intp_t[::1] labels,
int ngroups,
bint is_datetimelike,
bint skipna,
Expand Down Expand Up @@ -1368,7 +1368,7 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
@cython.wraparound(False)
cdef cummin_max(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
groupby_t[:, ::1] accum,
bint skipna,
bint is_datetimelike,
Expand Down Expand Up @@ -1428,7 +1428,7 @@ cdef cummin_max(groupby_t[:, ::1] out,
cdef masked_cummin_max(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
uint8_t[:, ::1] mask,
const intp_t[:] labels,
const intp_t[::1] labels,
groupby_t[:, ::1] accum,
bint skipna,
bint compute_max):
Expand Down Expand Up @@ -1471,7 +1471,7 @@ cdef masked_cummin_max(groupby_t[:, ::1] out,
@cython.wraparound(False)
def group_cummin(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
int ngroups,
bint is_datetimelike,
uint8_t[:, ::1] mask=None,
Expand All @@ -1493,7 +1493,7 @@ def group_cummin(groupby_t[:, ::1] out,
@cython.wraparound(False)
def group_cummax(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
const intp_t[:] labels,
const intp_t[::1] labels,
int ngroups,
bint is_datetimelike,
uint8_t[:, ::1] mask=None,
Expand Down
16 changes: 8 additions & 8 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def result_to_bool(
return result.astype(inference, copy=False)

return self._get_cythonized_result(
"group_any_all",
libgroupby.group_any_all,
aggregate=True,
numeric_only=False,
cython_dtype=np.dtype(np.int8),
Expand Down Expand Up @@ -1733,7 +1733,7 @@ def std(self, ddof: int = 1):
Standard deviation of values within each group.
"""
return self._get_cythonized_result(
"group_var",
libgroupby.group_var,
aggregate=True,
needs_counts=True,
needs_values=True,
Expand Down Expand Up @@ -2149,7 +2149,7 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit=None):
limit = -1

return self._get_cythonized_result(
"group_fillna_indexer",
libgroupby.group_fillna_indexer,
numeric_only=False,
needs_mask=True,
cython_dtype=np.dtype(np.int64),
Expand Down Expand Up @@ -2465,7 +2465,7 @@ def post_processor(vals: np.ndarray, inference: type | None) -> np.ndarray:

if is_scalar(q):
return self._get_cythonized_result(
"group_quantile",
libgroupby.group_quantile,
aggregate=True,
numeric_only=False,
needs_values=True,
Expand All @@ -2479,7 +2479,7 @@ def post_processor(vals: np.ndarray, inference: type | None) -> np.ndarray:
else:
results = [
self._get_cythonized_result(
"group_quantile",
libgroupby.group_quantile,
aggregate=True,
needs_values=True,
needs_mask=True,
Expand Down Expand Up @@ -2817,7 +2817,7 @@ def cummax(self, axis=0, **kwargs):
@final
def _get_cythonized_result(
self,
how: str,
base_func: Callable,
cython_dtype: np.dtype,
aggregate: bool = False,
numeric_only: bool | lib.NoDefault = lib.no_default,
Expand All @@ -2839,7 +2839,7 @@ def _get_cythonized_result(

Parameters
----------
how : str, Cythonized function name to be called
base_func : callable, Cythonized function to be called
cython_dtype : np.dtype
Type of the array that will be modified by the Cython call.
aggregate : bool, default False
Expand Down Expand Up @@ -2910,7 +2910,7 @@ def _get_cythonized_result(
ids, _, ngroups = grouper.group_info
output: dict[base.OutputKey, ArrayLike] = {}

base_func = getattr(libgroupby, how)
how = base_func.__name__
base_func = partial(base_func, labels=ids)
if needs_ngroups:
base_func = partial(base_func, ngroups=ngroups)
Expand Down