Skip to content

Commit 0e2c938

Browse files
committed
BUG: Error in to_stata when DataFrame contains non-string column names
to_stata does not work correctly when used with non-string names. Since Stata requires string names, the proposed fix attempts to rename columns using the string representation of the column name used. The main method that reformats column names was refactored to handle this case. Patch includes additional fixes for detecting invalid names. Patch includes some minor documentation fixes.
1 parent e19b2eb commit 0e2c938

File tree

3 files changed

+115
-54
lines changed

3 files changed

+115
-54
lines changed

doc/source/release.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ Bug Fixes
233233
- Bug in popping from a Series (:issue:`6600`)
234234
- Bug in ``iloc`` indexing when positional indexer matched Int64Index of corresponding axis no reordering happened (:issue:`6612`)
235235
- Bug in ``fillna`` with ``limit`` and ``value`` specified
236+
- Bug in ``DataFrame.to_stata`` when columns have non-string names (:issue:`4558`)
236237

237238
pandas 0.13.1
238239
-------------

pandas/io/stata.py

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pandas.core.categorical import Categorical
2121
import datetime
2222
from pandas import compat
23-
from pandas.compat import long, lrange, lmap, lzip
23+
from pandas.compat import long, lrange, lmap, lzip, text_type, string_types
2424
from pandas import isnull
2525
from pandas.io.common import get_filepath_or_buffer
2626
from pandas.tslib import NaT
@@ -191,6 +191,21 @@ class PossiblePrecisionLoss(Warning):
191191
"""
192192

193193

194+
class InvalidColumnName(Warning):
195+
pass
196+
197+
198+
invalid_name_doc = """
199+
Not all pandas column names were valid Stata variable names.
200+
The following replacements have been made:
201+
202+
{0}
203+
204+
If this is not what you expect, please make sure you have Stata-compliant
205+
column names in your DataFrame (strings only, max 32 characters, only alphanumerics and
206+
underscores, no Stata reserved words)
207+
"""
208+
194209
def _cast_to_stata_types(data):
195210
"""Checks the dtypes of the columns of a pandas DataFrame for
196211
compatibility with the data types and ranges supported by Stata, and
@@ -942,7 +957,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist):
942957
else:
943958
if not isinstance(key, int):
944959
raise ValueError(
945-
"convery_dates key is not in varlist and is not an int"
960+
"convert_dates key is not in varlist and is not an int"
946961
)
947962
new_dict.update({key: convert_dates[key]})
948963
return new_dict
@@ -1092,6 +1107,78 @@ def _write(self, to_write):
10921107
else:
10931108
self._file.write(to_write)
10941109

1110+
1111+
def _check_column_names(self, data):
1112+
"""Checks column names to ensure that they are valid Stata column names.
1113+
This includes checks for:
1114+
* Non-string names
1115+
* Stata keywords
1116+
* Variables that start with numbers
1117+
* Variables with names that are too long
1118+
1119+
When an illegal variable name is detected, it is converted, and if dates
1120+
are exported, the variable name is propogated to the date conversion
1121+
dictionary
1122+
"""
1123+
converted_names = []
1124+
columns = list(data.columns)
1125+
original_columns = columns[:]
1126+
1127+
duplicate_var_id = 0
1128+
for j, name in enumerate(columns):
1129+
orig_name = name
1130+
if not isinstance(name, string_types):
1131+
name = text_type(name)
1132+
1133+
for c in name:
1134+
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and \
1135+
(c < '0' or c > '9') and c != '_':
1136+
name = name.replace(c, '_')
1137+
1138+
# Variable name must not be a reserved word
1139+
if name in self.RESERVED_WORDS:
1140+
name = '_' + name
1141+
1142+
# Variable name may not start with a number
1143+
if name[0] >= '0' and name[0] <= '9':
1144+
name = '_' + name
1145+
1146+
name = name[:min(len(name), 32)]
1147+
1148+
if not name == orig_name:
1149+
# check for duplicates
1150+
while columns.count(name) > 0:
1151+
# prepend ascending number to avoid duplicates
1152+
name = '_' + str(duplicate_var_id) + name
1153+
name = name[:min(len(name), 32)]
1154+
duplicate_var_id += 1
1155+
1156+
# need to possibly encode the orig name if its unicode
1157+
try:
1158+
orig_name = orig_name.encode('utf-8')
1159+
except:
1160+
pass
1161+
converted_names.append('{0} -> {1}'.format(orig_name, name))
1162+
1163+
columns[j] = name
1164+
1165+
data.columns = columns
1166+
1167+
# Check date conversion, and fix key if needed
1168+
if self._convert_dates:
1169+
for c, o in zip(columns, original_columns):
1170+
if c != o:
1171+
self._convert_dates[c] = self._convert_dates[o]
1172+
del self._convert_dates[o]
1173+
1174+
if converted_names:
1175+
import warnings
1176+
1177+
ws = invalid_name_doc.format('\n '.join(converted_names))
1178+
warnings.warn(ws, InvalidColumnName)
1179+
1180+
return data
1181+
10951182
def _prepare_pandas(self, data):
10961183
#NOTE: we might need a different API / class for pandas objects so
10971184
# we can set different semantics - handle this with a PR to pandas.io
@@ -1108,6 +1195,8 @@ def __iter__(self):
11081195
data = data.reset_index()
11091196
# Check columns for compatibility with stata
11101197
data = _cast_to_stata_types(data)
1198+
# Ensure column names are strings
1199+
data = self._check_column_names(data)
11111200
self.datarows = DataFrameRowIter(data)
11121201
self.nobs, self.nvar = data.shape
11131202
self.data = data
@@ -1181,58 +1270,13 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
11811270
for typ in self.typlist:
11821271
self._write(typ)
11831272

1184-
# varlist, length 33*nvar, char array, null terminated
1185-
converted_names = []
1186-
duplicate_var_id = 0
1187-
for j, name in enumerate(self.varlist):
1188-
orig_name = name
1189-
# Replaces all characters disallowed in .dta format by their integral representation.
1190-
for c in name:
1191-
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
1192-
name = name.replace(c, '_')
1193-
# Variable name must not be a reserved word
1194-
if name in self.RESERVED_WORDS:
1195-
name = '_' + name
1196-
# Variable name may not start with a number
1197-
if name[0] > '0' and name[0] < '9':
1198-
name = '_' + name
1199-
1200-
name = name[:min(len(name), 32)]
1201-
1202-
if not name == orig_name:
1203-
# check for duplicates
1204-
while self.varlist.count(name) > 0:
1205-
# prepend ascending number to avoid duplicates
1206-
name = '_' + str(duplicate_var_id) + name
1207-
name = name[:min(len(name), 32)]
1208-
duplicate_var_id += 1
1209-
1210-
# need to possibly encode the orig name if its unicode
1211-
try:
1212-
orig_name = orig_name.encode('utf-8')
1213-
except:
1214-
pass
1215-
1216-
converted_names.append('{0} -> {1}'.format(orig_name, name))
1217-
self.varlist[j] = name
1218-
1273+
# varlist names are checked by _check_column_names
1274+
# varlist, requires null terminated
12191275
for name in self.varlist:
12201276
name = self._null_terminate(name, True)
12211277
name = _pad_bytes(name[:32], 33)
12221278
self._write(name)
12231279

1224-
if converted_names:
1225-
from warnings import warn
1226-
warn("""Not all pandas column names were valid Stata variable names.
1227-
Made the following replacements:
1228-
1229-
{0}
1230-
1231-
If this is not what you expect, please make sure you have Stata-compliant
1232-
column names in your DataFrame (max 32 characters, only alphanumerics and
1233-
underscores)/
1234-
""".format('\n '.join(converted_names)))
1235-
12361280
# srtlist, 2*(nvar+1), int array, encoded by byteorder
12371281
srtlist = _pad_bytes("", (2*(nvar+1)))
12381282
self._write(srtlist)

pandas/io/tests/test_stata.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pandas as pd
1414
from pandas.core.frame import DataFrame, Series
1515
from pandas.io.parsers import read_csv
16-
from pandas.io.stata import read_stata, StataReader
16+
from pandas.io.stata import read_stata, StataReader, InvalidColumnName
1717
import pandas.util.testing as tm
1818
from pandas.util.misc import is_little_endian
1919
from pandas import compat
@@ -332,10 +332,10 @@ def test_read_write_dta12(self):
332332
tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted)
333333

334334
def test_read_write_dta13(self):
335-
s1 = Series(2**9,dtype=np.int16)
336-
s2 = Series(2**17,dtype=np.int32)
337-
s3 = Series(2**33,dtype=np.int64)
338-
original = DataFrame({'int16':s1,'int32':s2,'int64':s3})
335+
s1 = Series(2**9, dtype=np.int16)
336+
s2 = Series(2**17, dtype=np.int32)
337+
s3 = Series(2**33, dtype=np.int64)
338+
original = DataFrame({'int16': s1, 'int32': s2, 'int64': s3})
339339
original.index.name = 'index'
340340

341341
formatted = original
@@ -398,6 +398,22 @@ def test_timestamp_and_label(self):
398398
assert parsed_time_stamp == time_stamp
399399
assert reader.data_label == data_label
400400

401+
def test_numeric_column_names(self):
402+
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
403+
original.index.name = 'index'
404+
with tm.ensure_clean() as path:
405+
# should get a warning for that format.
406+
with warnings.catch_warnings(record=True) as w:
407+
tm.assert_produces_warning(original.to_stata(path), InvalidColumnName)
408+
# should produce a single warning
409+
np.testing.assert_equal(len(w), 1)
410+
411+
written_and_read_again = self.read_dta(path)
412+
written_and_read_again = written_and_read_again.set_index('index')
413+
columns = list(written_and_read_again.columns)
414+
convert_col_name = lambda x: int(x[1])
415+
written_and_read_again.columns = map(convert_col_name, columns)
416+
tm.assert_frame_equal(original, written_and_read_again)
401417

402418

403419
if __name__ == '__main__':

0 commit comments

Comments
 (0)