diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index f0beab7193183..cc7c6e0e49c50 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -289,10 +289,8 @@ def group_cumprod( if uses_mask: isna_entry = mask[i, j] - elif int64float_t is float64_t or int64float_t is float32_t: - isna_entry = val != val else: - isna_entry = False + isna_entry = _treat_as_na(val, False) if not isna_entry: isna_prev = accum_mask[lab, j] @@ -737,23 +735,10 @@ def group_sum( for j in range(K): val = values[i, j] - # not nan - # With dt64/td64 values, values have been cast to float64 - # instead if int64 for group_sum, but the logic - # is otherwise the same as in _treat_as_na if uses_mask: isna_entry = mask[i, j] - elif ( - sum_t is float32_t - or sum_t is float64_t - or sum_t is complex64_t - ): - # avoid warnings because of equality comparison - isna_entry = not val == val - elif sum_t is int64_t and is_datetimelike and val == NPY_NAT: - isna_entry = True else: - isna_entry = False + isna_entry = _treat_as_na(val, is_datetimelike) if not isna_entry: nobs[lab, j] += 1 @@ -831,10 +816,8 @@ def group_prod( if uses_mask: isna_entry = mask[i, j] - elif int64float_t is float32_t or int64float_t is float64_t: - isna_entry = not val == val else: - isna_entry = False + isna_entry = _treat_as_na(val, False) if not isna_entry: nobs[lab, j] += 1 @@ -906,7 +889,7 @@ def group_var( if uses_mask: isna_entry = mask[i, j] else: - isna_entry = not val == val + isna_entry = _treat_as_na(val, False) if not isna_entry: nobs[lab, j] += 1 @@ -1008,9 +991,12 @@ def group_mean( if uses_mask: isna_entry = mask[i, j] elif is_datetimelike: + # With group_mean, we cannot just use _treat_as_na bc + # datetimelike dtypes get cast to float64 instead of + # to int64. isna_entry = val == NPY_NAT else: - isna_entry = not val == val + isna_entry = _treat_as_na(val, is_datetimelike) if not isna_entry: nobs[lab, j] += 1 @@ -1086,10 +1072,8 @@ def group_ohlc( if uses_mask: isna_entry = mask[i, 0] - elif int64float_t is float32_t or int64float_t is float64_t: - isna_entry = val != val else: - isna_entry = False + isna_entry = _treat_as_na(val, False) if isna_entry: continue @@ -1231,15 +1215,26 @@ def group_quantile( # group_nth, group_last, group_rank # ---------------------------------------------------------------------- -cdef bint _treat_as_na(numeric_object_t val, bint is_datetimelike) nogil: - if numeric_object_t is object: +ctypedef fused numeric_object_complex_t: + numeric_object_t + complex64_t + complex128_t + + +cdef bint _treat_as_na(numeric_object_complex_t val, bint is_datetimelike) nogil: + if numeric_object_complex_t is object: # Should never be used, but we need to avoid the `val != val` below # or else cython will raise about gil acquisition. raise NotImplementedError - elif numeric_object_t is int64_t: + elif numeric_object_complex_t is int64_t: return is_datetimelike and val == NPY_NAT - elif numeric_object_t is float32_t or numeric_object_t is float64_t: + elif ( + numeric_object_complex_t is float32_t + or numeric_object_complex_t is float64_t + or numeric_object_complex_t is complex64_t + or numeric_object_complex_t is complex128_t + ): return val != val else: # non-datetimelike integer