diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index b6c13c287d5f9..f782ec1fddc92 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -377,6 +377,7 @@ Missing MultiIndex ^^^^^^^^^^ - Bug in :meth:`MultiIndex.set_levels` not preserving dtypes for :class:`Categorical` (:issue:`52125`) +- Bug in displaying a :class:`MultiIndex` with a long element (:issue:`52960`) I/O ^^^ @@ -386,7 +387,7 @@ I/O - Bug in :func:`read_hdf` not properly closing store after a ``IndexError`` is raised (:issue:`52781`) - Bug in :func:`read_html`, style elements were read into DataFrames (:issue:`52197`) - Bug in :func:`read_html`, tail texts were removed together with elements containing ``display:none`` style (:issue:`51629`) -- Bug in displaying a :class:`MultiIndex` with a long element (:issue:`52960`) +- Bug when writing and reading empty Stata dta files where dtype information was lost (:issue:`46240`) Period ^^^^^^ diff --git a/pandas/io/stata.py b/pandas/io/stata.py index aed28efecb696..fbadda0a4128f 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -608,9 +608,10 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: # Replace with NumPy-compatible column data[col] = data[col].astype(data[col].dtype.numpy_dtype) dtype = data[col].dtype + empty_df = data.shape[0] == 0 for c_data in conversion_data: if dtype == c_data[0]: - if data[col].max() <= np.iinfo(c_data[1]).max: + if empty_df or data[col].max() <= np.iinfo(c_data[1]).max: dtype = c_data[1] else: dtype = c_data[2] @@ -621,14 +622,17 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame: data[col] = data[col].astype(dtype) # Check values and upcast if necessary - if dtype == np.int8: + + if dtype == np.int8 and not empty_df: if data[col].max() > 100 or data[col].min() < -127: data[col] = data[col].astype(np.int16) - elif dtype == np.int16: + elif dtype == np.int16 and not empty_df: if data[col].max() > 32740 or data[col].min() < -32767: data[col] = data[col].astype(np.int32) elif dtype == np.int64: - if data[col].max() <= 2147483620 and data[col].min() >= -2147483647: + if empty_df or ( + data[col].max() <= 2147483620 and data[col].min() >= -2147483647 + ): data[col] = data[col].astype(np.int32) else: data[col] = data[col].astype(np.float64) @@ -1700,13 +1704,6 @@ def read( order_categoricals: bool | None = None, ) -> DataFrame: self._ensure_open() - # Handle empty file or chunk. If reading incrementally raise - # StopIteration. If reading the whole thing return an empty - # data frame. - if (self._nobs == 0) and (nrows is None): - self._can_read_value_labels = True - self._data_read = True - return DataFrame(columns=self._varlist) # Handle options if convert_dates is None: @@ -1723,10 +1720,26 @@ def read( order_categoricals = self._order_categoricals if index_col is None: index_col = self._index_col - if nrows is None: nrows = self._nobs + # Handle empty file or chunk. If reading incrementally raise + # StopIteration. If reading the whole thing return an empty + # data frame. + if (self._nobs == 0) and nrows == 0: + self._can_read_value_labels = True + self._data_read = True + data = DataFrame(columns=self._varlist) + # Apply dtypes correctly + for i, col in enumerate(data.columns): + dt = self._dtyplist[i] + if isinstance(dt, np.dtype): + if dt.char != "S": + data[col] = data[col].astype(dt) + if columns is not None: + data = self._do_select_columns(data, columns) + return data + if (self._format_version >= 117) and (not self._value_labels_read): self._can_read_value_labels = True self._read_strls() diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 05c397c4ea4f1..753f0341af7fc 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -71,6 +71,41 @@ def test_read_empty_dta(self, version): empty_ds2 = read_stata(path) tm.assert_frame_equal(empty_ds, empty_ds2) + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_empty_dta_with_dtypes(self, version): + # GH 46240 + # Fixing above bug revealed that types are not correctly preserved when + # writing empty DataFrames + empty_df_typed = DataFrame( + { + "i8": np.array([0], dtype=np.int8), + "i16": np.array([0], dtype=np.int16), + "i32": np.array([0], dtype=np.int32), + "i64": np.array([0], dtype=np.int64), + "u8": np.array([0], dtype=np.uint8), + "u16": np.array([0], dtype=np.uint16), + "u32": np.array([0], dtype=np.uint32), + "u64": np.array([0], dtype=np.uint64), + "f32": np.array([0], dtype=np.float32), + "f64": np.array([0], dtype=np.float64), + } + ) + expected = empty_df_typed.copy() + # No uint# support. Downcast since values in range for int# + expected["u8"] = expected["u8"].astype(np.int8) + expected["u16"] = expected["u16"].astype(np.int16) + expected["u32"] = expected["u32"].astype(np.int32) + # No int64 supported at all. Downcast since values in range for int32 + expected["u64"] = expected["u64"].astype(np.int32) + expected["i64"] = expected["i64"].astype(np.int32) + + # GH 7369, make sure can read a 0-obs dta file + with tm.ensure_clean() as path: + empty_df_typed.to_stata(path, write_index=False, version=version) + empty_reread = read_stata(path) + tm.assert_frame_equal(expected, empty_reread) + tm.assert_series_equal(expected.dtypes, empty_reread.dtypes) + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) def test_read_index_col_none(self, version): df = DataFrame({"a": range(5), "b": ["b1", "b2", "b3", "b4", "b5"]}) @@ -2274,3 +2309,21 @@ def test_nullable_support(dtype, version): tm.assert_series_equal(df.a, reread.a) tm.assert_series_equal(reread.b, expected_b) tm.assert_series_equal(reread.c, expected_c) + + +def test_empty_frame(): + # GH 46240 + # create an empty DataFrame with int64 and float64 dtypes + df = DataFrame(data={"a": range(3), "b": [1.0, 2.0, 3.0]}).head(0) + with tm.ensure_clean() as path: + df.to_stata(path, write_index=False, version=117) + # Read entire dataframe + df2 = read_stata(path) + assert "b" in df2 + # Dtypes don't match since no support for int32 + dtypes = Series({"a": np.dtype("int32"), "b": np.dtype("float64")}) + tm.assert_series_equal(df2.dtypes, dtypes) + # read one column of empty .dta file + df3 = read_stata(path, columns=["a"]) + assert "b" not in df3 + tm.assert_series_equal(df3.dtypes, dtypes.loc[["a"]])