Skip to content

CLN: group_quantile #43489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,12 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
Py_ssize_t grp_start=0, idx=0
intp_t lab
uint8_t interp
InterpolationEnumType interp
float64_t q_val, q_idx, frac, val, next_val
ndarray[int64_t] counts, non_na_counts, sort_arr
int64_t[::1] counts, non_na_counts
intp_t[::1] sort_arr
ndarray[intp_t] labels_for_lexsort
intp_t na_label_for_sorting = 0

assert values.shape[0] == N

Expand All @@ -825,27 +828,28 @@ def group_quantile(ndarray[float64_t, ndim=2] out,
ngroups = len(out)
counts = np.zeros(ngroups, dtype=np.int64)
non_na_counts = np.zeros(ngroups, dtype=np.int64)
labels_for_lexsort = labels.copy()

# Put '-1' (NaN) labels as the last group so it does not interfere
# with the calculations.
if N > 0:
na_label_for_sorting = labels.max() + 1

# First figure out the size of every group
with nogil:
for i in range(N):
lab = labels[i]
if lab == -1: # NA group label
if lab == -1:
labels_for_lexsort[i] = na_label_for_sorting
continue

counts[lab] += 1
if not mask[i]:
non_na_counts[lab] += 1

# Get an index of values sorted by labels and then values
if labels.any():
# Put '-1' (NaN) labels as the last group so it does not interfere
# with the calculations.
labels_for_lexsort = np.where(labels == -1, labels.max() + 1, labels)
else:
labels_for_lexsort = labels
Copy link
Member Author

@mzeitlin11 mzeitlin11 Sep 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This labels.any() conditional is strange - special casing the all 0 label case (which seems rare and specific) when we really care about -1 labels. Regardless, since we already check for -1 in loop above, no need to check with any and where here

# Get an index of values sorted by values and then labels
order = (values, labels_for_lexsort)
sort_arr = np.lexsort(order).astype(np.int64, copy=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW the big perf gain comes from moving this call out of cython and not repeating it for each column

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense. If you're planning on doing that, I'll just close this for now and collect the other cleanups elsewhere

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually suggesting you do that in this PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sounds good :) I'll take a look at that

sort_arr = np.lexsort(order).astype(np.intp, copy=False)

with nogil:
for i in range(ngroups):
Expand Down Expand Up @@ -1420,7 +1424,7 @@ cdef cummin_max(groupby_t[:, ::1] out,
cdef:
Py_ssize_t i, j, N, K
groupby_t val, mval, na_val
uint8_t[:, ::1] seen_na
uint8_t[:, ::1] seen_na = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this because previously seen_na was treated as potentially uninitialized, leading to an emission of extra code whenever it was accessed, for example
if (unlikely(!__pyx_v_seen_na.memview)) { __Pyx_RaiseUnboundMemoryviewSliceNogil("seen_na"); __PYX_ERR(0, 1462, __pyx_L7_error)

intp_t lab
bint na_possible

Expand Down