Skip to content

Commit d1c0b51

Browse files
committed
ENH: Add option to use nullable dtypes in read_csv
1 parent 73d15a7 commit d1c0b51

File tree

3 files changed

+60
-12
lines changed

3 files changed

+60
-12
lines changed

pandas/_libs/parsers.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ cdef class TextReader:
342342
object index_col
343343
object skiprows
344344
object dtype
345+
bint use_nullable_dtypes
345346
object usecols
346347
set unnamed_cols # set[str]
347348

@@ -380,7 +381,8 @@ cdef class TextReader:
380381
bint mangle_dupe_cols=True,
381382
float_precision=None,
382383
bint skip_blank_lines=True,
383-
encoding_errors=b"strict"):
384+
encoding_errors=b"strict",
385+
use_nullable_dtypes=False):
384386

385387
# set encoding for native Python and C library
386388
if isinstance(encoding_errors, str):
@@ -505,6 +507,7 @@ cdef class TextReader:
505507
# - DtypeObj
506508
# - dict[Any, DtypeObj]
507509
self.dtype = dtype
510+
self.use_nullable_dtypes = use_nullable_dtypes
508511

509512
# XXX
510513
self.noconvert = set()
@@ -1053,8 +1056,8 @@ cdef class TextReader:
10531056
self._free_na_set(na_hashset)
10541057

10551058
# don't try to upcast EAs
1056-
if na_count > 0 and not is_extension_array_dtype(col_dtype):
1057-
col_res = _maybe_upcast(col_res)
1059+
if na_count > 0 and not is_extension_array_dtype(col_dtype) or self.use_nullable_dtypes:
1060+
col_res = _maybe_upcast(col_res, use_nullable_dtypes=self.use_nullable_dtypes)
10581061

10591062
if col_res is None:
10601063
raise ParserError(f'Unable to parse column {i}')

pandas/io/parsers/base_parser.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
is_dict_like,
5151
is_dtype_equal,
5252
is_extension_array_dtype,
53+
is_float_dtype,
5354
is_integer,
5455
is_integer_dtype,
5556
is_list_like,
@@ -61,8 +62,14 @@
6162
from pandas.core.dtypes.dtypes import CategoricalDtype
6263
from pandas.core.dtypes.missing import isna
6364

65+
from pandas import StringDtype
6466
from pandas.core import algorithms
65-
from pandas.core.arrays import Categorical
67+
from pandas.core.arrays import (
68+
BooleanArray,
69+
Categorical,
70+
FloatingArray,
71+
IntegerArray,
72+
)
6673
from pandas.core.indexes.api import (
6774
Index,
6875
MultiIndex,
@@ -110,6 +117,7 @@ def __init__(self, kwds) -> None:
110117

111118
self.dtype = copy(kwds.get("dtype", None))
112119
self.converters = kwds.get("converters")
120+
self.use_nullable_dtypes = kwds.get("use_nullable_dtypes")
113121

114122
self.true_values = kwds.get("true_values")
115123
self.false_values = kwds.get("false_values")
@@ -589,10 +597,7 @@ def _convert_to_ndarrays(
589597
)
590598

591599
# type specified in dtype param or cast_type is an EA
592-
if cast_type and (
593-
not is_dtype_equal(cvals, cast_type)
594-
or is_extension_array_dtype(cast_type)
595-
):
600+
if cast_type and (not is_dtype_equal(cvals, cast_type) or is_ea):
596601
if not is_ea and na_count > 0:
597602
try:
598603
if is_bool_dtype(cast_type):
@@ -710,14 +715,36 @@ def _infer_types(self, values, na_values, try_num_bool: bool = True):
710715
if try_num_bool and is_object_dtype(values.dtype):
711716
# exclude e.g DatetimeIndex here
712717
try:
713-
result, _ = lib.maybe_convert_numeric(values, na_values, False)
718+
result, result_mask = lib.maybe_convert_numeric(
719+
values,
720+
na_values,
721+
False,
722+
convert_to_masked_nullable=self.use_nullable_dtypes,
723+
)
714724
except (ValueError, TypeError):
715725
# e.g. encountering datetime string gets ValueError
716726
# TypeError can be raised in floatify
717-
result = values
718-
na_count = parsers.sanitize_objects(result, na_values)
727+
na_count = parsers.sanitize_objects(values, na_values)
728+
729+
if self.use_nullable_dtypes:
730+
result = StringDtype().construct_array_type()._from_sequence(values)
731+
else:
732+
result = values
719733
else:
720-
na_count = isna(result).sum()
734+
if self.use_nullable_dtypes:
735+
if result_mask is None:
736+
result_mask = np.zeros(result.shape, dtype="bool")
737+
738+
if is_integer_dtype(result):
739+
result = IntegerArray(result, result_mask)
740+
elif is_bool_dtype(result):
741+
result = BooleanArray(result, result_mask)
742+
elif is_float_dtype(result):
743+
result = FloatingArray(result, result_mask)
744+
745+
na_count = result_mask.sum()
746+
else:
747+
na_count = isna(result).sum()
721748
else:
722749
result = values
723750
if values.dtype == np.object_:
@@ -1146,6 +1173,7 @@ def converter(*date_cols):
11461173
"on_bad_lines": ParserBase.BadLineHandleMethod.ERROR,
11471174
"error_bad_lines": None,
11481175
"warn_bad_lines": None,
1176+
"use_nullable_dtypes": False,
11491177
}
11501178

11511179

pandas/io/parsers/readers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@
427427
428428
.. versionadded:: 1.2
429429
430+
use_nullable_dtypes: bool = False
431+
Whether or not to use nullable dtypes as default when reading data. If
432+
set to True, nullable dtypes are used for all dtypes that have a nullable
433+
implementation, even if no nulls are present.
434+
435+
.. versionadded:: 2.0
436+
430437
Returns
431438
-------
432439
DataFrame or TextFileReader
@@ -669,6 +676,7 @@ def read_csv(
669676
memory_map: bool = ...,
670677
float_precision: Literal["high", "legacy"] | None = ...,
671678
storage_options: StorageOptions = ...,
679+
use_nullable_dtypes: bool = ...,
672680
) -> TextFileReader:
673681
...
674682

@@ -729,6 +737,7 @@ def read_csv(
729737
memory_map: bool = ...,
730738
float_precision: Literal["high", "legacy"] | None = ...,
731739
storage_options: StorageOptions = ...,
740+
use_nullable_dtypes: bool = ...,
732741
) -> TextFileReader:
733742
...
734743

@@ -789,6 +798,7 @@ def read_csv(
789798
memory_map: bool = ...,
790799
float_precision: Literal["high", "legacy"] | None = ...,
791800
storage_options: StorageOptions = ...,
801+
use_nullable_dtypes: bool = ...,
792802
) -> DataFrame:
793803
...
794804

@@ -849,6 +859,7 @@ def read_csv(
849859
memory_map: bool = ...,
850860
float_precision: Literal["high", "legacy"] | None = ...,
851861
storage_options: StorageOptions = ...,
862+
use_nullable_dtypes: bool = ...,
852863
) -> DataFrame | TextFileReader:
853864
...
854865

@@ -928,6 +939,7 @@ def read_csv(
928939
memory_map: bool = False,
929940
float_precision: Literal["high", "legacy"] | None = None,
930941
storage_options: StorageOptions = None,
942+
use_nullable_dtypes: bool = False,
931943
) -> DataFrame | TextFileReader:
932944
# locals() should never be modified
933945
kwds = locals().copy()
@@ -1008,6 +1020,7 @@ def read_table(
10081020
memory_map: bool = ...,
10091021
float_precision: str | None = ...,
10101022
storage_options: StorageOptions = ...,
1023+
use_nullable_dtypes: bool = ...,
10111024
) -> TextFileReader:
10121025
...
10131026

@@ -1068,6 +1081,7 @@ def read_table(
10681081
memory_map: bool = ...,
10691082
float_precision: str | None = ...,
10701083
storage_options: StorageOptions = ...,
1084+
use_nullable_dtypes: bool = ...,
10711085
) -> TextFileReader:
10721086
...
10731087

@@ -1128,6 +1142,7 @@ def read_table(
11281142
memory_map: bool = ...,
11291143
float_precision: str | None = ...,
11301144
storage_options: StorageOptions = ...,
1145+
use_nullable_dtypes: bool = ...,
11311146
) -> DataFrame:
11321147
...
11331148

@@ -1188,6 +1203,7 @@ def read_table(
11881203
memory_map: bool = ...,
11891204
float_precision: str | None = ...,
11901205
storage_options: StorageOptions = ...,
1206+
use_nullable_dtypes: bool = ...,
11911207
) -> DataFrame | TextFileReader:
11921208
...
11931209

@@ -1267,6 +1283,7 @@ def read_table(
12671283
memory_map: bool = False,
12681284
float_precision: str | None = None,
12691285
storage_options: StorageOptions = None,
1286+
use_nullable_dtypes: bool = False,
12701287
) -> DataFrame | TextFileReader:
12711288
# locals() should never be modified
12721289
kwds = locals().copy()

0 commit comments

Comments
 (0)