Skip to content

BUG: to_stata not handling ea dtypes correctly #56771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ I/O
- Bug in :func:`read_json` not handling dtype conversion properly if ``infer_string`` is set (:issue:`56195`)
- Bug in :meth:`DataFrame.to_excel`, with ``OdsWriter`` (``ods`` files) writing Boolean/string value (:issue:`54994`)
- Bug in :meth:`DataFrame.to_hdf` and :func:`read_hdf` with ``datetime64`` dtypes with non-nanosecond resolution failing to round-trip correctly (:issue:`55622`)
- Bug in :meth:`DataFrame.to_stata` raising for extension dtypes (:issue:`54671`)
- Bug in :meth:`~pandas.read_excel` with ``engine="odf"`` (``ods`` files) when a string cell contains an annotation (:issue:`55200`)
- Bug in :meth:`~pandas.read_excel` with an ODS file without cached formatted cell for float values (:issue:`55219`)
- Bug where :meth:`DataFrame.to_json` would raise an ``OverflowError`` instead of a ``TypeError`` with unsupported NumPy types (:issue:`55403`)
Expand Down
23 changes: 14 additions & 9 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@
)
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import (
ensure_object,
is_numeric_dtype,
is_string_dtype,
)
from pandas.core.dtypes.dtypes import CategoricalDtype

Expand All @@ -62,8 +64,6 @@
to_datetime,
to_timedelta,
)
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import IntegerDtype
from pandas.core.frame import DataFrame
from pandas.core.indexes.base import Index
from pandas.core.indexes.range import RangeIndex
Expand Down Expand Up @@ -591,17 +591,22 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:

for col in data:
# Cast from unsupported types to supported types
is_nullable_int = isinstance(data[col].dtype, (IntegerDtype, BooleanDtype))
is_nullable_int = (
isinstance(data[col].dtype, ExtensionDtype)
and data[col].dtype.kind in "iub"
)
# We need to find orig_missing before altering data below
orig_missing = data[col].isna()
if is_nullable_int:
missing_loc = data[col].isna()
if missing_loc.any():
# Replace with always safe value
fv = 0 if isinstance(data[col].dtype, IntegerDtype) else False
data.loc[missing_loc, col] = fv
fv = 0 if data[col].dtype.kind in "iu" else False
# Replace with NumPy-compatible column
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
data[col] = data[col].fillna(fv).astype(data[col].dtype.numpy_dtype)
elif isinstance(data[col].dtype, ExtensionDtype):
if getattr(data[col].dtype, "numpy_dtype", None) is not None:
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
elif is_string_dtype(data[col].dtype):
data[col] = data[col].astype("object")

dtype = data[col].dtype
empty_df = data.shape[0] == 0
for c_data in conversion_data:
Expand Down
37 changes: 37 additions & 0 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

import pandas as pd
from pandas import CategoricalDtype
import pandas._testing as tm
Expand Down Expand Up @@ -1919,6 +1921,41 @@ def test_writer_118_exceptions(self):
with pytest.raises(ValueError, match="You must use version 119"):
StataWriterUTF8(path, df, version=118)

@pytest.mark.parametrize(
"dtype_backend",
["numpy_nullable", pytest.param("pyarrow", marks=td.skip_if_no("pyarrow"))],
)
def test_read_write_ea_dtypes(self, dtype_backend):
df = DataFrame(
{
"a": [1, 2, None],
"b": ["a", "b", "c"],
"c": [True, False, None],
"d": [1.5, 2.5, 3.5],
"e": pd.date_range("2020-12-31", periods=3, freq="D"),
},
index=pd.Index([0, 1, 2], name="index"),
)
df = df.convert_dtypes(dtype_backend=dtype_backend)
df.to_stata("test_stata.dta", version=118)

with tm.ensure_clean() as path:
df.to_stata(path)
written_and_read_again = self.read_dta(path)

expected = DataFrame(
{
"a": [1, 2, np.nan],
"b": ["a", "b", "c"],
"c": [1.0, 0, np.nan],
"d": [1.5, 2.5, 3.5],
"e": pd.date_range("2020-12-31", periods=3, freq="D"),
},
index=pd.Index([0, 1, 2], name="index", dtype=np.int32),
)

tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)


@pytest.mark.parametrize("version", [105, 108, 111, 113, 114])
def test_backward_compat(version, datapath):
Expand Down