diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 5bf99301d9261..6e18f28b12e67 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -1380,9 +1380,34 @@ cdef object _try_infer_map(object dtype): return None -def infer_dtype(value: object, skipna: bool = True) -> str: - """ - Return a string label of the type of a scalar or list-like of values. +class InferredType(Enum): + EMPTY ="empty" + STRING = "string" + BYTES = "bytes" + FLOATING = "floating" + INFERRED = "inferred" + INTEGER = "integer" + INTEGER_NA = "integer-na" + MIXED_INTEGER = "mixed-integer" + MIXED_INTEGER_FLOAT = "mixed-integer-float" + DECIMAL = "decimal" + COMPLEX = "complex" + CATEGORICAL = "categorical" + BOOLEAN = "boolean" + DATETIME64 = "datetime64" + DATETIME = "datetime" + DATE = "date" + TIMEDELTA64 = "timedelta64" + TIMEDELTA = "timedelta" + TIME = "time" + PERIOD = "period" + MIXED = "mixed" + UKNOWN_ARRAY = "unknown-array" + + +def infer_dtype(value: object, skipna: bool = True) -> InferredType: + """ + Return a InferredType label of the type of a scalar or list-like of values. Parameters ---------- @@ -1392,7 +1417,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str: Returns ------- - str + InferredType Describing the common type of the input data. Results can include: @@ -1494,22 +1519,22 @@ def infer_dtype(value: object, skipna: bool = True) -> str: if util.is_array(value): values = value - elif hasattr(type(value), "inferred_type") and skipna is False: - # Index, use the cached attribute if possible, populate the cache otherwise - return value.inferred_type + #elif hasattr(type(value), "inferred_type") and skipna is False: + # # Index, use the cached attribute if possible, populate the cache otherwise + # return value.inferred_type elif hasattr(value, "dtype"): inferred = _try_infer_map(value.dtype) if inferred is not None: - return inferred + return InferredType.INFERRED elif not cnp.PyArray_DescrCheck(value.dtype): - return "unknown-array" + return InferredType.UKNOWN_ARRAY # Unwrap Series/Index values = np.asarray(value) else: if not isinstance(value, list): value = list(value) if not value: - return "empty" + return InferredType.EMPTY from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike values = construct_1d_object_array_from_listlike(value) @@ -1517,7 +1542,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str: inferred = _try_infer_map(values.dtype) if inferred is not None: # Anything other than object-dtype should return here. - return inferred + return InferredType.INFERRED if values.descr.type_num != NPY_OBJECT: # i.e. values.dtype != np.object_ @@ -1526,7 +1551,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str: n = cnp.PyArray_SIZE(values) if n == 0: - return "empty" + return InferredType.EMPTY # Iterate until we find our first valid value. We will use this # value to decide which of the is_foo_array functions to call. @@ -1549,82 +1574,82 @@ def infer_dtype(value: object, skipna: bool = True) -> str: # if all values are nan/NaT if seen_val is False and seen_pdnat is True: - return "datetime" + return InferredType.DATETIME # float/object nan is handled in latter logic if seen_val is False and skipna: - return "empty" + return InferredType.EMPTY if util.is_datetime64_object(val): if is_datetime64_array(values, skipna=skipna): - return "datetime64" + return InferredType.DATETIME64 elif is_timedelta(val): if is_timedelta_or_timedelta64_array(values, skipna=skipna): - return "timedelta" + return InferredType.TIMEDELTA elif util.is_integer_object(val): # ordering matters here; this check must come after the is_timedelta # check otherwise numpy timedelta64 objects would come through here if is_integer_array(values, skipna=skipna): - return "integer" + return InferredType.INTEGER elif is_integer_float_array(values, skipna=skipna): if is_integer_na_array(values, skipna=skipna): - return "integer-na" + return InferredType.INTEGER_NA else: - return "mixed-integer-float" - return "mixed-integer" + return InferredType.MIXED_INTEGER_FLOAT + return InferredType.MIXED_INTEGER elif PyDateTime_Check(val): if is_datetime_array(values, skipna=skipna): - return "datetime" + return InferredType.DATETIME elif is_date_array(values, skipna=skipna): - return "date" + return InferredType.DATE elif PyDate_Check(val): if is_date_array(values, skipna=skipna): - return "date" + return InferredType.DATE elif PyTime_Check(val): if is_time_array(values, skipna=skipna): - return "time" + return InferredType.TIME elif is_decimal(val): if is_decimal_array(values, skipna=skipna): - return "decimal" + return InferredType.DECIMAL elif util.is_complex_object(val): if is_complex_array(values): - return "complex" + return InferredType.COMPLEX elif util.is_float_object(val): if is_float_array(values): - return "floating" + return InferredType.FLOATING elif is_integer_float_array(values, skipna=skipna): if is_integer_na_array(values, skipna=skipna): - return "integer-na" + return InferredType.INTEGER_NA else: - return "mixed-integer-float" + return InferredType.MIXED_INTEGER_FLOAT elif util.is_bool_object(val): if is_bool_array(values, skipna=skipna): - return "boolean" + return InferredType.BOOLEAN elif isinstance(val, str): if is_string_array(values, skipna=skipna): - return "string" + return InferredType.STRING elif isinstance(val, bytes): if is_bytes_array(values, skipna=skipna): - return "bytes" + return InferredType.BYTES elif is_period_object(val): if is_period_array(values, skipna=skipna): - return "period" + return InferredType.PERIOD elif is_interval(val): if is_interval_array(values): - return "interval" + return InferredType.INTERVAL cnp.PyArray_ITER_RESET(it) for i in range(n): @@ -1632,9 +1657,9 @@ def infer_dtype(value: object, skipna: bool = True) -> str: PyArray_ITER_NEXT(it) if util.is_integer_object(val): - return "mixed-integer" + return InferredType.MIXED_INTEGER - return "mixed" + return InferredType.MIXED cdef bint is_timedelta(object o): diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 08a3a1fb70bac..be2b4d0bbc79e 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -220,7 +220,7 @@ def _ensure_arraylike(values) -> ArrayLike: """ if not is_array_like(values): inferred = lib.infer_dtype(values, skipna=False) - if inferred in ["mixed", "string", "mixed-integer"]: + if inferred in [InferredType.MIXED, InferredType.STRING, InferredType.MIXED_INTEGER]: # "mixed-integer" to ensure we do not cast ["ss", 42] to str GH#22160 if isinstance(values, tuple): values = list(values) @@ -1535,7 +1535,7 @@ def safe_sort( if ( not isinstance(values.dtype, ExtensionDtype) - and lib.infer_dtype(values, skipna=False) == "mixed-integer" + and lib.infer_dtype(values, skipna=False) == InferredType.MIXED_INTEGER ): ordered = _sort_mixed(values) else: diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index b01b8e91a2cc7..534bd887ea8a2 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -747,20 +747,20 @@ def isin(self, values) -> npt.NDArray[np.bool_]: if not isinstance(values, type(self)): inferable = [ - "timedelta", - "timedelta64", - "datetime", - "datetime64", - "date", - "period", + InferredType.TIMEDELTA, + InferredType.TIMEDELTA64, + InferredType.DATETIME, + InferredType.DATETIME64, + InferredType.DATE, + InferredType.PERIOD, ] if values.dtype == object: inferred = lib.infer_dtype(values, skipna=False) if inferred not in inferable: - if inferred == "string": + if inferred == InferredType.STRING: pass - elif "mixed" in inferred: + elif InferredType.MIXED in inferred: return isin(self.astype(object), values) else: return np.zeros(self.shape, dtype=bool) diff --git a/pandas/core/construction.py b/pandas/core/construction.py index 2208ae07fe30f..8109a25bf0f2c 100644 --- a/pandas/core/construction.py +++ b/pandas/core/construction.py @@ -328,14 +328,14 @@ def array( if dtype is None: inferred_dtype = lib.infer_dtype(data, skipna=True) - if inferred_dtype == "period": + if inferred_dtype == InferredType.PERIOD: period_data = cast(Union[Sequence[Optional[Period]], AnyArrayLike], data) return PeriodArray._from_sequence(period_data, copy=copy) - elif inferred_dtype == "interval": + elif inferred_dtype == InferredType.INTERVAL: return IntervalArray(data, copy=copy) - elif inferred_dtype.startswith("datetime"): + elif inferred_dtype in [InferredType.DATETIME, InferredType.DATETIME64]: # datetime, datetime64 try: return DatetimeArray._from_sequence(data, copy=copy) @@ -343,26 +343,26 @@ def array( # Mixture of timezones, fall back to PandasArray pass - elif inferred_dtype.startswith("timedelta"): + elif inferred_dtype in [InferredType.TIMEDELTA, InferredType.TIMEDELTA64]: # timedelta, timedelta64 return TimedeltaArray._from_sequence(data, copy=copy) - elif inferred_dtype == "string": + elif inferred_dtype == InferredType.STRING: # StringArray/ArrowStringArray depending on pd.options.mode.string_storage return StringDtype().construct_array_type()._from_sequence(data, copy=copy) - elif inferred_dtype == "integer": + elif inferred_dtype == InferredType.INTEGER: return IntegerArray._from_sequence(data, copy=copy) elif ( - inferred_dtype in ("floating", "mixed-integer-float") + inferred_dtype in [InferredType.FLOATING, InferredType.MIXED_INTEGER_FLOAT] and getattr(data, "dtype", None) != np.float16 ): # GH#44715 Exclude np.float16 bc FloatingArray does not support it; # we will fall back to PandasArray. return FloatingArray._from_sequence(data, copy=copy) - elif inferred_dtype == "boolean": + elif inferred_dtype == InferredType.BOOLEAN: return BooleanArray._from_sequence(data, copy=copy) # Pandas overrides NumPy for diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 0eb4a42060416..924489bfa20c9 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -2659,7 +2659,7 @@ def _encode_strings(self) -> None: dtype = column.dtype if dtype.type is np.object_: inferred_dtype = infer_dtype(column, skipna=True) - if not ((inferred_dtype == "string") or len(column) == 0): + if not ((inferred_dtype == InferredType.STRING) or len(column) == 0): col = column.name raise ValueError( f"""\ diff --git a/pandas/plotting/_matplotlib/converter.py b/pandas/plotting/_matplotlib/converter.py index 9b0fe99e2d61e..67f6e82ff64b4 100644 --- a/pandas/plotting/_matplotlib/converter.py +++ b/pandas/plotting/_matplotlib/converter.py @@ -249,7 +249,7 @@ def _convert_1d(values, units, axis): return values.asfreq(axis.freq).asi8 elif isinstance(values, Index): return values.map(lambda x: get_datevalue(x, axis.freq)) - elif lib.infer_dtype(values, skipna=False) == "period": + elif lib.infer_dtype(values, skipna=False) == InferredType.PERIOD: # https://github.com/pandas-dev/pandas/issues/24304 # convert ndarray[period] -> PeriodIndex return PeriodIndex(values, freq=axis.freq).asi8 diff --git a/pandas/tests/extension/base/dtype.py b/pandas/tests/extension/base/dtype.py index 392a75f8a69a7..b121e98bfc870 100644 --- a/pandas/tests/extension/base/dtype.py +++ b/pandas/tests/extension/base/dtype.py @@ -120,6 +120,6 @@ def test_get_common_dtype(self, dtype): def test_infer_dtype(self, data, data_missing, skipna): # only testing that this works without raising an error res = infer_dtype(data, skipna=skipna) - assert isinstance(res, str) + assert isinstance(res, InferredType) res = infer_dtype(data_missing, skipna=skipna) - assert isinstance(res, str) + assert isinstance(res, InferredType)