diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 3069bbbf34bb7..c9994812462b1 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -372,7 +372,8 @@ def group_any_all(uint8_t[:] out, const uint8_t[:] mask, object val_test, bint skipna): - """Aggregated boolean values to show truthfulness of group elements + """ + Aggregated boolean values to show truthfulness of group elements. Parameters ---------- diff --git a/pandas/_libs/groupby_helper.pxi.in b/pandas/_libs/groupby_helper.pxi.in index f052feea0bbf3..c837c6c5c6519 100644 --- a/pandas/_libs/groupby_helper.pxi.in +++ b/pandas/_libs/groupby_helper.pxi.in @@ -20,6 +20,18 @@ ctypedef fused rank_t: object +cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil: + if rank_t is object: + # Should never be used, but we need to avoid the `val != val` below + # or else cython will raise about gil acquisition. + raise NotImplementedError + + elif rank_t is int64_t: + return is_datetimelike and val == NPY_NAT + else: + return val != val + + @cython.wraparound(False) @cython.boundscheck(False) def group_last(rank_t[:, :] out, @@ -61,24 +73,16 @@ def group_last(rank_t[:, :] out, for j in range(K): val = values[i, j] - # not nan - if rank_t is int64_t: - # need a special notna check - if val != NPY_NAT: - nobs[lab, j] += 1 - resx[lab, j] = val - else: - if val == val: - nobs[lab, j] += 1 - resx[lab, j] = val + if val == val: + # NB: use _treat_as_na here once + # conditional-nogil is available. + nobs[lab, j] += 1 + resx[lab, j] = val for i in range(ncounts): for j in range(K): if nobs[i, j] == 0: - if rank_t is int64_t: - out[i, j] = NPY_NAT - else: - out[i, j] = NAN + out[i, j] = NAN else: out[i, j] = resx[i, j] else: @@ -92,16 +96,10 @@ def group_last(rank_t[:, :] out, for j in range(K): val = values[i, j] - # not nan - if rank_t is int64_t: - # need a special notna check - if val != NPY_NAT: - nobs[lab, j] += 1 - resx[lab, j] = val - else: - if val == val: - nobs[lab, j] += 1 - resx[lab, j] = val + if not _treat_as_na(val, True): + # TODO: Sure we always want is_datetimelike=True? + nobs[lab, j] += 1 + resx[lab, j] = val for i in range(ncounts): for j in range(K): @@ -113,6 +111,7 @@ def group_last(rank_t[:, :] out, break else: out[i, j] = NAN + else: out[i, j] = resx[i, j] @@ -121,7 +120,6 @@ def group_last(rank_t[:, :] out, # block. raise RuntimeError("empty group with uint64_t") - group_last_float64 = group_last["float64_t"] group_last_float32 = group_last["float32_t"] group_last_int64 = group_last["int64_t"] @@ -169,8 +167,9 @@ def group_nth(rank_t[:, :] out, for j in range(K): val = values[i, j] - # not nan if val == val: + # NB: use _treat_as_na here once + # conditional-nogil is available. nobs[lab, j] += 1 if nobs[lab, j] == rank: resx[lab, j] = val @@ -193,18 +192,11 @@ def group_nth(rank_t[:, :] out, for j in range(K): val = values[i, j] - # not nan - if rank_t is int64_t: - # need a special notna check - if val != NPY_NAT: - nobs[lab, j] += 1 - if nobs[lab, j] == rank: - resx[lab, j] = val - else: - if val == val: - nobs[lab, j] += 1 - if nobs[lab, j] == rank: - resx[lab, j] = val + if not _treat_as_na(val, True): + # TODO: Sure we always want is_datetimelike=True? + nobs[lab, j] += 1 + if nobs[lab, j] == rank: + resx[lab, j] = val for i in range(ncounts): for j in range(K): @@ -487,17 +479,11 @@ def group_max(groupby_t[:, :] out, for j in range(K): val = values[i, j] - # not nan - if groupby_t is int64_t: - if val != nan_val: - nobs[lab, j] += 1 - if val > maxx[lab, j]: - maxx[lab, j] = val - else: - if val == val: - nobs[lab, j] += 1 - if val > maxx[lab, j]: - maxx[lab, j] = val + if not _treat_as_na(val, True): + # TODO: Sure we always want is_datetimelike=True? + nobs[lab, j] += 1 + if val > maxx[lab, j]: + maxx[lab, j] = val for i in range(ncounts): for j in range(K): @@ -563,17 +549,11 @@ def group_min(groupby_t[:, :] out, for j in range(K): val = values[i, j] - # not nan - if groupby_t is int64_t: - if val != nan_val: - nobs[lab, j] += 1 - if val < minx[lab, j]: - minx[lab, j] = val - else: - if val == val: - nobs[lab, j] += 1 - if val < minx[lab, j]: - minx[lab, j] = val + if not _treat_as_na(val, True): + # TODO: Sure we always want is_datetimelike=True? + nobs[lab, j] += 1 + if val < minx[lab, j]: + minx[lab, j] = val for i in range(ncounts): for j in range(K): @@ -643,21 +623,13 @@ def group_cummin(groupby_t[:, :] out, for j in range(K): val = values[i, j] - # val = nan - if groupby_t is int64_t: - if is_datetimelike and val == NPY_NAT: - out[i, j] = NPY_NAT - else: - mval = accum[lab, j] - if val < mval: - accum[lab, j] = mval = val - out[i, j] = mval + if _treat_as_na(val, is_datetimelike): + out[i, j] = val else: - if val == val: - mval = accum[lab, j] - if val < mval: - accum[lab, j] = mval = val - out[i, j] = mval + mval = accum[lab, j] + if val < mval: + accum[lab, j] = mval = val + out[i, j] = mval @cython.boundscheck(False) @@ -712,17 +684,10 @@ def group_cummax(groupby_t[:, :] out, for j in range(K): val = values[i, j] - if groupby_t is int64_t: - if is_datetimelike and val == NPY_NAT: - out[i, j] = NPY_NAT - else: - mval = accum[lab, j] - if val > mval: - accum[lab, j] = mval = val - out[i, j] = mval + if _treat_as_na(val, is_datetimelike): + out[i, j] = val else: - if val == val: - mval = accum[lab, j] - if val > mval: - accum[lab, j] = mval = val - out[i, j] = mval + mval = accum[lab, j] + if val > mval: + accum[lab, j] = mval = val + out[i, j] = mval