diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index fee8334940a16..9166b28af4e78 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -101,6 +101,7 @@ Other enhancements - :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview ` for performance and functional benefits (:issue:`42273`) - :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` now support the argument ``skipna`` (:issue:`34047`) - :meth:`read_table` now supports the argument ``storage_options`` (:issue:`39167`) +- :meth:`DataFrame.to_stata` and :meth:`StataWriter` now accept the keyword only argument ``value_labels`` to save labels for non-categorical columns - Methods that relied on hashmap based algos such as :meth:`DataFrameGroupBy.value_counts`, :meth:`DataFrameGroupBy.count` and :func:`factorize` ignored imaginary component for complex numbers (:issue:`17927`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index db12129a15ef9..c1a3f4d9298b4 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -2382,6 +2382,8 @@ def to_stata( convert_strl: Sequence[Hashable] | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, ) -> None: """ Export DataFrame object to Stata dta format. @@ -2463,6 +2465,13 @@ def to_stata( .. versionadded:: 1.2.0 + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. Labels for a single variable must be 32,000 + characters or smaller. + + .. versionadded:: 1.4.0 + Raises ------ NotImplementedError @@ -2524,6 +2533,7 @@ def to_stata( variable_labels=variable_labels, compression=compression, storage_options=storage_options, + value_labels=value_labels, **kwargs, ) writer.write_file() diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 1deaa634ce3ae..11b9e8f7009c4 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -18,6 +18,7 @@ import struct import sys from typing import ( + TYPE_CHECKING, Any, AnyStr, Hashable, @@ -46,6 +47,7 @@ ensure_object, is_categorical_dtype, is_datetime64_dtype, + is_numeric_dtype, ) from pandas import ( @@ -64,6 +66,9 @@ from pandas.io.common import get_handle +if TYPE_CHECKING: + from typing import Literal + _version_error = ( "Version of given Stata file is {version}. pandas supports importing " "versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), " @@ -658,24 +663,37 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"): self.labname = catarray.name self._encoding = encoding categories = catarray.cat.categories - self.value_labels = list(zip(np.arange(len(categories)), categories)) + self.value_labels: list[tuple[int | float, str]] = list( + zip(np.arange(len(categories)), categories) + ) self.value_labels.sort(key=lambda x: x[0]) + + self._prepare_value_labels() + + def _prepare_value_labels(self): + """Encode value labels.""" + self.text_len = 0 self.txt: list[bytes] = [] self.n = 0 + # Offsets (length of categories), converted to int32 + self.off = np.array([]) + # Values, converted to int32 + self.val = np.array([]) + self.len = 0 # Compute lengths and setup lists of offsets and labels offsets: list[int] = [] - values: list[int] = [] + values: list[int | float] = [] for vl in self.value_labels: - category = vl[1] + category: str | bytes = vl[1] if not isinstance(category, str): category = str(category) warnings.warn( - value_label_mismatch_doc.format(catarray.name), + value_label_mismatch_doc.format(self.labname), ValueLabelTypeMismatch, ) - category = category.encode(encoding) + category = category.encode(self._encoding) offsets.append(self.text_len) self.text_len += len(category) + 1 # +1 for the padding values.append(vl[0]) @@ -748,6 +766,38 @@ def generate_value_label(self, byteorder: str) -> bytes: return bio.getvalue() +class StataNonCatValueLabel(StataValueLabel): + """ + Prepare formatted version of value labels + + Parameters + ---------- + labname : str + Value label name + value_labels: Dictionary + Mapping of values to labels + encoding : {"latin-1", "utf-8"} + Encoding to use for value labels. + """ + + def __init__( + self, + labname: str, + value_labels: dict[float | int, str], + encoding: Literal["latin-1", "utf-8"] = "latin-1", + ): + + if encoding not in ("latin-1", "utf-8"): + raise ValueError("Only latin-1 and utf-8 are supported.") + + self.labname = labname + self._encoding = encoding + self.value_labels: list[tuple[int | float, str]] = sorted( + value_labels.items(), key=lambda x: x[0] + ) + self._prepare_value_labels() + + class StataMissingValue: """ An observation's missing value. @@ -2175,6 +2225,13 @@ class StataWriter(StataParser): .. versionadded:: 1.2.0 + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. The combined length of all labels for a single + variable must be 32,000 characters or smaller. + + .. versionadded:: 1.4.0 + Returns ------- writer : StataWriter instance @@ -2225,15 +2282,22 @@ def __init__( variable_labels: dict[Hashable, str] | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, ): super().__init__() + self.data = data self._convert_dates = {} if convert_dates is None else convert_dates self._write_index = write_index self._time_stamp = time_stamp self._data_label = data_label self._variable_labels = variable_labels + self._non_cat_value_labels = value_labels + self._value_labels: list[StataValueLabel] = [] + self._has_value_labels = np.array([], dtype=bool) self._compression = compression self._output_file: Buffer | None = None + self._converted_names: dict[Hashable, str] = {} # attach nobs, nvars, data, varlist, typlist self._prepare_pandas(data) self.storage_options = storage_options @@ -2243,7 +2307,6 @@ def __init__( self._byteorder = _set_endianness(byteorder) self._fname = fname self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8} - self._converted_names: dict[Hashable, str] = {} def _write(self, to_write: str) -> None: """ @@ -2259,17 +2322,50 @@ def _write_bytes(self, value: bytes) -> None: """ self.handles.handle.write(value) # type: ignore[arg-type] + def _prepare_non_cat_value_labels( + self, data: DataFrame + ) -> list[StataNonCatValueLabel]: + """ + Check for value labels provided for non-categorical columns. Value + labels + """ + non_cat_value_labels: list[StataNonCatValueLabel] = [] + if self._non_cat_value_labels is None: + return non_cat_value_labels + + for labname, labels in self._non_cat_value_labels.items(): + if labname in self._converted_names: + colname = self._converted_names[labname] + elif labname in data.columns: + colname = str(labname) + else: + raise KeyError( + f"Can't create value labels for {labname}, it wasn't " + "found in the dataset." + ) + + if not is_numeric_dtype(data[colname].dtype): + # Labels should not be passed explicitly for categorical + # columns that will be converted to int + raise ValueError( + f"Can't create value labels for {labname}, value labels " + "can only be applied to numeric columns." + ) + svl = StataNonCatValueLabel(colname, labels) + non_cat_value_labels.append(svl) + return non_cat_value_labels + def _prepare_categoricals(self, data: DataFrame) -> DataFrame: """ Check for categorical columns, retain categorical information for Stata file and convert categorical data to int """ is_cat = [is_categorical_dtype(data[col].dtype) for col in data] - self._is_col_cat = is_cat - self._value_labels: list[StataValueLabel] = [] if not any(is_cat): return data + self._has_value_labels |= np.array(is_cat) + get_base_missing_value = StataMissingValue.get_base_missing_value data_formatted = [] for col, col_is_cat in zip(data, is_cat): @@ -2449,6 +2545,17 @@ def _prepare_pandas(self, data: DataFrame) -> None: # Replace NaNs with Stata missing values data = self._replace_nans(data) + # Set all columns to initially unlabelled + self._has_value_labels = np.repeat(False, data.shape[1]) + + # Create value labels for non-categorical data + non_cat_value_labels = self._prepare_non_cat_value_labels(data) + + non_cat_columns = [svl.labname for svl in non_cat_value_labels] + has_non_cat_val_labels = data.columns.isin(non_cat_columns) + self._has_value_labels |= has_non_cat_val_labels + self._value_labels.extend(non_cat_value_labels) + # Convert categoricals to int data, and strip labels data = self._prepare_categoricals(data) @@ -2688,7 +2795,7 @@ def _write_value_label_names(self) -> None: # lbllist, 33*nvar, char array for i in range(self.nvar): # Use variable name when categorical - if self._is_col_cat[i]: + if self._has_value_labels[i]: name = self.varlist[i] name = self._null_terminate_str(name) name = _pad_bytes(name[:32], 33) @@ -3059,6 +3166,13 @@ class StataWriter117(StataWriter): .. versionadded:: 1.1.0 + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. The combined length of all labels for a single + variable must be 32,000 characters or smaller. + + .. versionadded:: 1.4.0 + Returns ------- writer : StataWriter117 instance @@ -3112,6 +3226,8 @@ def __init__( convert_strl: Sequence[Hashable] | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, ): # Copy to new list since convert_strl might be modified later self._convert_strl: list[Hashable] = [] @@ -3127,6 +3243,7 @@ def __init__( time_stamp=time_stamp, data_label=data_label, variable_labels=variable_labels, + value_labels=value_labels, compression=compression, storage_options=storage_options, ) @@ -3272,7 +3389,7 @@ def _write_value_label_names(self) -> None: for i in range(self.nvar): # Use variable name when categorical name = "" # default name - if self._is_col_cat[i]: + if self._has_value_labels[i]: name = self.varlist[i] name = self._null_terminate_str(name) encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) @@ -3449,6 +3566,13 @@ class StataWriterUTF8(StataWriter117): .. versionadded:: 1.1.0 + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. The combined length of all labels for a single + variable must be 32,000 characters or smaller. + + .. versionadded:: 1.4.0 + Returns ------- StataWriterUTF8 @@ -3505,6 +3629,8 @@ def __init__( version: int | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, ): if version is None: version = 118 if data.shape[1] <= 32767 else 119 @@ -3525,6 +3651,7 @@ def __init__( time_stamp=time_stamp, data_label=data_label, variable_labels=variable_labels, + value_labels=value_labels, convert_strl=convert_strl, compression=compression, storage_options=storage_options, diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 6bf8d23f61937..02cf478c61583 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -29,6 +29,7 @@ PossiblePrecisionLoss, StataMissingValue, StataReader, + StataWriter, StataWriterUTF8, ValueLabelTypeMismatch, read_stata, @@ -2048,3 +2049,116 @@ def test_stata_compression(compression_only, read_infer, to_infer): df.to_stata(path, compression=to_compression) result = read_stata(path, compression=read_compression, index_col="index") tm.assert_frame_equal(result, df) + + +def test_non_categorical_value_labels(): + data = DataFrame( + { + "fully_labelled": [1, 2, 3, 3, 1], + "partially_labelled": [1.0, 2.0, np.nan, 9.0, np.nan], + "Y": [7, 7, 9, 8, 10], + "Z": pd.Categorical(["j", "k", "l", "k", "j"]), + } + ) + + with tm.ensure_clean() as path: + value_labels = { + "fully_labelled": {1: "one", 2: "two", 3: "three"}, + "partially_labelled": {1.0: "one", 2.0: "two"}, + } + expected = {**value_labels, "Z": {0: "j", 1: "k", 2: "l"}} + + writer = StataWriter(path, data, value_labels=value_labels) + writer.write_file() + + reader = StataReader(path) + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected + + msg = "Can't create value labels for notY, it wasn't found in the dataset." + with pytest.raises(KeyError, match=msg): + value_labels = {"notY": {7: "label1", 8: "label2"}} + writer = StataWriter(path, data, value_labels=value_labels) + + msg = ( + "Can't create value labels for Z, value labels " + "can only be applied to numeric columns." + ) + with pytest.raises(ValueError, match=msg): + value_labels = {"Z": {1: "a", 2: "k", 3: "j", 4: "i"}} + writer = StataWriter(path, data, value_labels=value_labels) + + +def test_non_categorical_value_label_name_conversion(): + # Check conversion of invalid variable names + data = DataFrame( + { + "invalid~!": [1, 1, 2, 3, 5, 8], # Only alphanumeric and _ + "6_invalid": [1, 1, 2, 3, 5, 8], # Must start with letter or _ + "invalid_name_longer_than_32_characters": [8, 8, 9, 9, 8, 8], # Too long + "aggregate": [2, 5, 5, 6, 6, 9], # Reserved words + (1, 2): [1, 2, 3, 4, 5, 6], # Hashable non-string + } + ) + + value_labels = { + "invalid~!": {1: "label1", 2: "label2"}, + "6_invalid": {1: "label1", 2: "label2"}, + "invalid_name_longer_than_32_characters": {8: "eight", 9: "nine"}, + "aggregate": {5: "five"}, + (1, 2): {3: "three"}, + } + + expected = { + "invalid__": {1: "label1", 2: "label2"}, + "_6_invalid": {1: "label1", 2: "label2"}, + "invalid_name_longer_than_32_char": {8: "eight", 9: "nine"}, + "_aggregate": {5: "five"}, + "_1__2_": {3: "three"}, + } + + with tm.ensure_clean() as path: + with tm.assert_produces_warning(InvalidColumnName): + data.to_stata(path, value_labels=value_labels) + + reader = StataReader(path) + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected + + +def test_non_categorical_value_label_convert_categoricals_error(): + # Mapping more than one value to the same label is valid for Stata + # labels, but can't be read with convert_categoricals=True + value_labels = { + "repeated_labels": {10: "Ten", 20: "More than ten", 40: "More than ten"} + } + + data = DataFrame( + { + "repeated_labels": [10, 10, 20, 20, 40, 40], + } + ) + + with tm.ensure_clean() as path: + data.to_stata(path, value_labels=value_labels) + + reader = StataReader(path, convert_categoricals=False) + reader_value_labels = reader.value_labels() + assert reader_value_labels == value_labels + + col = "repeated_labels" + repeats = "-" * 80 + "\n" + "\n".join(["More than ten"]) + + msg = f""" +Value labels for column {col} are not unique. These cannot be converted to +pandas categoricals. + +Either read the file with `convert_categoricals` set to False or use the +low level interface in `StataReader` to separately read the values and the +value_labels. + +The repeated labels are: +{repeats} +""" + with pytest.raises(ValueError, match=msg): + read_stata(path, convert_categoricals=True)