Skip to content

Commit ca4705a

Browse files
committed
add type hints
1 parent 601eff1 commit ca4705a

File tree

1 file changed

+37
-15
lines changed

1 file changed

+37
-15
lines changed

pandas/core/dtypes/cast.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from datetime import date, datetime, timedelta
6-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type
6+
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set, Tuple, Type, Union
77

88
import numpy as np
99

@@ -18,7 +18,7 @@
1818
ints_to_pydatetime,
1919
)
2020
from pandas._libs.tslibs.timezones import tz_compare
21-
from pandas._typing import ArrayLike, Dtype, DtypeObj
21+
from pandas._typing import AnyArrayLike, ArrayLike, Dtype, DtypeObj, Scalar
2222
from pandas.util._validators import validate_bool_kwarg
2323

2424
from pandas.core.dtypes.common import (
@@ -118,7 +118,7 @@ def is_nested_object(obj) -> bool:
118118
return False
119119

120120

121-
def maybe_downcast_to_dtype(result, dtype):
121+
def maybe_downcast_to_dtype(result, dtype: Dtype):
122122
"""
123123
try to cast to the specified dtype (e.g. convert back to bool/int
124124
or could be an astype of float64->float32
@@ -186,7 +186,7 @@ def maybe_downcast_to_dtype(result, dtype):
186186
return result
187187

188188

189-
def maybe_downcast_numeric(result, dtype, do_round: bool = False):
189+
def maybe_downcast_numeric(result, dtype: Dtype, do_round: bool = False):
190190
"""
191191
Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
192192
@@ -329,7 +329,9 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
329329
return dtype
330330

331331

332-
def maybe_cast_to_extension_array(cls: Type["ExtensionArray"], obj, dtype=None):
332+
def maybe_cast_to_extension_array(
333+
cls: Type["ExtensionArray"], obj, dtype: Dtype = None
334+
):
333335
"""
334336
Call to `_from_sequence` that returns the object unchanged on Exception.
335337
@@ -362,7 +364,9 @@ def maybe_cast_to_extension_array(cls: Type["ExtensionArray"], obj, dtype=None):
362364
return result
363365

364366

365-
def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other):
367+
def maybe_upcast_putmask(
368+
result: np.ndarray, mask: np.ndarray, other: Scalar
369+
) -> Tuple[np.ndarray, bool]:
366370
"""
367371
A safe version of putmask that potentially upcasts the result.
368372
@@ -657,7 +661,7 @@ def maybe_promote(dtype, fill_value=np.nan):
657661
return dtype, fill_value
658662

659663

660-
def _ensure_dtype_type(value, dtype):
664+
def _ensure_dtype_type(value, dtype: DtypeObj):
661665
"""
662666
Ensure that the given value is an instance of the given dtype.
663667
@@ -784,7 +788,9 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> Tuple[DtypeObj,
784788

785789

786790
# TODO: try to make the Any in the return annotation more specific
787-
def infer_dtype_from_array(arr, pandas_dtype: bool = False) -> Tuple[DtypeObj, Any]:
791+
def infer_dtype_from_array(
792+
arr, pandas_dtype: bool = False
793+
) -> Tuple[DtypeObj, AnyArrayLike]:
788794
"""
789795
Infer the dtype from an array.
790796
@@ -872,7 +878,12 @@ def maybe_infer_dtype_type(element):
872878
return tipo
873879

874880

875-
def maybe_upcast(values, fill_value=np.nan, dtype=None, copy: bool = False):
881+
def maybe_upcast(
882+
values: ArrayLike,
883+
fill_value: Scalar = np.nan,
884+
dtype: Dtype = None,
885+
copy: bool = False,
886+
) -> Tuple[ArrayLike, Scalar]:
876887
"""
877888
Provide explicit type promotion and coercion.
878889
@@ -884,6 +895,13 @@ def maybe_upcast(values, fill_value=np.nan, dtype=None, copy: bool = False):
884895
dtype : if None, then use the dtype of the values, else coerce to this type
885896
copy : bool, default True
886897
If True always make a copy even if no upcast is required.
898+
899+
Returns
900+
-------
901+
values: ndarray or ExtensionArray
902+
the original array, possibly upcast
903+
fill_value:
904+
the fill value, possibly upcast
887905
"""
888906
if not is_scalar(fill_value) and not is_object_dtype(values.dtype):
889907
# We allow arbitrary fill values for object dtype
@@ -904,7 +922,7 @@ def maybe_upcast(values, fill_value=np.nan, dtype=None, copy: bool = False):
904922
return values, fill_value
905923

906924

907-
def invalidate_string_dtypes(dtype_set):
925+
def invalidate_string_dtypes(dtype_set: Set):
908926
"""
909927
Change string like dtypes to object for
910928
``DataFrame.select_dtypes()``.
@@ -926,7 +944,7 @@ def coerce_indexer_dtype(indexer, categories):
926944
return ensure_int64(indexer)
927945

928946

929-
def coerce_to_dtypes(result, dtypes):
947+
def coerce_to_dtypes(result: Sequence[Scalar], dtypes: Sequence[Dtype]) -> List[Scalar]:
930948
"""
931949
given a dtypes and a result set, coerce the result elements to the
932950
dtypes
@@ -956,7 +974,9 @@ def conv(r, dtype):
956974
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
957975

958976

959-
def astype_nansafe(arr, dtype, copy: bool = True, skipna: bool = False):
977+
def astype_nansafe(
978+
arr, dtype: DtypeObj, copy: bool = True, skipna: bool = False
979+
) -> ArrayLike:
960980
"""
961981
Cast the elements of an array to a given dtype a nan-safe manner.
962982
@@ -1060,7 +1080,9 @@ def astype_nansafe(arr, dtype, copy: bool = True, skipna: bool = False):
10601080
return arr.view(dtype)
10611081

10621082

1063-
def maybe_convert_objects(values: np.ndarray, convert_numeric: bool = True):
1083+
def maybe_convert_objects(
1084+
values: np.ndarray, convert_numeric: bool = True
1085+
) -> Union[np.ndarray, ABCDatetimeIndex]:
10641086
"""
10651087
If we have an object dtype array, try to coerce dates and/or numbers.
10661088
@@ -1181,7 +1203,7 @@ def soft_convert_objects(
11811203

11821204

11831205
def convert_dtypes(
1184-
input_array,
1206+
input_array: AnyArrayLike,
11851207
convert_string: bool = True,
11861208
convert_integer: bool = True,
11871209
convert_boolean: bool = True,
@@ -1247,7 +1269,7 @@ def convert_dtypes(
12471269
return inferred_dtype
12481270

12491271

1250-
def maybe_castable(arr) -> bool:
1272+
def maybe_castable(arr: np.ndarray) -> bool:
12511273
# return False to force a non-fastpath
12521274

12531275
# check datetime64[ns]/timedelta64[ns] are valid

0 commit comments

Comments
 (0)