Skip to content

Commit 7b0835e

Browse files
committed
REF: Introduce enum class InferredType
1 parent fe415f5 commit 7b0835e

File tree

7 files changed

+83
-58
lines changed

7 files changed

+83
-58
lines changed

pandas/_libs/lib.pyx

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,9 +1380,34 @@ cdef object _try_infer_map(object dtype):
13801380
return None
13811381

13821382

1383-
def infer_dtype(value: object, skipna: bool = True) -> str:
1384-
"""
1385-
Return a string label of the type of a scalar or list-like of values.
1383+
class InferredType(Enum):
1384+
EMPTY ="empty"
1385+
STRING = "string"
1386+
BYTES = "bytes"
1387+
FLOATING = "floating"
1388+
INFERRED = "inferred"
1389+
INTEGER = "integer"
1390+
INTEGER_NA = "integer-na"
1391+
MIXED_INTEGER = "mixed-integer"
1392+
MIXED_INTEGER_FLOAT = "mixed-integer-float"
1393+
DECIMAL = "decimal"
1394+
COMPLEX = "complex"
1395+
CATEGORICAL = "categorical"
1396+
BOOLEAN = boolean"
1397+
DATETIME64 = "datetime64"
1398+
DATETIME = "datetime"
1399+
DATE = "date"
1400+
TIMEDELTA64 = "timedelta64"
1401+
TIMEDELTA = "timedelta"
1402+
TIME = "time"
1403+
PERIOD = "period"
1404+
MIXED = "mixed"
1405+
UKNOWN_ARRAY = "unknown-array"
1406+
1407+
1408+
def infer_dtype(value: object, skipna: bool = True) -> InferredType:
1409+
"""
1410+
Return a InferredType label of the type of a scalar or list-like of values.
13861411

13871412
Parameters
13881413
----------
@@ -1392,7 +1417,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
13921417

13931418
Returns
13941419
-------
1395-
str
1420+
InferredType
13961421
Describing the common type of the input data.
13971422
Results can include:
13981423

@@ -1494,30 +1519,30 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
14941519

14951520
if util.is_array(value):
14961521
values = value
1497-
elif hasattr(type(value), "inferred_type") and skipna is False:
1498-
# Index, use the cached attribute if possible, populate the cache otherwise
1499-
return value.inferred_type
1522+
#elif hasattr(type(value), "inferred_type") and skipna is False:
1523+
# # Index, use the cached attribute if possible, populate the cache otherwise
1524+
# return value.inferred_type
15001525
elif hasattr(value, "dtype"):
15011526
inferred = _try_infer_map(value.dtype)
15021527
if inferred is not None:
1503-
return inferred
1528+
return InferredType.INFERRED
15041529
elif not cnp.PyArray_DescrCheck(value.dtype):
1505-
return "unknown-array"
1530+
return InferredType.UKNOWN_ARRAY
15061531
# Unwrap Series/Index
15071532
values = np.asarray(value)
15081533
else:
15091534
if not isinstance(value, list):
15101535
value = list(value)
15111536
if not value:
1512-
return "empty"
1537+
return InferredType.EMPTY
15131538

15141539
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
15151540
values = construct_1d_object_array_from_listlike(value)
15161541

15171542
inferred = _try_infer_map(values.dtype)
15181543
if inferred is not None:
15191544
# Anything other than object-dtype should return here.
1520-
return inferred
1545+
return InferredType.INFERRED
15211546

15221547
if values.descr.type_num != NPY_OBJECT:
15231548
# i.e. values.dtype != np.object_
@@ -1526,7 +1551,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
15261551

15271552
n = cnp.PyArray_SIZE(values)
15281553
if n == 0:
1529-
return "empty"
1554+
return InferredType.EMPTY
15301555

15311556
# Iterate until we find our first valid value. We will use this
15321557
# value to decide which of the is_foo_array functions to call.
@@ -1549,92 +1574,92 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
15491574

15501575
# if all values are nan/NaT
15511576
if seen_val is False and seen_pdnat is True:
1552-
return "datetime"
1577+
return InferredType.DATETIME
15531578
# float/object nan is handled in latter logic
15541579
if seen_val is False and skipna:
1555-
return "empty"
1580+
return InferredType.EMPTY
15561581

15571582
if util.is_datetime64_object(val):
15581583
if is_datetime64_array(values, skipna=skipna):
1559-
return "datetime64"
1584+
return InferredType.DATETIME64
15601585

15611586
elif is_timedelta(val):
15621587
if is_timedelta_or_timedelta64_array(values, skipna=skipna):
1563-
return "timedelta"
1588+
return InferredType.TIMEDELTA
15641589

15651590
elif util.is_integer_object(val):
15661591
# ordering matters here; this check must come after the is_timedelta
15671592
# check otherwise numpy timedelta64 objects would come through here
15681593

15691594
if is_integer_array(values, skipna=skipna):
1570-
return "integer"
1595+
return InferredType.INTEGER
15711596
elif is_integer_float_array(values, skipna=skipna):
15721597
if is_integer_na_array(values, skipna=skipna):
1573-
return "integer-na"
1598+
return InferredType.INTEGER_NA
15741599
else:
1575-
return "mixed-integer-float"
1576-
return "mixed-integer"
1600+
return InferredType.MIXED_INTEGER_FLOAT
1601+
return InferredType.MIXED_INTEGER
15771602

15781603
elif PyDateTime_Check(val):
15791604
if is_datetime_array(values, skipna=skipna):
1580-
return "datetime"
1605+
return InferredType.DATETIME
15811606
elif is_date_array(values, skipna=skipna):
1582-
return "date"
1607+
return InferredType.DATE
15831608

15841609
elif PyDate_Check(val):
15851610
if is_date_array(values, skipna=skipna):
1586-
return "date"
1611+
return InferredType.DATE
15871612

15881613
elif PyTime_Check(val):
15891614
if is_time_array(values, skipna=skipna):
1590-
return "time"
1615+
return InferredType.TIME
15911616

15921617
elif is_decimal(val):
15931618
if is_decimal_array(values, skipna=skipna):
1594-
return "decimal"
1619+
return InferredType.DECIMAL
15951620

15961621
elif util.is_complex_object(val):
15971622
if is_complex_array(values):
1598-
return "complex"
1623+
return InferredType.COMPLEX
15991624

16001625
elif util.is_float_object(val):
16011626
if is_float_array(values):
1602-
return "floating"
1627+
return InferredType.FLOATING
16031628
elif is_integer_float_array(values, skipna=skipna):
16041629
if is_integer_na_array(values, skipna=skipna):
1605-
return "integer-na"
1630+
return InferredType.INTEGER_NA
16061631
else:
1607-
return "mixed-integer-float"
1632+
return InferredType.MIXED_INTEGER_FLOAT
16081633

16091634
elif util.is_bool_object(val):
16101635
if is_bool_array(values, skipna=skipna):
1611-
return "boolean"
1636+
return InferredType.BOOLEAN
16121637

16131638
elif isinstance(val, str):
16141639
if is_string_array(values, skipna=skipna):
1615-
return "string"
1640+
return InferredType.STRING
16161641

16171642
elif isinstance(val, bytes):
16181643
if is_bytes_array(values, skipna=skipna):
1619-
return "bytes"
1644+
return InferredType.BYTES
16201645

16211646
elif is_period_object(val):
16221647
if is_period_array(values, skipna=skipna):
1623-
return "period"
1648+
return InferredType.PERIOD
16241649

16251650
elif is_interval(val):
16261651
if is_interval_array(values):
1627-
return "interval"
1652+
return InferredType.INTERVAL
16281653

16291654
cnp.PyArray_ITER_RESET(it)
16301655
for i in range(n):
16311656
val = PyArray_GETITEM(values, PyArray_ITER_DATA(it))
16321657
PyArray_ITER_NEXT(it)
16331658

16341659
if util.is_integer_object(val):
1635-
return "mixed-integer"
1660+
return InferredType.MIXED_INTEGER
16361661

1637-
return "mixed"
1662+
return InferredType.MIXED
16381663

16391664

16401665
cdef bint is_timedelta(object o):

pandas/core/algorithms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _ensure_arraylike(values) -> ArrayLike:
220220
"""
221221
if not is_array_like(values):
222222
inferred = lib.infer_dtype(values, skipna=False)
223-
if inferred in ["mixed", "string", "mixed-integer"]:
223+
if inferred in [InferredType.MIXED, InferredType.STRING, InferredType.MIXED_INTEGER]:
224224
# "mixed-integer" to ensure we do not cast ["ss", 42] to str GH#22160
225225
if isinstance(values, tuple):
226226
values = list(values)
@@ -1535,7 +1535,7 @@ def safe_sort(
15351535

15361536
if (
15371537
not isinstance(values.dtype, ExtensionDtype)
1538-
and lib.infer_dtype(values, skipna=False) == "mixed-integer"
1538+
and lib.infer_dtype(values, skipna=False) == InferredType.MIXED_INTEGER
15391539
):
15401540
ordered = _sort_mixed(values)
15411541
else:

pandas/core/arrays/datetimelike.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -747,20 +747,20 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
747747

748748
if not isinstance(values, type(self)):
749749
inferable = [
750-
"timedelta",
751-
"timedelta64",
752-
"datetime",
753-
"datetime64",
754-
"date",
755-
"period",
750+
InferredType.TIMEDELTA,
751+
InferredType.TIMEDELTA64,
752+
InferredType.DATETIME,
753+
InferredType.DATETIME64,
754+
InferredType.DATE,
755+
InferredType.PERIOD,
756756
]
757757
if values.dtype == object:
758758
inferred = lib.infer_dtype(values, skipna=False)
759759
if inferred not in inferable:
760-
if inferred == "string":
760+
if inferred == InferredType.STRING:
761761
pass
762762

763-
elif "mixed" in inferred:
763+
elif InferredType.MIXED in inferred:
764764
return isin(self.astype(object), values)
765765
else:
766766
return np.zeros(self.shape, dtype=bool)

pandas/core/construction.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,41 +328,41 @@ def array(
328328

329329
if dtype is None:
330330
inferred_dtype = lib.infer_dtype(data, skipna=True)
331-
if inferred_dtype == "period":
331+
if inferred_dtype == InferredType.PERIOD:
332332
period_data = cast(Union[Sequence[Optional[Period]], AnyArrayLike], data)
333333
return PeriodArray._from_sequence(period_data, copy=copy)
334334

335-
elif inferred_dtype == "interval":
335+
elif inferred_dtype == InferredType.INTERVAL:
336336
return IntervalArray(data, copy=copy)
337337

338-
elif inferred_dtype.startswith("datetime"):
338+
elif inferred_dtype in [InferredType.DATETIME, InferredType.DATETIME64]:
339339
# datetime, datetime64
340340
try:
341341
return DatetimeArray._from_sequence(data, copy=copy)
342342
except ValueError:
343343
# Mixture of timezones, fall back to PandasArray
344344
pass
345345

346-
elif inferred_dtype.startswith("timedelta"):
346+
elif inferred_dtype in [InferredType.TIMEDELTA, InferredType.TIMEDELTA64]:
347347
# timedelta, timedelta64
348348
return TimedeltaArray._from_sequence(data, copy=copy)
349349

350-
elif inferred_dtype == "string":
350+
elif inferred_dtype == InferredType.STRING:
351351
# StringArray/ArrowStringArray depending on pd.options.mode.string_storage
352352
return StringDtype().construct_array_type()._from_sequence(data, copy=copy)
353353

354-
elif inferred_dtype == "integer":
354+
elif inferred_dtype == InferredType.INTEGER:
355355
return IntegerArray._from_sequence(data, copy=copy)
356356

357357
elif (
358-
inferred_dtype in ("floating", "mixed-integer-float")
358+
inferred_dtype in [InferredType.FLOATING, InferredType.MIXED_INTEGER_FLOAT]
359359
and getattr(data, "dtype", None) != np.float16
360360
):
361361
# GH#44715 Exclude np.float16 bc FloatingArray does not support it;
362362
# we will fall back to PandasArray.
363363
return FloatingArray._from_sequence(data, copy=copy)
364364

365-
elif inferred_dtype == "boolean":
365+
elif inferred_dtype == InferredType.BOOLEAN:
366366
return BooleanArray._from_sequence(data, copy=copy)
367367

368368
# Pandas overrides NumPy for

pandas/io/stata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2659,7 +2659,7 @@ def _encode_strings(self) -> None:
26592659
dtype = column.dtype
26602660
if dtype.type is np.object_:
26612661
inferred_dtype = infer_dtype(column, skipna=True)
2662-
if not ((inferred_dtype == "string") or len(column) == 0):
2662+
if not ((inferred_dtype == InferredType.STRING) or len(column) == 0):
26632663
col = column.name
26642664
raise ValueError(
26652665
f"""\

pandas/plotting/_matplotlib/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _convert_1d(values, units, axis):
249249
return values.asfreq(axis.freq).asi8
250250
elif isinstance(values, Index):
251251
return values.map(lambda x: get_datevalue(x, axis.freq))
252-
elif lib.infer_dtype(values, skipna=False) == "period":
252+
elif lib.infer_dtype(values, skipna=False) == InferredType.PERIOD:
253253
# https://github.com/pandas-dev/pandas/issues/24304
254254
# convert ndarray[period] -> PeriodIndex
255255
return PeriodIndex(values, freq=axis.freq).asi8

pandas/tests/extension/base/dtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,6 @@ def test_get_common_dtype(self, dtype):
120120
def test_infer_dtype(self, data, data_missing, skipna):
121121
# only testing that this works without raising an error
122122
res = infer_dtype(data, skipna=skipna)
123-
assert isinstance(res, str)
123+
assert isinstance(res, InferredType)
124124
res = infer_dtype(data_missing, skipna=skipna)
125-
assert isinstance(res, str)
125+
assert isinstance(res, InferredType)

0 commit comments

Comments
 (0)