Skip to content

REF: Introduce enum class InferredType #52517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 61 additions & 36 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1380,9 +1380,34 @@ cdef object _try_infer_map(object dtype):
return None


def infer_dtype(value: object, skipna: bool = True) -> str:
"""
Return a string label of the type of a scalar or list-like of values.
class InferredType(Enum):
EMPTY ="empty"
STRING = "string"
BYTES = "bytes"
FLOATING = "floating"
INFERRED = "inferred"
INTEGER = "integer"
INTEGER_NA = "integer-na"
MIXED_INTEGER = "mixed-integer"
MIXED_INTEGER_FLOAT = "mixed-integer-float"
DECIMAL = "decimal"
COMPLEX = "complex"
CATEGORICAL = "categorical"
BOOLEAN = "boolean"
DATETIME64 = "datetime64"
DATETIME = "datetime"
DATE = "date"
TIMEDELTA64 = "timedelta64"
TIMEDELTA = "timedelta"
TIME = "time"
PERIOD = "period"
MIXED = "mixed"
UKNOWN_ARRAY = "unknown-array"


def infer_dtype(value: object, skipna: bool = True) -> InferredType:
"""
Return a InferredType label of the type of a scalar or list-like of values.

Parameters
----------
Expand All @@ -1392,7 +1417,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:

Returns
-------
str
InferredType
Describing the common type of the input data.
Results can include:

Expand Down Expand Up @@ -1494,30 +1519,30 @@ def infer_dtype(value: object, skipna: bool = True) -> str:

if util.is_array(value):
values = value
elif hasattr(type(value), "inferred_type") and skipna is False:
# Index, use the cached attribute if possible, populate the cache otherwise
return value.inferred_type
#elif hasattr(type(value), "inferred_type") and skipna is False:
# # Index, use the cached attribute if possible, populate the cache otherwise
# return value.inferred_type
elif hasattr(value, "dtype"):
inferred = _try_infer_map(value.dtype)
if inferred is not None:
return inferred
return InferredType.INFERRED
elif not cnp.PyArray_DescrCheck(value.dtype):
return "unknown-array"
return InferredType.UKNOWN_ARRAY
# Unwrap Series/Index
values = np.asarray(value)
else:
if not isinstance(value, list):
value = list(value)
if not value:
return "empty"
return InferredType.EMPTY

from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
values = construct_1d_object_array_from_listlike(value)

inferred = _try_infer_map(values.dtype)
if inferred is not None:
# Anything other than object-dtype should return here.
return inferred
return InferredType.INFERRED

if values.descr.type_num != NPY_OBJECT:
# i.e. values.dtype != np.object_
Expand All @@ -1526,7 +1551,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:

n = cnp.PyArray_SIZE(values)
if n == 0:
return "empty"
return InferredType.EMPTY

# Iterate until we find our first valid value. We will use this
# value to decide which of the is_foo_array functions to call.
Expand All @@ -1549,92 +1574,92 @@ def infer_dtype(value: object, skipna: bool = True) -> str:

# if all values are nan/NaT
if seen_val is False and seen_pdnat is True:
return "datetime"
return InferredType.DATETIME
# float/object nan is handled in latter logic
if seen_val is False and skipna:
return "empty"
return InferredType.EMPTY

if util.is_datetime64_object(val):
if is_datetime64_array(values, skipna=skipna):
return "datetime64"
return InferredType.DATETIME64

elif is_timedelta(val):
if is_timedelta_or_timedelta64_array(values, skipna=skipna):
return "timedelta"
return InferredType.TIMEDELTA

elif util.is_integer_object(val):
# ordering matters here; this check must come after the is_timedelta
# check otherwise numpy timedelta64 objects would come through here

if is_integer_array(values, skipna=skipna):
return "integer"
return InferredType.INTEGER
elif is_integer_float_array(values, skipna=skipna):
if is_integer_na_array(values, skipna=skipna):
return "integer-na"
return InferredType.INTEGER_NA
else:
return "mixed-integer-float"
return "mixed-integer"
return InferredType.MIXED_INTEGER_FLOAT
return InferredType.MIXED_INTEGER

elif PyDateTime_Check(val):
if is_datetime_array(values, skipna=skipna):
return "datetime"
return InferredType.DATETIME
elif is_date_array(values, skipna=skipna):
return "date"
return InferredType.DATE

elif PyDate_Check(val):
if is_date_array(values, skipna=skipna):
return "date"
return InferredType.DATE

elif PyTime_Check(val):
if is_time_array(values, skipna=skipna):
return "time"
return InferredType.TIME

elif is_decimal(val):
if is_decimal_array(values, skipna=skipna):
return "decimal"
return InferredType.DECIMAL

elif util.is_complex_object(val):
if is_complex_array(values):
return "complex"
return InferredType.COMPLEX

elif util.is_float_object(val):
if is_float_array(values):
return "floating"
return InferredType.FLOATING
elif is_integer_float_array(values, skipna=skipna):
if is_integer_na_array(values, skipna=skipna):
return "integer-na"
return InferredType.INTEGER_NA
else:
return "mixed-integer-float"
return InferredType.MIXED_INTEGER_FLOAT

elif util.is_bool_object(val):
if is_bool_array(values, skipna=skipna):
return "boolean"
return InferredType.BOOLEAN

elif isinstance(val, str):
if is_string_array(values, skipna=skipna):
return "string"
return InferredType.STRING

elif isinstance(val, bytes):
if is_bytes_array(values, skipna=skipna):
return "bytes"
return InferredType.BYTES

elif is_period_object(val):
if is_period_array(values, skipna=skipna):
return "period"
return InferredType.PERIOD

elif is_interval(val):
if is_interval_array(values):
return "interval"
return InferredType.INTERVAL

cnp.PyArray_ITER_RESET(it)
for i in range(n):
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
PyArray_ITER_NEXT(it)

if util.is_integer_object(val):
return "mixed-integer"
return InferredType.MIXED_INTEGER

return "mixed"
return InferredType.MIXED


cdef bint is_timedelta(object o):
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _ensure_arraylike(values) -> ArrayLike:
"""
if not is_array_like(values):
inferred = lib.infer_dtype(values, skipna=False)
if inferred in ["mixed", "string", "mixed-integer"]:
if inferred in [InferredType.MIXED, InferredType.STRING, InferredType.MIXED_INTEGER]:
# "mixed-integer" to ensure we do not cast ["ss", 42] to str GH#22160
if isinstance(values, tuple):
values = list(values)
Expand Down Expand Up @@ -1535,7 +1535,7 @@ def safe_sort(

if (
not isinstance(values.dtype, ExtensionDtype)
and lib.infer_dtype(values, skipna=False) == "mixed-integer"
and lib.infer_dtype(values, skipna=False) == InferredType.MIXED_INTEGER
):
ordered = _sort_mixed(values)
else:
Expand Down
16 changes: 8 additions & 8 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,20 +747,20 @@ def isin(self, values) -> npt.NDArray[np.bool_]:

if not isinstance(values, type(self)):
inferable = [
"timedelta",
"timedelta64",
"datetime",
"datetime64",
"date",
"period",
InferredType.TIMEDELTA,
InferredType.TIMEDELTA64,
InferredType.DATETIME,
InferredType.DATETIME64,
InferredType.DATE,
InferredType.PERIOD,
]
if values.dtype == object:
inferred = lib.infer_dtype(values, skipna=False)
if inferred not in inferable:
if inferred == "string":
if inferred == InferredType.STRING:
pass

elif "mixed" in inferred:
elif InferredType.MIXED in inferred:
return isin(self.astype(object), values)
else:
return np.zeros(self.shape, dtype=bool)
Expand Down
16 changes: 8 additions & 8 deletions pandas/core/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,41 +328,41 @@ def array(

if dtype is None:
inferred_dtype = lib.infer_dtype(data, skipna=True)
if inferred_dtype == "period":
if inferred_dtype == InferredType.PERIOD:
period_data = cast(Union[Sequence[Optional[Period]], AnyArrayLike], data)
return PeriodArray._from_sequence(period_data, copy=copy)

elif inferred_dtype == "interval":
elif inferred_dtype == InferredType.INTERVAL:
return IntervalArray(data, copy=copy)

elif inferred_dtype.startswith("datetime"):
elif inferred_dtype in [InferredType.DATETIME, InferredType.DATETIME64]:
# datetime, datetime64
try:
return DatetimeArray._from_sequence(data, copy=copy)
except ValueError:
# Mixture of timezones, fall back to PandasArray
pass

elif inferred_dtype.startswith("timedelta"):
elif inferred_dtype in [InferredType.TIMEDELTA, InferredType.TIMEDELTA64]:
# timedelta, timedelta64
return TimedeltaArray._from_sequence(data, copy=copy)

elif inferred_dtype == "string":
elif inferred_dtype == InferredType.STRING:
# StringArray/ArrowStringArray depending on pd.options.mode.string_storage
return StringDtype().construct_array_type()._from_sequence(data, copy=copy)

elif inferred_dtype == "integer":
elif inferred_dtype == InferredType.INTEGER:
return IntegerArray._from_sequence(data, copy=copy)

elif (
inferred_dtype in ("floating", "mixed-integer-float")
inferred_dtype in [InferredType.FLOATING, InferredType.MIXED_INTEGER_FLOAT]
and getattr(data, "dtype", None) != np.float16
):
# GH#44715 Exclude np.float16 bc FloatingArray does not support it;
# we will fall back to PandasArray.
return FloatingArray._from_sequence(data, copy=copy)

elif inferred_dtype == "boolean":
elif inferred_dtype == InferredType.BOOLEAN:
return BooleanArray._from_sequence(data, copy=copy)

# Pandas overrides NumPy for
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,7 +2659,7 @@ def _encode_strings(self) -> None:
dtype = column.dtype
if dtype.type is np.object_:
inferred_dtype = infer_dtype(column, skipna=True)
if not ((inferred_dtype == "string") or len(column) == 0):
if not ((inferred_dtype == InferredType.STRING) or len(column) == 0):
col = column.name
raise ValueError(
f"""\
Expand Down
2 changes: 1 addition & 1 deletion pandas/plotting/_matplotlib/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _convert_1d(values, units, axis):
return values.asfreq(axis.freq).asi8
elif isinstance(values, Index):
return values.map(lambda x: get_datevalue(x, axis.freq))
elif lib.infer_dtype(values, skipna=False) == "period":
elif lib.infer_dtype(values, skipna=False) == InferredType.PERIOD:
# https://github.com/pandas-dev/pandas/issues/24304
# convert ndarray[period] -> PeriodIndex
return PeriodIndex(values, freq=axis.freq).asi8
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/base/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,6 @@ def test_get_common_dtype(self, dtype):
def test_infer_dtype(self, data, data_missing, skipna):
# only testing that this works without raising an error
res = infer_dtype(data, skipna=skipna)
assert isinstance(res, str)
assert isinstance(res, InferredType)
res = infer_dtype(data_missing, skipna=skipna)
assert isinstance(res, str)
assert isinstance(res, InferredType)