-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: option to export df to Stata dataset with value labels #41042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4db9ccb
5a3d6d9
bbb43f8
533a3d5
ed73d69
ae4dca7
8e57e46
70dc88b
2796d1f
277896a
c31034d
85374fd
4ac27db
d2a5584
d30d1fe
05d4d74
0311142
46d3783
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import struct | ||
import sys | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
AnyStr, | ||
Hashable, | ||
|
@@ -46,6 +47,7 @@ | |
ensure_object, | ||
is_categorical_dtype, | ||
is_datetime64_dtype, | ||
is_numeric_dtype, | ||
) | ||
|
||
from pandas import ( | ||
|
@@ -64,6 +66,9 @@ | |
|
||
from pandas.io.common import get_handle | ||
|
||
if TYPE_CHECKING: | ||
lmcindewar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from typing import Literal | ||
|
||
_version_error = ( | ||
"Version of given Stata file is {version}. pandas supports importing " | ||
"versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), " | ||
|
@@ -658,24 +663,37 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"): | |
self.labname = catarray.name | ||
self._encoding = encoding | ||
categories = catarray.cat.categories | ||
self.value_labels = list(zip(np.arange(len(categories)), categories)) | ||
self.value_labels: list[tuple[int | float, str]] = list( | ||
zip(np.arange(len(categories)), categories) | ||
) | ||
self.value_labels.sort(key=lambda x: x[0]) | ||
|
||
self._prepare_value_labels() | ||
|
||
def _prepare_value_labels(self): | ||
"""Encode value labels.""" | ||
|
||
self.text_len = 0 | ||
self.txt: list[bytes] = [] | ||
self.n = 0 | ||
# Offsets (length of categories), converted to int32 | ||
self.off = np.array([]) | ||
lmcindewar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Values, converted to int32 | ||
self.val = np.array([]) | ||
self.len = 0 | ||
|
||
# Compute lengths and setup lists of offsets and labels | ||
offsets: list[int] = [] | ||
values: list[int] = [] | ||
values: list[int | float] = [] | ||
for vl in self.value_labels: | ||
category = vl[1] | ||
category: str | bytes = vl[1] | ||
if not isinstance(category, str): | ||
category = str(category) | ||
warnings.warn( | ||
value_label_mismatch_doc.format(catarray.name), | ||
value_label_mismatch_doc.format(self.labname), | ||
ValueLabelTypeMismatch, | ||
) | ||
category = category.encode(encoding) | ||
category = category.encode(self._encoding) | ||
offsets.append(self.text_len) | ||
self.text_len += len(category) + 1 # +1 for the padding | ||
values.append(vl[0]) | ||
|
@@ -748,6 +766,38 @@ def generate_value_label(self, byteorder: str) -> bytes: | |
return bio.getvalue() | ||
|
||
|
||
class StataNonCatValueLabel(StataValueLabel): | ||
""" | ||
Prepare formatted version of value labels | ||
|
||
Parameters | ||
---------- | ||
labname : str | ||
Value label name | ||
value_labels: Dictionary | ||
lmcindewar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Mapping of values to labels | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
encoding : {"latin-1", "utf-8"} | ||
Encoding to use for value labels. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
labname: str, | ||
value_labels: dict[float | int, str], | ||
encoding: Literal["latin-1", "utf-8"] = "latin-1", | ||
): | ||
|
||
if encoding not in ("latin-1", "utf-8"): | ||
raise ValueError("Only latin-1 and utf-8 are supported.") | ||
|
||
self.labname = labname | ||
self._encoding = encoding | ||
self.value_labels: list[tuple[int | float, str]] = sorted( | ||
value_labels.items(), key=lambda x: x[0] | ||
) | ||
self._prepare_value_labels() | ||
|
||
|
||
class StataMissingValue: | ||
""" | ||
An observation's missing value. | ||
|
@@ -2175,6 +2225,13 @@ class StataWriter(StataParser): | |
|
||
.. versionadded:: 1.2.0 | ||
|
||
value_labels : dict of dicts | ||
Dictionary containing columns as keys and dictionaries of column value | ||
to labels as values. The combined length of all labels for a single | ||
variable must be 32,000 characters or smaller. | ||
|
||
lmcindewar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.. versionadded:: 1.4.0 | ||
|
||
Returns | ||
------- | ||
writer : StataWriter instance | ||
|
@@ -2225,15 +2282,22 @@ def __init__( | |
variable_labels: dict[Hashable, str] | None = None, | ||
compression: CompressionOptions = "infer", | ||
storage_options: StorageOptions = None, | ||
*, | ||
value_labels: dict[Hashable, dict[float | int, str]] | None = None, | ||
): | ||
super().__init__() | ||
self.data = data | ||
self._convert_dates = {} if convert_dates is None else convert_dates | ||
self._write_index = write_index | ||
self._time_stamp = time_stamp | ||
self._data_label = data_label | ||
self._variable_labels = variable_labels | ||
self._non_cat_value_labels = value_labels | ||
self._value_labels: list[StataValueLabel] = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't you rename this to _cat_value_labels to distingish and make consistent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _value_labels is a combined list of categorical and non-categorical labels. there isn't a standalone _cat_value_labels because they're added straight to _value_labels after converting categoricals in the dataframe |
||
self._has_value_labels = np.array([], dtype=bool) | ||
self._compression = compression | ||
self._output_file: Buffer | None = None | ||
self._converted_names: dict[Hashable, str] = {} | ||
# attach nobs, nvars, data, varlist, typlist | ||
self._prepare_pandas(data) | ||
self.storage_options = storage_options | ||
|
@@ -2243,7 +2307,6 @@ def __init__( | |
self._byteorder = _set_endianness(byteorder) | ||
self._fname = fname | ||
self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8} | ||
self._converted_names: dict[Hashable, str] = {} | ||
|
||
def _write(self, to_write: str) -> None: | ||
""" | ||
|
@@ -2259,17 +2322,50 @@ def _write_bytes(self, value: bytes) -> None: | |
""" | ||
self.handles.handle.write(value) # type: ignore[arg-type] | ||
|
||
def _prepare_non_cat_value_labels( | ||
self, data: DataFrame | ||
) -> list[StataNonCatValueLabel]: | ||
""" | ||
Check for value labels provided for non-categorical columns. Value | ||
lmcindewar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
labels | ||
""" | ||
non_cat_value_labels: list[StataNonCatValueLabel] = [] | ||
if self._non_cat_value_labels is None: | ||
return non_cat_value_labels | ||
|
||
for labname, labels in self._non_cat_value_labels.items(): | ||
if labname in self._converted_names: | ||
colname = self._converted_names[labname] | ||
elif labname in data.columns: | ||
colname = str(labname) | ||
else: | ||
raise KeyError( | ||
f"Can't create value labels for {labname}, it wasn't " | ||
"found in the dataset." | ||
) | ||
|
||
if not is_numeric_dtype(data[colname].dtype): | ||
# Labels should not be passed explicitly for categorical | ||
# columns that will be converted to int | ||
raise ValueError( | ||
f"Can't create value labels for {labname}, value labels " | ||
"can only be applied to numeric columns." | ||
) | ||
svl = StataNonCatValueLabel(colname, labels) | ||
non_cat_value_labels.append(svl) | ||
return non_cat_value_labels | ||
|
||
def _prepare_categoricals(self, data: DataFrame) -> DataFrame: | ||
""" | ||
Check for categorical columns, retain categorical information for | ||
Stata file and convert categorical data to int | ||
""" | ||
is_cat = [is_categorical_dtype(data[col].dtype) for col in data] | ||
self._is_col_cat = is_cat | ||
self._value_labels: list[StataValueLabel] = [] | ||
if not any(is_cat): | ||
return data | ||
|
||
self._has_value_labels |= np.array(is_cat) | ||
|
||
get_base_missing_value = StataMissingValue.get_base_missing_value | ||
data_formatted = [] | ||
for col, col_is_cat in zip(data, is_cat): | ||
|
@@ -2449,6 +2545,17 @@ def _prepare_pandas(self, data: DataFrame) -> None: | |
# Replace NaNs with Stata missing values | ||
data = self._replace_nans(data) | ||
|
||
# Set all columns to initially unlabelled | ||
self._has_value_labels = np.repeat(False, data.shape[1]) | ||
|
||
# Create value labels for non-categorical data | ||
non_cat_value_labels = self._prepare_non_cat_value_labels(data) | ||
|
||
non_cat_columns = [svl.labname for svl in non_cat_value_labels] | ||
has_non_cat_val_labels = data.columns.isin(non_cat_columns) | ||
self._has_value_labels |= has_non_cat_val_labels | ||
self._value_labels.extend(non_cat_value_labels) | ||
|
||
# Convert categoricals to int data, and strip labels | ||
data = self._prepare_categoricals(data) | ||
|
||
|
@@ -2688,7 +2795,7 @@ def _write_value_label_names(self) -> None: | |
# lbllist, 33*nvar, char array | ||
for i in range(self.nvar): | ||
# Use variable name when categorical | ||
if self._is_col_cat[i]: | ||
if self._has_value_labels[i]: | ||
name = self.varlist[i] | ||
name = self._null_terminate_str(name) | ||
name = _pad_bytes(name[:32], 33) | ||
|
@@ -3059,6 +3166,13 @@ class StataWriter117(StataWriter): | |
|
||
.. versionadded:: 1.1.0 | ||
|
||
value_labels : dict of dicts | ||
Dictionary containing columns as keys and dictionaries of column value | ||
to labels as values. The combined length of all labels for a single | ||
variable must be 32,000 characters or smaller. | ||
|
||
.. versionadded:: 1.4.0 | ||
|
||
Returns | ||
------- | ||
writer : StataWriter117 instance | ||
|
@@ -3112,6 +3226,8 @@ def __init__( | |
convert_strl: Sequence[Hashable] | None = None, | ||
compression: CompressionOptions = "infer", | ||
storage_options: StorageOptions = None, | ||
*, | ||
value_labels: dict[Hashable, dict[float | int, str]] | None = None, | ||
): | ||
# Copy to new list since convert_strl might be modified later | ||
self._convert_strl: list[Hashable] = [] | ||
|
@@ -3127,6 +3243,7 @@ def __init__( | |
time_stamp=time_stamp, | ||
data_label=data_label, | ||
variable_labels=variable_labels, | ||
value_labels=value_labels, | ||
compression=compression, | ||
storage_options=storage_options, | ||
) | ||
|
@@ -3272,7 +3389,7 @@ def _write_value_label_names(self) -> None: | |
for i in range(self.nvar): | ||
# Use variable name when categorical | ||
name = "" # default name | ||
if self._is_col_cat[i]: | ||
if self._has_value_labels[i]: | ||
name = self.varlist[i] | ||
name = self._null_terminate_str(name) | ||
encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) | ||
|
@@ -3449,6 +3566,13 @@ class StataWriterUTF8(StataWriter117): | |
|
||
.. versionadded:: 1.1.0 | ||
|
||
value_labels : dict of dicts | ||
Dictionary containing columns as keys and dictionaries of column value | ||
to labels as values. The combined length of all labels for a single | ||
variable must be 32,000 characters or smaller. | ||
|
||
.. versionadded:: 1.4.0 | ||
|
||
Returns | ||
------- | ||
StataWriterUTF8 | ||
|
@@ -3505,6 +3629,8 @@ def __init__( | |
version: int | None = None, | ||
compression: CompressionOptions = "infer", | ||
storage_options: StorageOptions = None, | ||
*, | ||
value_labels: dict[Hashable, dict[float | int, str]] | None = None, | ||
): | ||
if version is None: | ||
version = 118 if data.shape[1] <= 32767 else 119 | ||
|
@@ -3525,6 +3651,7 @@ def __init__( | |
time_stamp=time_stamp, | ||
data_label=data_label, | ||
variable_labels=variable_labels, | ||
value_labels=value_labels, | ||
convert_strl=convert_strl, | ||
compression=compression, | ||
storage_options=storage_options, | ||
|
Uh oh!
There was an error while loading. Please reload this page.