Skip to content

Commit 252a97a

Browse files
authored
PERF: stronger typing in libgroupby (#43103)
1 parent 774c2a9 commit 252a97a

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

pandas/_libs/groupby.pyx

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def group_median_float64(ndarray[float64_t, ndim=2] out,
145145
@cython.wraparound(False)
146146
def group_cumprod_float64(float64_t[:, ::1] out,
147147
const float64_t[:, :] values,
148-
const intp_t[:] labels,
148+
const intp_t[::1] labels,
149149
int ngroups,
150150
bint is_datetimelike,
151151
bint skipna=True) -> None:
@@ -202,7 +202,7 @@ def group_cumprod_float64(float64_t[:, ::1] out,
202202
@cython.wraparound(False)
203203
def group_cumsum(numeric[:, ::1] out,
204204
ndarray[numeric, ndim=2] values,
205-
const intp_t[:] labels,
205+
const intp_t[::1] labels,
206206
int ngroups,
207207
is_datetimelike,
208208
bint skipna=True) -> None:
@@ -269,7 +269,7 @@ def group_cumsum(numeric[:, ::1] out,
269269

270270
@cython.boundscheck(False)
271271
@cython.wraparound(False)
272-
def group_shift_indexer(int64_t[::1] out, const intp_t[:] labels,
272+
def group_shift_indexer(int64_t[::1] out, const intp_t[::1] labels,
273273
int ngroups, int periods) -> None:
274274
cdef:
275275
Py_ssize_t N, i, j, ii, lab
@@ -390,7 +390,7 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels,
390390
@cython.wraparound(False)
391391
def group_any_all(int8_t[::1] out,
392392
const int8_t[::1] values,
393-
const intp_t[:] labels,
393+
const intp_t[::1] labels,
394394
const uint8_t[::1] mask,
395395
str val_test,
396396
bint skipna,
@@ -482,7 +482,7 @@ ctypedef fused add_t:
482482
def group_add(add_t[:, ::1] out,
483483
int64_t[::1] counts,
484484
ndarray[add_t, ndim=2] values,
485-
const intp_t[:] labels,
485+
const intp_t[::1] labels,
486486
Py_ssize_t min_count=0) -> None:
487487
"""
488488
Only aggregates on axis=0 using Kahan summation
@@ -565,7 +565,7 @@ def group_add(add_t[:, ::1] out,
565565
def group_prod(floating[:, ::1] out,
566566
int64_t[::1] counts,
567567
ndarray[floating, ndim=2] values,
568-
const intp_t[:] labels,
568+
const intp_t[::1] labels,
569569
Py_ssize_t min_count=0) -> None:
570570
"""
571571
Only aggregates on axis=0
@@ -614,7 +614,7 @@ def group_prod(floating[:, ::1] out,
614614
def group_var(floating[:, ::1] out,
615615
int64_t[::1] counts,
616616
ndarray[floating, ndim=2] values,
617-
const intp_t[:] labels,
617+
const intp_t[::1] labels,
618618
Py_ssize_t min_count=-1,
619619
int64_t ddof=1) -> None:
620620
cdef:
@@ -720,7 +720,7 @@ def group_mean(floating[:, ::1] out,
720720
def group_ohlc(floating[:, ::1] out,
721721
int64_t[::1] counts,
722722
ndarray[floating, ndim=2] values,
723-
const intp_t[:] labels,
723+
const intp_t[::1] labels,
724724
Py_ssize_t min_count=-1) -> None:
725725
"""
726726
Only aggregates on axis=0
@@ -910,7 +910,7 @@ cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
910910
def group_last(rank_t[:, ::1] out,
911911
int64_t[::1] counts,
912912
ndarray[rank_t, ndim=2] values,
913-
const intp_t[:] labels,
913+
const intp_t[::1] labels,
914914
Py_ssize_t min_count=-1) -> None:
915915
"""
916916
Only aggregates on axis=0
@@ -1002,7 +1002,7 @@ def group_last(rank_t[:, ::1] out,
10021002
def group_nth(rank_t[:, ::1] out,
10031003
int64_t[::1] counts,
10041004
ndarray[rank_t, ndim=2] values,
1005-
const intp_t[:] labels,
1005+
const intp_t[::1] labels,
10061006
int64_t min_count=-1,
10071007
int64_t rank=1,
10081008
) -> None:
@@ -1095,7 +1095,7 @@ def group_nth(rank_t[:, ::1] out,
10951095
@cython.wraparound(False)
10961096
def group_rank(float64_t[:, ::1] out,
10971097
ndarray[rank_t, ndim=2] values,
1098-
const intp_t[:] labels,
1098+
const intp_t[::1] labels,
10991099
int ngroups,
11001100
bint is_datetimelike, str ties_method="average",
11011101
bint ascending=True, bint pct=False, str na_option="keep") -> None:
@@ -1173,7 +1173,7 @@ ctypedef fused groupby_t:
11731173
cdef group_min_max(groupby_t[:, ::1] out,
11741174
int64_t[::1] counts,
11751175
ndarray[groupby_t, ndim=2] values,
1176-
const intp_t[:] labels,
1176+
const intp_t[::1] labels,
11771177
Py_ssize_t min_count=-1,
11781178
bint is_datetimelike=False,
11791179
bint compute_max=True):
@@ -1274,7 +1274,7 @@ cdef group_min_max(groupby_t[:, ::1] out,
12741274
def group_max(groupby_t[:, ::1] out,
12751275
int64_t[::1] counts,
12761276
ndarray[groupby_t, ndim=2] values,
1277-
const intp_t[:] labels,
1277+
const intp_t[::1] labels,
12781278
Py_ssize_t min_count=-1,
12791279
bint is_datetimelike=False) -> None:
12801280
"""See group_min_max.__doc__"""
@@ -1294,7 +1294,7 @@ def group_max(groupby_t[:, ::1] out,
12941294
def group_min(groupby_t[:, ::1] out,
12951295
int64_t[::1] counts,
12961296
ndarray[groupby_t, ndim=2] values,
1297-
const intp_t[:] labels,
1297+
const intp_t[::1] labels,
12981298
Py_ssize_t min_count=-1,
12991299
bint is_datetimelike=False) -> None:
13001300
"""See group_min_max.__doc__"""
@@ -1314,7 +1314,7 @@ def group_min(groupby_t[:, ::1] out,
13141314
cdef group_cummin_max(groupby_t[:, ::1] out,
13151315
ndarray[groupby_t, ndim=2] values,
13161316
uint8_t[:, ::1] mask,
1317-
const intp_t[:] labels,
1317+
const intp_t[::1] labels,
13181318
int ngroups,
13191319
bint is_datetimelike,
13201320
bint skipna,
@@ -1368,7 +1368,7 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
13681368
@cython.wraparound(False)
13691369
cdef cummin_max(groupby_t[:, ::1] out,
13701370
ndarray[groupby_t, ndim=2] values,
1371-
const intp_t[:] labels,
1371+
const intp_t[::1] labels,
13721372
groupby_t[:, ::1] accum,
13731373
bint skipna,
13741374
bint is_datetimelike,
@@ -1428,7 +1428,7 @@ cdef cummin_max(groupby_t[:, ::1] out,
14281428
cdef masked_cummin_max(groupby_t[:, ::1] out,
14291429
ndarray[groupby_t, ndim=2] values,
14301430
uint8_t[:, ::1] mask,
1431-
const intp_t[:] labels,
1431+
const intp_t[::1] labels,
14321432
groupby_t[:, ::1] accum,
14331433
bint skipna,
14341434
bint compute_max):
@@ -1471,7 +1471,7 @@ cdef masked_cummin_max(groupby_t[:, ::1] out,
14711471
@cython.wraparound(False)
14721472
def group_cummin(groupby_t[:, ::1] out,
14731473
ndarray[groupby_t, ndim=2] values,
1474-
const intp_t[:] labels,
1474+
const intp_t[::1] labels,
14751475
int ngroups,
14761476
bint is_datetimelike,
14771477
uint8_t[:, ::1] mask=None,
@@ -1493,7 +1493,7 @@ def group_cummin(groupby_t[:, ::1] out,
14931493
@cython.wraparound(False)
14941494
def group_cummax(groupby_t[:, ::1] out,
14951495
ndarray[groupby_t, ndim=2] values,
1496-
const intp_t[:] labels,
1496+
const intp_t[::1] labels,
14971497
int ngroups,
14981498
bint is_datetimelike,
14991499
uint8_t[:, ::1] mask=None,

pandas/core/groupby/groupby.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,7 +1557,7 @@ def result_to_bool(
15571557
return result.astype(inference, copy=False)
15581558

15591559
return self._get_cythonized_result(
1560-
"group_any_all",
1560+
libgroupby.group_any_all,
15611561
aggregate=True,
15621562
numeric_only=False,
15631563
cython_dtype=np.dtype(np.int8),
@@ -1733,7 +1733,7 @@ def std(self, ddof: int = 1):
17331733
Standard deviation of values within each group.
17341734
"""
17351735
return self._get_cythonized_result(
1736-
"group_var",
1736+
libgroupby.group_var,
17371737
aggregate=True,
17381738
needs_counts=True,
17391739
needs_values=True,
@@ -2149,7 +2149,7 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit=None):
21492149
limit = -1
21502150

21512151
return self._get_cythonized_result(
2152-
"group_fillna_indexer",
2152+
libgroupby.group_fillna_indexer,
21532153
numeric_only=False,
21542154
needs_mask=True,
21552155
cython_dtype=np.dtype(np.int64),
@@ -2465,7 +2465,7 @@ def post_processor(vals: np.ndarray, inference: type | None) -> np.ndarray:
24652465

24662466
if is_scalar(q):
24672467
return self._get_cythonized_result(
2468-
"group_quantile",
2468+
libgroupby.group_quantile,
24692469
aggregate=True,
24702470
numeric_only=False,
24712471
needs_values=True,
@@ -2479,7 +2479,7 @@ def post_processor(vals: np.ndarray, inference: type | None) -> np.ndarray:
24792479
else:
24802480
results = [
24812481
self._get_cythonized_result(
2482-
"group_quantile",
2482+
libgroupby.group_quantile,
24832483
aggregate=True,
24842484
needs_values=True,
24852485
needs_mask=True,
@@ -2817,7 +2817,7 @@ def cummax(self, axis=0, **kwargs):
28172817
@final
28182818
def _get_cythonized_result(
28192819
self,
2820-
how: str,
2820+
base_func: Callable,
28212821
cython_dtype: np.dtype,
28222822
aggregate: bool = False,
28232823
numeric_only: bool | lib.NoDefault = lib.no_default,
@@ -2839,7 +2839,7 @@ def _get_cythonized_result(
28392839
28402840
Parameters
28412841
----------
2842-
how : str, Cythonized function name to be called
2842+
base_func : callable, Cythonized function to be called
28432843
cython_dtype : np.dtype
28442844
Type of the array that will be modified by the Cython call.
28452845
aggregate : bool, default False
@@ -2910,7 +2910,7 @@ def _get_cythonized_result(
29102910
ids, _, ngroups = grouper.group_info
29112911
output: dict[base.OutputKey, ArrayLike] = {}
29122912

2913-
base_func = getattr(libgroupby, how)
2913+
how = base_func.__name__
29142914
base_func = partial(base_func, labels=ids)
29152915
if needs_ngroups:
29162916
base_func = partial(base_func, ngroups=ngroups)

0 commit comments

Comments
 (0)