-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
CLN: group_quantile #43489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLN: group_quantile #43489
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -800,9 +800,12 @@ def group_quantile(ndarray[float64_t, ndim=2] out, | |
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs | ||
Py_ssize_t grp_start=0, idx=0 | ||
intp_t lab | ||
uint8_t interp | ||
InterpolationEnumType interp | ||
float64_t q_val, q_idx, frac, val, next_val | ||
ndarray[int64_t] counts, non_na_counts, sort_arr | ||
int64_t[::1] counts, non_na_counts | ||
intp_t[::1] sort_arr | ||
ndarray[intp_t] labels_for_lexsort | ||
intp_t na_label_for_sorting = 0 | ||
|
||
assert values.shape[0] == N | ||
|
||
|
@@ -825,27 +828,28 @@ def group_quantile(ndarray[float64_t, ndim=2] out, | |
ngroups = len(out) | ||
counts = np.zeros(ngroups, dtype=np.int64) | ||
non_na_counts = np.zeros(ngroups, dtype=np.int64) | ||
labels_for_lexsort = labels.copy() | ||
|
||
# Put '-1' (NaN) labels as the last group so it does not interfere | ||
# with the calculations. | ||
if N > 0: | ||
na_label_for_sorting = labels.max() + 1 | ||
|
||
# First figure out the size of every group | ||
with nogil: | ||
for i in range(N): | ||
lab = labels[i] | ||
if lab == -1: # NA group label | ||
if lab == -1: | ||
labels_for_lexsort[i] = na_label_for_sorting | ||
continue | ||
|
||
counts[lab] += 1 | ||
if not mask[i]: | ||
non_na_counts[lab] += 1 | ||
|
||
# Get an index of values sorted by labels and then values | ||
if labels.any(): | ||
# Put '-1' (NaN) labels as the last group so it does not interfere | ||
# with the calculations. | ||
labels_for_lexsort = np.where(labels == -1, labels.max() + 1, labels) | ||
else: | ||
labels_for_lexsort = labels | ||
# Get an index of values sorted by values and then labels | ||
order = (values, labels_for_lexsort) | ||
sort_arr = np.lexsort(order).astype(np.int64, copy=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW the big perf gain comes from moving this call out of cython and not repeating it for each column There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, makes sense. If you're planning on doing that, I'll just close this for now and collect the other cleanups elsewhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was actually suggesting you do that in this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sounds good :) I'll take a look at that |
||
sort_arr = np.lexsort(order).astype(np.intp, copy=False) | ||
|
||
with nogil: | ||
for i in range(ngroups): | ||
|
@@ -1420,7 +1424,7 @@ cdef cummin_max(groupby_t[:, ::1] out, | |
cdef: | ||
Py_ssize_t i, j, N, K | ||
groupby_t val, mval, na_val | ||
uint8_t[:, ::1] seen_na | ||
uint8_t[:, ::1] seen_na = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this because previously |
||
intp_t lab | ||
bint na_possible | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
labels.any()
conditional is strange - special casing the all 0 label case (which seems rare and specific) when we really care about -1 labels. Regardless, since we already check for -1 in loop above, no need to check withany
andwhere
here