Skip to content

Commit d7a7eca

Browse files
committed
Finish implementation
1 parent d1c0b51 commit d7a7eca

File tree

4 files changed

+75
-15
lines changed

4 files changed

+75
-15
lines changed

doc/source/whatsnew/v1.6.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Other enhancements
3232
- :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
3333
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
3434
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
35+
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` to enable automatic conversion to nullable dtypes (:issue:`36712`)
3536
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
3637
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
3738
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)

pandas/_libs/parsers.pyx

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ cdef class TextReader:
936936
bint na_filter = 0
937937
int64_t num_cols
938938
dict result
939+
bint use_nullable_dtypes
939940

940941
start = self.parser_start
941942

@@ -1056,8 +1057,15 @@ cdef class TextReader:
10561057
self._free_na_set(na_hashset)
10571058

10581059
# don't try to upcast EAs
1059-
if na_count > 0 and not is_extension_array_dtype(col_dtype) or self.use_nullable_dtypes:
1060-
col_res = _maybe_upcast(col_res, use_nullable_dtypes=self.use_nullable_dtypes)
1060+
print(col_dtype)
1061+
if (
1062+
na_count > 0 and not is_extension_array_dtype(col_dtype)
1063+
or self.use_nullable_dtypes
1064+
):
1065+
use_nullable_dtypes = self.use_nullable_dtypes and col_dtype is None
1066+
col_res = _maybe_upcast(
1067+
col_res, use_nullable_dtypes=use_nullable_dtypes
1068+
)
10611069

10621070
if col_res is None:
10631071
raise ParserError(f'Unable to parse column {i}')

pandas/io/parsers/base_parser.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, kwds) -> None:
117117

118118
self.dtype = copy(kwds.get("dtype", None))
119119
self.converters = kwds.get("converters")
120-
self.use_nullable_dtypes = kwds.get("use_nullable_dtypes")
120+
self.use_nullable_dtypes = kwds.get("use_nullable_dtypes", False)
121121

122122
self.true_values = kwds.get("true_values")
123123
self.false_values = kwds.get("false_values")
@@ -516,7 +516,7 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
516516
)
517517

518518
arr, _ = self._infer_types(
519-
arr, col_na_values | col_na_fvalues, try_num_bool
519+
arr, col_na_values | col_na_fvalues, cast_type, try_num_bool
520520
)
521521
arrays.append(arr)
522522

@@ -582,7 +582,10 @@ def _convert_to_ndarrays(
582582
values = lib.map_infer_mask(values, conv_f, mask)
583583

584584
cvals, na_count = self._infer_types(
585-
values, set(col_na_values) | col_na_fvalues, try_num_bool=False
585+
values,
586+
set(col_na_values) | col_na_fvalues,
587+
cast_type,
588+
try_num_bool=False,
586589
)
587590
else:
588591
is_ea = is_extension_array_dtype(cast_type)
@@ -593,7 +596,7 @@ def _convert_to_ndarrays(
593596

594597
# general type inference and conversion
595598
cvals, na_count = self._infer_types(
596-
values, set(col_na_values) | col_na_fvalues, try_num_bool
599+
values, set(col_na_values) | col_na_fvalues, cast_type, try_num_bool
597600
)
598601

599602
# type specified in dtype param or cast_type is an EA
@@ -684,14 +687,15 @@ def _set(x) -> int:
684687

685688
return noconvert_columns
686689

687-
def _infer_types(self, values, na_values, try_num_bool: bool = True):
690+
def _infer_types(self, values, na_values, cast_type, try_num_bool: bool = True):
688691
"""
689692
Infer types of values, possibly casting
690693
691694
Parameters
692695
----------
693696
values : ndarray
694697
na_values : set
698+
cast_type: Specifies if we want to cast explicitly
695699
try_num_bool : bool, default try
696700
try to cast values to numeric (first preference) or boolean
697701
@@ -712,26 +716,24 @@ def _infer_types(self, values, na_values, try_num_bool: bool = True):
712716
np.putmask(values, mask, np.nan)
713717
return values, na_count
714718

719+
use_nullable_dtypes = self.use_nullable_dtypes and cast_type is None
720+
715721
if try_num_bool and is_object_dtype(values.dtype):
716722
# exclude e.g DatetimeIndex here
717723
try:
718724
result, result_mask = lib.maybe_convert_numeric(
719725
values,
720726
na_values,
721727
False,
722-
convert_to_masked_nullable=self.use_nullable_dtypes,
728+
convert_to_masked_nullable=use_nullable_dtypes,
723729
)
724730
except (ValueError, TypeError):
725731
# e.g. encountering datetime string gets ValueError
726732
# TypeError can be raised in floatify
727733
na_count = parsers.sanitize_objects(values, na_values)
728-
729-
if self.use_nullable_dtypes:
730-
result = StringDtype().construct_array_type()._from_sequence(values)
731-
else:
732-
result = values
734+
result = values
733735
else:
734-
if self.use_nullable_dtypes:
736+
if use_nullable_dtypes:
735737
if result_mask is None:
736738
result_mask = np.zeros(result.shape, dtype="bool")
737739

@@ -751,11 +753,18 @@ def _infer_types(self, values, na_values, try_num_bool: bool = True):
751753
na_count = parsers.sanitize_objects(values, na_values)
752754

753755
if result.dtype == np.object_ and try_num_bool:
754-
result, _ = libops.maybe_convert_bool(
756+
result, mask = libops.maybe_convert_bool(
755757
np.asarray(values),
756758
true_values=self.true_values,
757759
false_values=self.false_values,
760+
convert_to_masked_nullable=use_nullable_dtypes,
758761
)
762+
if result.dtype == np.bool_ and use_nullable_dtypes:
763+
if mask is None:
764+
mask = np.zeros(result.shape, dtype=np.bool_)
765+
result = BooleanArray(result, mask)
766+
elif result.dtype == np.object_ and use_nullable_dtypes:
767+
result = StringDtype().construct_array_type()._from_sequence(values)
759768

760769
return result, na_count
761770

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,3 +385,45 @@ def test_dtypes_defaultdict_invalid(all_parsers):
385385
parser = all_parsers
386386
with pytest.raises(TypeError, match="not understood"):
387387
parser.read_csv(StringIO(data), dtype=dtype)
388+
389+
390+
def test_use_nullabla_dtypes(all_parsers):
391+
# GH#36712
392+
393+
parser = all_parsers
394+
395+
data = """a,b,c,d,e,f,g,h,i
396+
1,2.5,True,a,,,,,12-31-2019
397+
3,4.5,False,b,6,7.5,True,a,12-31-2019
398+
"""
399+
result = parser.read_csv(
400+
StringIO(data), use_nullable_dtypes=True, parse_dates=["i"]
401+
)
402+
expected = DataFrame(
403+
{
404+
"a": pd.Series([1, 3], dtype="Int64"),
405+
"b": pd.Series([2.5, 4.5], dtype="Float64"),
406+
"c": pd.Series([True, False], dtype="boolean"),
407+
"d": pd.Series(["a", "b"], dtype="string"),
408+
"e": pd.Series([pd.NA, 6], dtype="Int64"),
409+
"f": pd.Series([pd.NA, 7.5], dtype="Float64"),
410+
"g": pd.Series([pd.NA, True], dtype="boolean"),
411+
"h": pd.Series([pd.NA, "a"], dtype="string"),
412+
"i": pd.Series([Timestamp("2019-12-31")] * 2),
413+
}
414+
)
415+
tm.assert_frame_equal(result, expected)
416+
417+
418+
def test_use_nullabla_dtypes_and_dtype(all_parsers):
419+
# GH#36712
420+
421+
parser = all_parsers
422+
423+
data = """a,b
424+
1,2.5
425+
,
426+
"""
427+
result = parser.read_csv(StringIO(data), use_nullable_dtypes=True, dtype="float64")
428+
expected = DataFrame({"a": [1.0, np.nan], "b": [2.5, np.nan]})
429+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)