Skip to content

Commit 9d5a6c7

Browse files
committed
WIP:Make python engine support EA types when reading CSVs
The C engine is the real WIP.
1 parent 5aea205 commit 9d5a6c7

File tree

8 files changed

+99
-10
lines changed

8 files changed

+99
-10
lines changed

pandas/_libs/parsers.pyx

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,11 @@ cdef class TextReader:
12311231

12321232
if result is not None and dtype != 'int64':
12331233
if is_extension_array_dtype(dtype):
1234-
result = result.astype(dtype.numpy_dtype)
1234+
try:
1235+
result = dtype.construct_array_type()._from_sequence(
1236+
result, dtype=dtype)
1237+
except Exception as e:
1238+
raise
12351239
else:
12361240
result = result.astype(dtype)
12371241

@@ -1243,7 +1247,11 @@ cdef class TextReader:
12431247

12441248
if result is not None and dtype != 'float64':
12451249
if is_extension_array_dtype(dtype):
1246-
result = result.astype(dtype.numpy_dtype)
1250+
try:
1251+
result = dtype.construct_array_type()._from_sequence(
1252+
result)
1253+
except Exception as e:
1254+
raise
12471255
else:
12481256
result = result.astype(dtype)
12491257
return result, na_count

pandas/core/arrays/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,27 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
123123
"""
124124
raise AbstractMethodError(cls)
125125

126+
@classmethod
127+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
128+
"""Construct a new ExtensionArray from a sequence of scalars.
129+
130+
Parameters
131+
----------
132+
strings : Sequence
133+
Each element will be an instance of the scalar type for this
134+
array, ``cls.dtype.type``.
135+
dtype : dtype, optional
136+
Construct for this particular dtype. This should be a Dtype
137+
compatible with the ExtensionArray.
138+
copy : boolean, default False
139+
If True, copy the underlying data.
140+
141+
Returns
142+
-------
143+
ExtensionArray
144+
"""
145+
raise AbstractMethodError(cls)
146+
126147
@classmethod
127148
def _from_factorized(cls, values, original):
128149
"""

pandas/core/arrays/integer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def coerce_to_array(values, dtype, mask=None, copy=False):
154154
dtype = dtype.lower()
155155
if not issubclass(type(dtype), _IntegerDtype):
156156
try:
157-
dtype = _dtypes[str(np.dtype(dtype))]
157+
dtype = _dtypes[str(np.dtype(dtype.name.lower()))]
158158
except KeyError:
159159
raise ValueError("invalid dtype specified {}".format(dtype))
160160

@@ -261,6 +261,10 @@ def __init__(self, values, mask, copy=False):
261261
def _from_sequence(cls, scalars, dtype=None, copy=False):
262262
return integer_array(scalars, dtype=dtype, copy=copy)
263263

264+
@classmethod
265+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
266+
return cls._from_sequence([int(x) for x in strings], dtype, copy)
267+
264268
@classmethod
265269
def _from_factorized(cls, values, original):
266270
return integer_array(values, dtype=original.dtype)

pandas/core/dtypes/cast.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,22 @@ def astype_nansafe(arr, dtype, copy=True, skipna=False):
615615

616616
# dispatch on extension dtype if needed
617617
if is_extension_array_dtype(dtype):
618-
return dtype.construct_array_type()._from_sequence(
619-
arr, dtype=dtype, copy=copy)
618+
if is_object_dtype(arr):
619+
try:
620+
return dtype.construct_array_type()._from_sequence_of_strings(
621+
arr, dtype=dtype, copy=copy)
622+
except AttributeError:
623+
dtype = pandas_dtype(dtype)
624+
return dtype.construct_array_type()._from_sequence_of_strings(
625+
arr, dtype=dtype, copy=copy)
626+
else:
627+
try:
628+
return dtype.construct_array_type()._from_sequence(
629+
arr, dtype=dtype, copy=copy)
630+
except AttributeError:
631+
dtype = pandas_dtype(dtype)
632+
return dtype.construct_array_type()._from_sequence(
633+
arr, dtype=dtype, copy=copy)
620634

621635
if not isinstance(dtype, np.dtype):
622636
dtype = pandas_dtype(dtype)

pandas/core/dtypes/common.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,7 +1795,10 @@ def _get_dtype(arr_or_dtype):
17951795
if isinstance(arr_or_dtype, np.dtype):
17961796
return arr_or_dtype
17971797
elif isinstance(arr_or_dtype, type):
1798-
return np.dtype(arr_or_dtype)
1798+
try:
1799+
return pandas_dtype(arr_or_dtype)
1800+
except TypeError:
1801+
return np.dtype(arr_or_dtype)
17991802
elif isinstance(arr_or_dtype, ExtensionDtype):
18001803
return arr_or_dtype
18011804
elif isinstance(arr_or_dtype, DatetimeTZDtype):
@@ -1813,6 +1816,11 @@ def _get_dtype(arr_or_dtype):
18131816
return PeriodDtype.construct_from_string(arr_or_dtype)
18141817
elif is_interval_dtype(arr_or_dtype):
18151818
return IntervalDtype.construct_from_string(arr_or_dtype)
1819+
else:
1820+
try:
1821+
return pandas_dtype(arr_or_dtype)
1822+
except TypeError:
1823+
pass
18161824
elif isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex,
18171825
ABCSparseArray, ABCSparseSeries)):
18181826
return arr_or_dtype.dtype
@@ -1843,7 +1851,15 @@ def _get_dtype_type(arr_or_dtype):
18431851
if isinstance(arr_or_dtype, np.dtype):
18441852
return arr_or_dtype.type
18451853
elif isinstance(arr_or_dtype, type):
1846-
return np.dtype(arr_or_dtype).type
1854+
try:
1855+
dtype = pandas_dtype(arr_or_dtype)
1856+
try:
1857+
return dtype.type
1858+
except AttributeError:
1859+
raise TypeError
1860+
except TypeError:
1861+
return np.dtype(arr_or_dtype).type
1862+
18471863
elif isinstance(arr_or_dtype, CategoricalDtype):
18481864
return CategoricalDtypeType
18491865
elif isinstance(arr_or_dtype, DatetimeTZDtype):

pandas/io/parsers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pandas.core.dtypes.common import (
3030
ensure_object, is_bool_dtype, is_categorical_dtype, is_dtype_equal,
3131
is_float, is_integer, is_integer_dtype, is_list_like, is_object_dtype,
32-
is_scalar, is_string_dtype)
32+
is_scalar, is_string_dtype, is_extension_array_dtype)
3333
from pandas.core.dtypes.dtypes import CategoricalDtype
3434
from pandas.core.dtypes.missing import isna
3535

@@ -1660,15 +1660,17 @@ def _convert_to_ndarrays(self, dct, na_values, na_fvalues, verbose=False,
16601660
try_num_bool=False)
16611661
else:
16621662
# skip inference if specified dtype is object
1663-
try_num_bool = not (cast_type and is_string_dtype(cast_type))
1663+
try_num_bool = not (cast_type and (is_string_dtype(cast_type)
1664+
or is_extension_array_dtype(cast_type)))
16641665

16651666
# general type inference and conversion
16661667
cvals, na_count = self._infer_types(
16671668
values, set(col_na_values) | col_na_fvalues,
16681669
try_num_bool)
16691670

16701671
# type specified in dtype param
1671-
if cast_type and not is_dtype_equal(cvals, cast_type):
1672+
if cast_type and not (is_dtype_equal(cvals, cast_type)
1673+
or is_extension_array_dtype(cast_type)):
16721674
try:
16731675
if (is_bool_dtype(cast_type) and
16741676
not is_categorical_dtype(cast_type)

pandas/tests/extension/base/io.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pandas as pd
2+
from pandas.compat import StringIO
3+
from pandas.core.arrays.integer import Int64Dtype
4+
from .base import BaseExtensionTests
5+
6+
7+
class ExtensionParsingTests(BaseExtensionTests):
8+
def test_EA_types(self):
9+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int64'),
10+
'A': [1, 2, 1]})
11+
data = df.to_csv(index=False)
12+
result = pd.read_csv(StringIO(data), dtype={'Int': Int64Dtype})
13+
assert result is not None
14+
15+
df = pd.DataFrame({'Int': pd.Series([1, 2, 3], dtype='Int8'),
16+
'A': [1, 2, 1]})
17+
data = df.to_csv(index=False)
18+
result = pd.read_csv(StringIO(data), dtype={'Int': 'Int8'})
19+
assert result is not None

pandas/tests/extension/decimal/array.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def dtype(self):
7373
def _from_sequence(cls, scalars, dtype=None, copy=False):
7474
return cls(scalars)
7575

76+
@classmethod
77+
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
78+
return cls._from_sequence([decimal.Decimal(x) for x in strings],
79+
dtype, copy)
80+
7681
@classmethod
7782
def _from_factorized(cls, values, original):
7883
return cls(values)

0 commit comments

Comments
 (0)