Skip to content

REF: de-duplicate groupby_helper code #28934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def group_any_all(uint8_t[:] out,
const uint8_t[:] mask,
object val_test,
bint skipna):
"""Aggregated boolean values to show truthfulness of group elements
"""
Aggregated boolean values to show truthfulness of group elements.

Parameters
----------
Expand Down
139 changes: 52 additions & 87 deletions pandas/_libs/groupby_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ ctypedef fused rank_t:
object


cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only applicable in groupby or should it go in util?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only here. As I mention in comments below, I'm not wild about the fact that is_datetimelike is effectively hard-coded to True in several places. Will try to see if that causes problems in follow-up(s)

if rank_t is object:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhat counter-intuitive that this works - out of curiosity what was the complaint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT the val != val check requires calling val.__ne__ which requires the gil

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn’t need the gil if these are c level object (eg ints or floats)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we just need to exclude the object case from getting to the val == val step, thats all this check is doing

# Should never be used, but we need to avoid the `val != val` below
# or else cython will raise about gil acquisition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty odd, why dont you return an enum here (true, false, raise) and handle in code appropriately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the vicissitudes of cython fused types

raise NotImplementedError

elif rank_t is int64_t:
return is_datetimelike and val == NPY_NAT
else:
return val != val


@cython.wraparound(False)
@cython.boundscheck(False)
def group_last(rank_t[:, :] out,
Expand Down Expand Up @@ -61,24 +73,16 @@ def group_last(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if rank_t is int64_t:
# need a special notna check
if val != NPY_NAT:
nobs[lab, j] += 1
resx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
resx[lab, j] = val
if val == val:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this totally equivalent to what's in place? Looks like we are losing the NPY_NAT check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the rank_t is object case where we cant use _treat_as_na. (comment just below here)

# NB: use _treat_as_na here once
# conditional-nogil is available.
nobs[lab, j] += 1
resx[lab, j] = val

for i in range(ncounts):
for j in range(K):
if nobs[i, j] == 0:
if rank_t is int64_t:
out[i, j] = NPY_NAT
else:
out[i, j] = NAN
out[i, j] = NAN
else:
out[i, j] = resx[i, j]
else:
Expand All @@ -92,16 +96,10 @@ def group_last(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if rank_t is int64_t:
# need a special notna check
if val != NPY_NAT:
nobs[lab, j] += 1
resx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
resx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
resx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand All @@ -113,6 +111,7 @@ def group_last(rank_t[:, :] out,
break
else:
out[i, j] = NAN

else:
out[i, j] = resx[i, j]

Expand All @@ -121,7 +120,6 @@ def group_last(rank_t[:, :] out,
# block.
raise RuntimeError("empty group with uint64_t")


group_last_float64 = group_last["float64_t"]
group_last_float32 = group_last["float32_t"]
group_last_int64 = group_last["int64_t"]
Expand Down Expand Up @@ -169,8 +167,9 @@ def group_nth(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if val == val:
# NB: use _treat_as_na here once
# conditional-nogil is available.
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
Expand All @@ -193,18 +192,11 @@ def group_nth(rank_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if rank_t is int64_t:
# need a special notna check
if val != NPY_NAT:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand Down Expand Up @@ -487,17 +479,11 @@ def group_max(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if groupby_t is int64_t:
if val != nan_val:
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand Down Expand Up @@ -563,17 +549,11 @@ def group_min(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

# not nan
if groupby_t is int64_t:
if val != nan_val:
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val
else:
if val == val:
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val

for i in range(ncounts):
for j in range(K):
Expand Down Expand Up @@ -643,21 +623,13 @@ def group_cummin(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

# val = nan
if groupby_t is int64_t:
if is_datetimelike and val == NPY_NAT:
out[i, j] = NPY_NAT
else:
mval = accum[lab, j]
if val < mval:
accum[lab, j] = mval = val
out[i, j] = mval
if _treat_as_na(val, is_datetimelike):
out[i, j] = val
else:
if val == val:
mval = accum[lab, j]
if val < mval:
accum[lab, j] = mval = val
out[i, j] = mval
mval = accum[lab, j]
if val < mval:
accum[lab, j] = mval = val
out[i, j] = mval


@cython.boundscheck(False)
Expand Down Expand Up @@ -712,17 +684,10 @@ def group_cummax(groupby_t[:, :] out,
for j in range(K):
val = values[i, j]

if groupby_t is int64_t:
if is_datetimelike and val == NPY_NAT:
out[i, j] = NPY_NAT
else:
mval = accum[lab, j]
if val > mval:
accum[lab, j] = mval = val
out[i, j] = mval
if _treat_as_na(val, is_datetimelike):
out[i, j] = val
else:
if val == val:
mval = accum[lab, j]
if val > mval:
accum[lab, j] = mval = val
out[i, j] = mval
mval = accum[lab, j]
if val > mval:
accum[lab, j] = mval = val
out[i, j] = mval