Skip to content

Commit d0ac2d7

Browse files
committed
add type hints
1 parent ca27cca commit d0ac2d7

File tree

1 file changed

+37
-15
lines changed

1 file changed

+37
-15
lines changed

pandas/core/dtypes/cast.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from datetime import date, datetime, timedelta
6-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type
6+
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set, Tuple, Type, Union
77

88
import numpy as np
99

@@ -18,7 +18,7 @@
1818
ints_to_pydatetime,
1919
)
2020
from pandas._libs.tslibs.timezones import tz_compare
21-
from pandas._typing import ArrayLike, Dtype, DtypeObj
21+
from pandas._typing import AnyArrayLike, ArrayLike, Dtype, DtypeObj, Scalar
2222
from pandas.util._validators import validate_bool_kwarg
2323

2424
from pandas.core.dtypes.common import (
@@ -118,7 +118,7 @@ def is_nested_object(obj) -> bool:
118118
return False
119119

120120

121-
def maybe_downcast_to_dtype(result, dtype):
121+
def maybe_downcast_to_dtype(result, dtype: Dtype):
122122
"""
123123
try to cast to the specified dtype (e.g. convert back to bool/int
124124
or could be an astype of float64->float32
@@ -186,7 +186,7 @@ def maybe_downcast_to_dtype(result, dtype):
186186
return result
187187

188188

189-
def maybe_downcast_numeric(result, dtype, do_round: bool = False):
189+
def maybe_downcast_numeric(result, dtype: Dtype, do_round: bool = False):
190190
"""
191191
Subset of maybe_downcast_to_dtype restricted to numeric dtypes.
192192
@@ -329,7 +329,9 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
329329
return dtype
330330

331331

332-
def maybe_cast_to_extension_array(cls: Type["ExtensionArray"], obj, dtype=None):
332+
def maybe_cast_to_extension_array(
333+
cls: Type["ExtensionArray"], obj, dtype: Dtype = None
334+
):
333335
"""
334336
Call to `_from_sequence` that returns the object unchanged on Exception.
335337
@@ -362,7 +364,9 @@ def maybe_cast_to_extension_array(cls: Type["ExtensionArray"], obj, dtype=None):
362364
return result
363365

364366

365-
def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other):
367+
def maybe_upcast_putmask(
368+
result: np.ndarray, mask: np.ndarray, other: Scalar
369+
) -> Tuple[np.ndarray, bool]:
366370
"""
367371
A safe version of putmask that potentially upcasts the result.
368372
@@ -660,7 +664,7 @@ def maybe_promote(dtype, fill_value=np.nan):
660664
return dtype, fill_value
661665

662666

663-
def _ensure_dtype_type(value, dtype):
667+
def _ensure_dtype_type(value, dtype: DtypeObj):
664668
"""
665669
Ensure that the given value is an instance of the given dtype.
666670
@@ -787,7 +791,9 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> Tuple[DtypeObj,
787791

788792

789793
# TODO: try to make the Any in the return annotation more specific
790-
def infer_dtype_from_array(arr, pandas_dtype: bool = False) -> Tuple[DtypeObj, Any]:
794+
def infer_dtype_from_array(
795+
arr, pandas_dtype: bool = False
796+
) -> Tuple[DtypeObj, AnyArrayLike]:
791797
"""
792798
Infer the dtype from an array.
793799
@@ -875,7 +881,12 @@ def maybe_infer_dtype_type(element):
875881
return tipo
876882

877883

878-
def maybe_upcast(values, fill_value=np.nan, dtype=None, copy: bool = False):
884+
def maybe_upcast(
885+
values: ArrayLike,
886+
fill_value: Scalar = np.nan,
887+
dtype: Dtype = None,
888+
copy: bool = False,
889+
) -> Tuple[ArrayLike, Scalar]:
879890
"""
880891
Provide explicit type promotion and coercion.
881892
@@ -887,6 +898,13 @@ def maybe_upcast(values, fill_value=np.nan, dtype=None, copy: bool = False):
887898
dtype : if None, then use the dtype of the values, else coerce to this type
888899
copy : bool, default True
889900
If True always make a copy even if no upcast is required.
901+
902+
Returns
903+
-------
904+
values: ndarray or ExtensionArray
905+
the original array, possibly upcast
906+
fill_value:
907+
the fill value, possibly upcast
890908
"""
891909
if not is_scalar(fill_value) and not is_object_dtype(values.dtype):
892910
# We allow arbitrary fill values for object dtype
@@ -907,7 +925,7 @@ def maybe_upcast(values, fill_value=np.nan, dtype=None, copy: bool = False):
907925
return values, fill_value
908926

909927

910-
def invalidate_string_dtypes(dtype_set):
928+
def invalidate_string_dtypes(dtype_set: Set):
911929
"""
912930
Change string like dtypes to object for
913931
``DataFrame.select_dtypes()``.
@@ -929,7 +947,7 @@ def coerce_indexer_dtype(indexer, categories):
929947
return ensure_int64(indexer)
930948

931949

932-
def coerce_to_dtypes(result, dtypes):
950+
def coerce_to_dtypes(result: Sequence[Scalar], dtypes: Sequence[Dtype]) -> List[Scalar]:
933951
"""
934952
given a dtypes and a result set, coerce the result elements to the
935953
dtypes
@@ -959,7 +977,9 @@ def conv(r, dtype):
959977
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
960978

961979

962-
def astype_nansafe(arr, dtype, copy: bool = True, skipna: bool = False):
980+
def astype_nansafe(
981+
arr, dtype: DtypeObj, copy: bool = True, skipna: bool = False
982+
) -> ArrayLike:
963983
"""
964984
Cast the elements of an array to a given dtype a nan-safe manner.
965985
@@ -1063,7 +1083,9 @@ def astype_nansafe(arr, dtype, copy: bool = True, skipna: bool = False):
10631083
return arr.view(dtype)
10641084

10651085

1066-
def maybe_convert_objects(values: np.ndarray, convert_numeric: bool = True):
1086+
def maybe_convert_objects(
1087+
values: np.ndarray, convert_numeric: bool = True
1088+
) -> Union[np.ndarray, ABCDatetimeIndex]:
10671089
"""
10681090
If we have an object dtype array, try to coerce dates and/or numbers.
10691091
@@ -1184,7 +1206,7 @@ def soft_convert_objects(
11841206

11851207

11861208
def convert_dtypes(
1187-
input_array,
1209+
input_array: AnyArrayLike,
11881210
convert_string: bool = True,
11891211
convert_integer: bool = True,
11901212
convert_boolean: bool = True,
@@ -1250,7 +1272,7 @@ def convert_dtypes(
12501272
return inferred_dtype
12511273

12521274

1253-
def maybe_castable(arr) -> bool:
1275+
def maybe_castable(arr: np.ndarray) -> bool:
12541276
# return False to force a non-fastpath
12551277

12561278
# check datetime64[ns]/timedelta64[ns] are valid

0 commit comments

Comments
 (0)