Skip to content

REF: use _treat_as_na more #51067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 24 additions & 29 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,8 @@ def group_cumprod(

if uses_mask:
isna_entry = mask[i, j]
elif int64float_t is float64_t or int64float_t is float32_t:
isna_entry = val != val
else:
isna_entry = False
isna_entry = _treat_as_na(val, False)

if not isna_entry:
isna_prev = accum_mask[lab, j]
Expand Down Expand Up @@ -737,23 +735,10 @@ def group_sum(
for j in range(K):
val = values[i, j]

# not nan
# With dt64/td64 values, values have been cast to float64
# instead if int64 for group_sum, but the logic
# is otherwise the same as in _treat_as_na
if uses_mask:
isna_entry = mask[i, j]
elif (
sum_t is float32_t
or sum_t is float64_t
or sum_t is complex64_t
):
# avoid warnings because of equality comparison
isna_entry = not val == val
elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
isna_entry = True
else:
isna_entry = False
isna_entry = _treat_as_na(val, is_datetimelike)

if not isna_entry:
nobs[lab, j] += 1
Expand Down Expand Up @@ -831,10 +816,8 @@ def group_prod(

if uses_mask:
isna_entry = mask[i, j]
elif int64float_t is float32_t or int64float_t is float64_t:
isna_entry = not val == val
else:
isna_entry = False
isna_entry = _treat_as_na(val, False)

if not isna_entry:
nobs[lab, j] += 1
Expand Down Expand Up @@ -906,7 +889,7 @@ def group_var(
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = not val == val
isna_entry = _treat_as_na(val, False)

if not isna_entry:
nobs[lab, j] += 1
Expand Down Expand Up @@ -1008,9 +991,12 @@ def group_mean(
if uses_mask:
isna_entry = mask[i, j]
elif is_datetimelike:
# With group_mean, we cannot just use _treat_as_na bc
# datetimelike dtypes get cast to float64 instead of
# to int64.
isna_entry = val == NPY_NAT
else:
isna_entry = not val == val
isna_entry = _treat_as_na(val, is_datetimelike)

if not isna_entry:
nobs[lab, j] += 1
Expand Down Expand Up @@ -1086,10 +1072,8 @@ def group_ohlc(

if uses_mask:
isna_entry = mask[i, 0]
elif int64float_t is float32_t or int64float_t is float64_t:
isna_entry = val != val
else:
isna_entry = False
isna_entry = _treat_as_na(val, False)

if isna_entry:
continue
Expand Down Expand Up @@ -1231,15 +1215,26 @@ def group_quantile(
# group_nth, group_last, group_rank
# ----------------------------------------------------------------------

cdef bint _treat_as_na(numeric_object_t val, bint is_datetimelike) nogil:
if numeric_object_t is object:
ctypedef fused numeric_object_complex_t:
numeric_object_t
complex64_t
complex128_t


cdef bint _treat_as_na(numeric_object_complex_t val, bint is_datetimelike) nogil:
if numeric_object_complex_t is object:
# Should never be used, but we need to avoid the `val != val` below
# or else cython will raise about gil acquisition.
raise NotImplementedError

elif numeric_object_t is int64_t:
elif numeric_object_complex_t is int64_t:
return is_datetimelike and val == NPY_NAT
elif numeric_object_t is float32_t or numeric_object_t is float64_t:
elif (
numeric_object_complex_t is float32_t
or numeric_object_complex_t is float64_t
or numeric_object_complex_t is complex64_t
or numeric_object_complex_t is complex128_t
):
return val != val
else:
# non-datetimelike integer
Expand Down