Skip to content

ENH: option to export df to Stata dataset with value labels #41042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Other enhancements
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
- :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` now support the argument ``skipna`` (:issue:`34047`)
- :meth:`read_table` now supports the argument ``storage_options`` (:issue:`39167`)
- :meth:`DataFrame.to_stata` and :meth:`StataWriter` now accept the keyword only argument ``value_labels`` to save labels for non-categorical columns
- Methods that relied on hashmap based algos such as :meth:`DataFrameGroupBy.value_counts`, :meth:`DataFrameGroupBy.count` and :func:`factorize` ignored imaginary component for complex numbers (:issue:`17927`)

.. ---------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,8 @@ def to_stata(
convert_strl: Sequence[Hashable] | None = None,
compression: CompressionOptions = "infer",
storage_options: StorageOptions = None,
*,
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
) -> None:
"""
Export DataFrame object to Stata dta format.
Expand Down Expand Up @@ -2463,6 +2465,13 @@ def to_stata(

.. versionadded:: 1.2.0

value_labels : dict of dicts
Dictionary containing columns as keys and dictionaries of column value
to labels as values. Labels for a single variable must be 32,000
characters or smaller.

.. versionadded:: 1.4.0

Raises
------
NotImplementedError
Expand Down Expand Up @@ -2524,6 +2533,7 @@ def to_stata(
variable_labels=variable_labels,
compression=compression,
storage_options=storage_options,
value_labels=value_labels,
**kwargs,
)
writer.write_file()
Expand Down
147 changes: 137 additions & 10 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import struct
import sys
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Hashable,
Expand Down Expand Up @@ -46,6 +47,7 @@
ensure_object,
is_categorical_dtype,
is_datetime64_dtype,
is_numeric_dtype,
)

from pandas import (
Expand All @@ -64,6 +66,9 @@

from pandas.io.common import get_handle

if TYPE_CHECKING:
from typing import Literal

_version_error = (
"Version of given Stata file is {version}. pandas supports importing "
"versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
Expand Down Expand Up @@ -658,24 +663,37 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
self.labname = catarray.name
self._encoding = encoding
categories = catarray.cat.categories
self.value_labels = list(zip(np.arange(len(categories)), categories))
self.value_labels: list[tuple[int | float, str]] = list(
zip(np.arange(len(categories)), categories)
)
self.value_labels.sort(key=lambda x: x[0])

self._prepare_value_labels()

def _prepare_value_labels(self):
"""Encode value labels."""

self.text_len = 0
self.txt: list[bytes] = []
self.n = 0
# Offsets (length of categories), converted to int32
self.off = np.array([])
# Values, converted to int32
self.val = np.array([])
self.len = 0

# Compute lengths and setup lists of offsets and labels
offsets: list[int] = []
values: list[int] = []
values: list[int | float] = []
for vl in self.value_labels:
category = vl[1]
category: str | bytes = vl[1]
if not isinstance(category, str):
category = str(category)
warnings.warn(
value_label_mismatch_doc.format(catarray.name),
value_label_mismatch_doc.format(self.labname),
ValueLabelTypeMismatch,
)
category = category.encode(encoding)
category = category.encode(self._encoding)
offsets.append(self.text_len)
self.text_len += len(category) + 1 # +1 for the padding
values.append(vl[0])
Expand Down Expand Up @@ -748,6 +766,38 @@ def generate_value_label(self, byteorder: str) -> bytes:
return bio.getvalue()


class StataNonCatValueLabel(StataValueLabel):
"""
Prepare formatted version of value labels

Parameters
----------
labname : str
Value label name
value_labels: Dictionary
Mapping of values to labels
encoding : {"latin-1", "utf-8"}
Encoding to use for value labels.
"""

def __init__(
self,
labname: str,
value_labels: dict[float | int, str],
encoding: Literal["latin-1", "utf-8"] = "latin-1",
):

if encoding not in ("latin-1", "utf-8"):
raise ValueError("Only latin-1 and utf-8 are supported.")

self.labname = labname
self._encoding = encoding
self.value_labels: list[tuple[int | float, str]] = sorted(
value_labels.items(), key=lambda x: x[0]
)
self._prepare_value_labels()


class StataMissingValue:
"""
An observation's missing value.
Expand Down Expand Up @@ -2175,6 +2225,13 @@ class StataWriter(StataParser):

.. versionadded:: 1.2.0

value_labels : dict of dicts
Dictionary containing columns as keys and dictionaries of column value
to labels as values. The combined length of all labels for a single
variable must be 32,000 characters or smaller.

.. versionadded:: 1.4.0

Returns
-------
writer : StataWriter instance
Expand Down Expand Up @@ -2225,15 +2282,22 @@ def __init__(
variable_labels: dict[Hashable, str] | None = None,
compression: CompressionOptions = "infer",
storage_options: StorageOptions = None,
*,
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
):
super().__init__()
self.data = data
self._convert_dates = {} if convert_dates is None else convert_dates
self._write_index = write_index
self._time_stamp = time_stamp
self._data_label = data_label
self._variable_labels = variable_labels
self._non_cat_value_labels = value_labels
self._value_labels: list[StataValueLabel] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't you rename this to _cat_value_labels to distingish and make consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_value_labels is a combined list of categorical and non-categorical labels. there isn't a standalone _cat_value_labels because they're added straight to _value_labels after converting categoricals in the dataframe

self._has_value_labels = np.array([], dtype=bool)
self._compression = compression
self._output_file: Buffer | None = None
self._converted_names: dict[Hashable, str] = {}
# attach nobs, nvars, data, varlist, typlist
self._prepare_pandas(data)
self.storage_options = storage_options
Expand All @@ -2243,7 +2307,6 @@ def __init__(
self._byteorder = _set_endianness(byteorder)
self._fname = fname
self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
self._converted_names: dict[Hashable, str] = {}

def _write(self, to_write: str) -> None:
"""
Expand All @@ -2259,17 +2322,50 @@ def _write_bytes(self, value: bytes) -> None:
"""
self.handles.handle.write(value) # type: ignore[arg-type]

def _prepare_non_cat_value_labels(
self, data: DataFrame
) -> list[StataNonCatValueLabel]:
"""
Check for value labels provided for non-categorical columns. Value
labels
"""
non_cat_value_labels: list[StataNonCatValueLabel] = []
if self._non_cat_value_labels is None:
return non_cat_value_labels

for labname, labels in self._non_cat_value_labels.items():
if labname in self._converted_names:
colname = self._converted_names[labname]
elif labname in data.columns:
colname = str(labname)
else:
raise KeyError(
f"Can't create value labels for {labname}, it wasn't "
"found in the dataset."
)

if not is_numeric_dtype(data[colname].dtype):
# Labels should not be passed explicitly for categorical
# columns that will be converted to int
raise ValueError(
f"Can't create value labels for {labname}, value labels "
"can only be applied to numeric columns."
)
svl = StataNonCatValueLabel(colname, labels)
non_cat_value_labels.append(svl)
return non_cat_value_labels

def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
"""
Check for categorical columns, retain categorical information for
Stata file and convert categorical data to int
"""
is_cat = [is_categorical_dtype(data[col].dtype) for col in data]
self._is_col_cat = is_cat
self._value_labels: list[StataValueLabel] = []
if not any(is_cat):
return data

self._has_value_labels |= np.array(is_cat)

get_base_missing_value = StataMissingValue.get_base_missing_value
data_formatted = []
for col, col_is_cat in zip(data, is_cat):
Expand Down Expand Up @@ -2449,6 +2545,17 @@ def _prepare_pandas(self, data: DataFrame) -> None:
# Replace NaNs with Stata missing values
data = self._replace_nans(data)

# Set all columns to initially unlabelled
self._has_value_labels = np.repeat(False, data.shape[1])

# Create value labels for non-categorical data
non_cat_value_labels = self._prepare_non_cat_value_labels(data)

non_cat_columns = [svl.labname for svl in non_cat_value_labels]
has_non_cat_val_labels = data.columns.isin(non_cat_columns)
self._has_value_labels |= has_non_cat_val_labels
self._value_labels.extend(non_cat_value_labels)

# Convert categoricals to int data, and strip labels
data = self._prepare_categoricals(data)

Expand Down Expand Up @@ -2688,7 +2795,7 @@ def _write_value_label_names(self) -> None:
# lbllist, 33*nvar, char array
for i in range(self.nvar):
# Use variable name when categorical
if self._is_col_cat[i]:
if self._has_value_labels[i]:
name = self.varlist[i]
name = self._null_terminate_str(name)
name = _pad_bytes(name[:32], 33)
Expand Down Expand Up @@ -3059,6 +3166,13 @@ class StataWriter117(StataWriter):

.. versionadded:: 1.1.0

value_labels : dict of dicts
Dictionary containing columns as keys and dictionaries of column value
to labels as values. The combined length of all labels for a single
variable must be 32,000 characters or smaller.

.. versionadded:: 1.4.0

Returns
-------
writer : StataWriter117 instance
Expand Down Expand Up @@ -3112,6 +3226,8 @@ def __init__(
convert_strl: Sequence[Hashable] | None = None,
compression: CompressionOptions = "infer",
storage_options: StorageOptions = None,
*,
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
):
# Copy to new list since convert_strl might be modified later
self._convert_strl: list[Hashable] = []
Expand All @@ -3127,6 +3243,7 @@ def __init__(
time_stamp=time_stamp,
data_label=data_label,
variable_labels=variable_labels,
value_labels=value_labels,
compression=compression,
storage_options=storage_options,
)
Expand Down Expand Up @@ -3272,7 +3389,7 @@ def _write_value_label_names(self) -> None:
for i in range(self.nvar):
# Use variable name when categorical
name = "" # default name
if self._is_col_cat[i]:
if self._has_value_labels[i]:
name = self.varlist[i]
name = self._null_terminate_str(name)
encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
Expand Down Expand Up @@ -3449,6 +3566,13 @@ class StataWriterUTF8(StataWriter117):

.. versionadded:: 1.1.0

value_labels : dict of dicts
Dictionary containing columns as keys and dictionaries of column value
to labels as values. The combined length of all labels for a single
variable must be 32,000 characters or smaller.

.. versionadded:: 1.4.0

Returns
-------
StataWriterUTF8
Expand Down Expand Up @@ -3505,6 +3629,8 @@ def __init__(
version: int | None = None,
compression: CompressionOptions = "infer",
storage_options: StorageOptions = None,
*,
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
):
if version is None:
version = 118 if data.shape[1] <= 32767 else 119
Expand All @@ -3525,6 +3651,7 @@ def __init__(
time_stamp=time_stamp,
data_label=data_label,
variable_labels=variable_labels,
value_labels=value_labels,
convert_strl=convert_strl,
compression=compression,
storage_options=storage_options,
Expand Down
Loading