diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 922dcd7e74aa0..f4caafb3a9fe7 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -1414,10 +1414,12 @@ def infer_dtype(value: object, skipna: bool = True) -> str: return "time" elif is_decimal(val): - return "decimal" + if is_decimal_array(values): + return "decimal" elif is_complex(val): - return "complex" + if is_complex_array(values): + return "complex" elif util.is_float_object(val): if is_float_array(values): @@ -1702,6 +1704,34 @@ cpdef bint is_float_array(ndarray values): return validator.validate(values) +cdef class ComplexValidator(Validator): + cdef inline bint is_value_typed(self, object value) except -1: + return ( + util.is_complex_object(value) + or (util.is_float_object(value) and is_nan(value)) + ) + + cdef inline bint is_array_typed(self) except -1: + return issubclass(self.dtype.type, np.complexfloating) + + +cdef bint is_complex_array(ndarray values): + cdef: + ComplexValidator validator = ComplexValidator(len(values), values.dtype) + return validator.validate(values) + + +cdef class DecimalValidator(Validator): + cdef inline bint is_value_typed(self, object value) except -1: + return is_decimal(value) + + +cdef bint is_decimal_array(ndarray values): + cdef: + DecimalValidator validator = DecimalValidator(len(values), values.dtype) + return validator.validate(values) + + cdef class StringValidator(Validator): cdef inline bint is_value_typed(self, object value) except -1: return isinstance(value, str) @@ -2546,8 +2576,6 @@ def fast_multiget(dict mapping, ndarray keys, default=np.nan): # kludge, for Series return np.empty(0, dtype='f8') - keys = getattr(keys, 'values', keys) - for i in range(n): val = keys[i] if val in mapping: diff --git a/pandas/tests/dtypes/test_inference.py b/pandas/tests/dtypes/test_inference.py index c6c54ccb357d5..7fa83eeac8400 100644 --- a/pandas/tests/dtypes/test_inference.py +++ b/pandas/tests/dtypes/test_inference.py @@ -709,6 +709,9 @@ def test_decimals(self): result = lib.infer_dtype(arr, skipna=True) assert result == "mixed" + result = lib.infer_dtype(arr[::-1], skipna=True) + assert result == "mixed" + arr = np.array([Decimal(1), Decimal("NaN"), Decimal(3)]) result = lib.infer_dtype(arr, skipna=True) assert result == "decimal" @@ -729,6 +732,9 @@ def test_complex(self, skipna): result = lib.infer_dtype(arr, skipna=skipna) assert result == "mixed" + result = lib.infer_dtype(arr[::-1], skipna=skipna) + assert result == "mixed" + # gets cast to complex on array construction arr = np.array([1, np.nan, 1 + 1j]) result = lib.infer_dtype(arr, skipna=skipna)