diff --git a/pandas/core/missing.py b/pandas/core/missing.py index 0b77a6d821c6d..d1597b23cf577 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -3,7 +3,10 @@ """ from __future__ import annotations -from functools import partial +from functools import ( + partial, + wraps, +) from typing import ( TYPE_CHECKING, Any, @@ -11,6 +14,7 @@ Optional, Set, Union, + cast, ) import numpy as np @@ -22,15 +26,13 @@ from pandas._typing import ( ArrayLike, Axis, - DtypeObj, + F, ) from pandas.compat._optional import import_optional_dependency from pandas.core.dtypes.cast import infer_dtype_from from pandas.core.dtypes.common import ( - ensure_float64, is_array_like, - is_integer_dtype, is_numeric_v_string_like, needs_i8_conversion, ) @@ -674,54 +676,53 @@ def interpolate_2d( return result -def _cast_values_for_fillna(values, dtype: DtypeObj, has_mask: bool): - """ - Cast values to a dtype that algos.pad and algos.backfill can handle. - """ - # TODO: for int-dtypes we make a copy, but for everything else this - # alters the values in-place. Is this intentional? +def _fillna_prep(values, mask=None): + # boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d - if needs_i8_conversion(dtype): - values = values.view(np.int64) + if mask is None: + mask = isna(values) - elif is_integer_dtype(values) and not has_mask: - # NB: this check needs to come after the datetime64 check above - # has_mask check to avoid casting i8 values that have already - # been cast from PeriodDtype - values = ensure_float64(values) + mask = mask.view(np.uint8) + return mask - return values +def _datetimelike_compat(func: F) -> F: + """ + Wrapper to handle datetime64 and timedelta64 dtypes. + """ -def _fillna_prep(values, mask=None): - # boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d - dtype = values.dtype + @wraps(func) + def new_func(values, limit=None, mask=None): + if needs_i8_conversion(values.dtype): + if mask is None: + # This needs to occur before casting to int64 + mask = isna(values) - has_mask = mask is not None - if not has_mask: - # This needs to occur before datetime/timedeltas are cast to int64 - mask = isna(values) + result = func(values.view("i8"), limit=limit, mask=mask) + return result.view(values.dtype) - values = _cast_values_for_fillna(values, dtype, has_mask) + return func(values, limit=limit, mask=mask) - mask = mask.view(np.uint8) - return values, mask + return cast(F, new_func) +@_datetimelike_compat def _pad_1d(values, limit=None, mask=None): - values, mask = _fillna_prep(values, mask) + mask = _fillna_prep(values, mask) algos.pad_inplace(values, mask, limit=limit) return values +@_datetimelike_compat def _backfill_1d(values, limit=None, mask=None): - values, mask = _fillna_prep(values, mask) + mask = _fillna_prep(values, mask) algos.backfill_inplace(values, mask, limit=limit) return values +@_datetimelike_compat def _pad_2d(values, limit=None, mask=None): - values, mask = _fillna_prep(values, mask) + mask = _fillna_prep(values, mask) if np.all(values.shape): algos.pad_2d_inplace(values, mask, limit=limit) @@ -731,8 +732,9 @@ def _pad_2d(values, limit=None, mask=None): return values +@_datetimelike_compat def _backfill_2d(values, limit=None, mask=None): - values, mask = _fillna_prep(values, mask) + mask = _fillna_prep(values, mask) if np.all(values.shape): algos.backfill_2d_inplace(values, mask, limit=limit)