Skip to content

Commit 61bd98a

Browse files
Terji PetersenTerji Petersen
Terji Petersen
authored and
Terji Petersen
committed
fix test_stata.py
1 parent 3ee39c7 commit 61bd98a

File tree

1 file changed

+112
-50
lines changed

1 file changed

+112
-50
lines changed

pandas/tests/io/test_stata.py

Lines changed: 112 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,10 @@ def test_read_write_dta5(self):
287287
with tm.ensure_clean() as path:
288288
original.to_stata(path, convert_dates=None)
289289
written_and_read_again = self.read_dta(path)
290-
tm.assert_frame_equal(written_and_read_again.set_index("index"), original)
290+
291+
expected = original.copy()
292+
expected.index = expected.index.astype(np.int32)
293+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
291294

292295
def test_write_dta6(self, datapath):
293296
original = self.read_csv(datapath("io", "data", "stata", "stata3.csv"))
@@ -380,7 +383,10 @@ def test_read_write_dta11(self):
380383
original.to_stata(path, convert_dates=None)
381384

382385
written_and_read_again = self.read_dta(path)
383-
tm.assert_frame_equal(written_and_read_again.set_index("index"), formatted)
386+
387+
expected = formatted.copy()
388+
expected.index = expected.index.astype(np.int32)
389+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
384390

385391
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
386392
def test_read_write_dta12(self, version):
@@ -417,7 +423,10 @@ def test_read_write_dta12(self, version):
417423
assert len(w) == 1
418424

419425
written_and_read_again = self.read_dta(path)
420-
tm.assert_frame_equal(written_and_read_again.set_index("index"), formatted)
426+
427+
expected = formatted.copy()
428+
expected.index = expected.index.astype(np.int32)
429+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
421430

422431
def test_read_write_dta13(self):
423432
s1 = Series(2**9, dtype=np.int16)
@@ -432,7 +441,10 @@ def test_read_write_dta13(self):
432441
with tm.ensure_clean() as path:
433442
original.to_stata(path)
434443
written_and_read_again = self.read_dta(path)
435-
tm.assert_frame_equal(written_and_read_again.set_index("index"), formatted)
444+
445+
expected = formatted.copy()
446+
expected.index = expected.index.astype(np.int32)
447+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
436448

437449
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
438450
@pytest.mark.parametrize(
@@ -455,7 +467,10 @@ def test_read_write_reread_dta14(self, file, parsed_114, version, datapath):
455467
with tm.ensure_clean() as path:
456468
parsed_114.to_stata(path, convert_dates={"date_td": "td"}, version=version)
457469
written_and_read_again = self.read_dta(path)
458-
tm.assert_frame_equal(written_and_read_again.set_index("index"), parsed_114)
470+
471+
expected = parsed_114.copy()
472+
expected.index = expected.index.astype(np.int32)
473+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
459474

460475
@pytest.mark.parametrize(
461476
"file", ["stata6_113", "stata6_114", "stata6_115", "stata6_117"]
@@ -510,11 +525,15 @@ def test_numeric_column_names(self):
510525
original.to_stata(path)
511526

512527
written_and_read_again = self.read_dta(path)
513-
written_and_read_again = written_and_read_again.set_index("index")
514-
columns = list(written_and_read_again.columns)
515-
convert_col_name = lambda x: int(x[1])
516-
written_and_read_again.columns = map(convert_col_name, columns)
517-
tm.assert_frame_equal(original, written_and_read_again)
528+
529+
written_and_read_again = written_and_read_again.set_index("index")
530+
columns = list(written_and_read_again.columns)
531+
convert_col_name = lambda x: int(x[1])
532+
written_and_read_again.columns = map(convert_col_name, columns)
533+
534+
expected = original.copy()
535+
expected.index = expected.index.astype(np.int32)
536+
tm.assert_frame_equal(expected, written_and_read_again)
518537

519538
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
520539
def test_nan_to_missing_value(self, version):
@@ -524,11 +543,15 @@ def test_nan_to_missing_value(self, version):
524543
s2[1::2] = np.nan
525544
original = DataFrame({"s1": s1, "s2": s2})
526545
original.index.name = "index"
546+
527547
with tm.ensure_clean() as path:
528548
original.to_stata(path, version=version)
529549
written_and_read_again = self.read_dta(path)
530-
written_and_read_again = written_and_read_again.set_index("index")
531-
tm.assert_frame_equal(written_and_read_again, original)
550+
551+
written_and_read_again = written_and_read_again.set_index("index")
552+
expected = original.copy()
553+
expected.index = expected.index.astype(np.int32)
554+
tm.assert_frame_equal(written_and_read_again, expected)
532555

533556
def test_no_index(self):
534557
columns = ["x", "y"]
@@ -548,7 +571,10 @@ def test_string_no_dates(self):
548571
with tm.ensure_clean() as path:
549572
original.to_stata(path)
550573
written_and_read_again = self.read_dta(path)
551-
tm.assert_frame_equal(written_and_read_again.set_index("index"), original)
574+
575+
expected = original.copy()
576+
expected.index = expected.index.astype(np.int32)
577+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
552578

553579
def test_large_value_conversion(self):
554580
s0 = Series([1, 99], dtype=np.int8)
@@ -562,11 +588,13 @@ def test_large_value_conversion(self):
562588
original.to_stata(path)
563589

564590
written_and_read_again = self.read_dta(path)
565-
modified = original.copy()
566-
modified["s1"] = Series(modified["s1"], dtype=np.int16)
567-
modified["s2"] = Series(modified["s2"], dtype=np.int32)
568-
modified["s3"] = Series(modified["s3"], dtype=np.float64)
569-
tm.assert_frame_equal(written_and_read_again.set_index("index"), modified)
591+
592+
modified = original.copy()
593+
modified["s1"] = Series(modified["s1"], dtype=np.int16)
594+
modified["s2"] = Series(modified["s2"], dtype=np.int32)
595+
modified["s3"] = Series(modified["s3"], dtype=np.float64)
596+
modified.index = original.index.astype(np.int32)
597+
tm.assert_frame_equal(written_and_read_again.set_index("index"), modified)
570598

571599
def test_dates_invalid_column(self):
572600
original = DataFrame([datetime(2006, 11, 19, 23, 13, 20)])
@@ -576,9 +604,11 @@ def test_dates_invalid_column(self):
576604
original.to_stata(path, convert_dates={0: "tc"})
577605

578606
written_and_read_again = self.read_dta(path)
579-
modified = original.copy()
580-
modified.columns = ["_0"]
581-
tm.assert_frame_equal(written_and_read_again.set_index("index"), modified)
607+
608+
modified = original.copy()
609+
modified.columns = ["_0"]
610+
modified.index = original.index.astype(np.int32)
611+
tm.assert_frame_equal(written_and_read_again.set_index("index"), modified)
582612

583613
def test_105(self, datapath):
584614
# Data obtained from:
@@ -619,21 +649,32 @@ def test_date_export_formats(self):
619649
datetime(2006, 1, 1),
620650
] # Year
621651

622-
expected = DataFrame([expected_values], columns=columns)
623-
expected.index.name = "index"
652+
expected = DataFrame(
653+
[expected_values],
654+
index=pd.Index([0], dtype=np.int32, name="index"),
655+
columns=columns,
656+
)
657+
624658
with tm.ensure_clean() as path:
625659
original.to_stata(path, convert_dates=conversions)
626660
written_and_read_again = self.read_dta(path)
627-
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
661+
662+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
628663

629664
def test_write_missing_strings(self):
630665
original = DataFrame([["1"], [None]], columns=["foo"])
631-
expected = DataFrame([["1"], [""]], columns=["foo"])
632-
expected.index.name = "index"
666+
667+
expected = DataFrame(
668+
[["1"], [""]],
669+
index=pd.Index([0, 1], dtype=np.int32, name="index"),
670+
columns=["foo"],
671+
)
672+
633673
with tm.ensure_clean() as path:
634674
original.to_stata(path)
635675
written_and_read_again = self.read_dta(path)
636-
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
676+
677+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
637678

638679
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
639680
@pytest.mark.parametrize("byteorder", [">", "<"])
@@ -651,6 +692,7 @@ def test_bool_uint(self, byteorder, version):
651692
)
652693
original.index.name = "index"
653694
expected = original.copy()
695+
expected.index = original.index.astype(np.int32)
654696
expected_types = (
655697
np.int8,
656698
np.int8,
@@ -666,8 +708,9 @@ def test_bool_uint(self, byteorder, version):
666708
with tm.ensure_clean() as path:
667709
original.to_stata(path, byteorder=byteorder, version=version)
668710
written_and_read_again = self.read_dta(path)
669-
written_and_read_again = written_and_read_again.set_index("index")
670-
tm.assert_frame_equal(written_and_read_again, expected)
711+
712+
written_and_read_again = written_and_read_again.set_index("index")
713+
tm.assert_frame_equal(written_and_read_again, expected)
671714

672715
def test_variable_labels(self, datapath):
673716
with StataReader(datapath("io", "data", "stata", "stata7_115.dta")) as rdr:
@@ -818,11 +861,12 @@ def test_big_dates(self, datapath):
818861
expected.index.name = "index"
819862
expected.to_stata(path, convert_dates=date_conversion)
820863
written_and_read_again = self.read_dta(path)
821-
tm.assert_frame_equal(
822-
written_and_read_again.set_index("index"),
823-
expected,
824-
check_datetimelike_compat=True,
825-
)
864+
865+
tm.assert_frame_equal(
866+
written_and_read_again.set_index("index"),
867+
expected.set_index(expected.index.astype(np.int32)),
868+
check_datetimelike_compat=True,
869+
)
826870

827871
def test_dtype_conversion(self, datapath):
828872
expected = self.read_csv(datapath("io", "data", "stata", "stata6.csv"))
@@ -936,7 +980,7 @@ def test_categorical_writing(self, version):
936980
original = pd.concat(
937981
[original[col].astype("category") for col in original], axis=1
938982
)
939-
expected.index.name = "index"
983+
expected.index = expected.index.set_names("index").astype(np.int32)
940984

941985
expected["incompletely_labeled"] = expected["incompletely_labeled"].apply(str)
942986
expected["unlabeled"] = expected["unlabeled"].apply(str)
@@ -955,8 +999,9 @@ def test_categorical_writing(self, version):
955999
with tm.ensure_clean() as path:
9561000
original.to_stata(path, version=version)
9571001
written_and_read_again = self.read_dta(path)
958-
res = written_and_read_again.set_index("index")
959-
tm.assert_frame_equal(res, expected)
1002+
1003+
res = written_and_read_again.set_index("index")
1004+
tm.assert_frame_equal(res, expected)
9601005

9611006
def test_categorical_warnings_and_errors(self):
9621007
# Warning for non-string labels
@@ -1000,15 +1045,17 @@ def test_categorical_with_stata_missing_values(self, version):
10001045
with tm.ensure_clean() as path:
10011046
original.to_stata(path, version=version)
10021047
written_and_read_again = self.read_dta(path)
1003-
res = written_and_read_again.set_index("index")
10041048

1005-
expected = original.copy()
1006-
for col in expected:
1007-
cat = expected[col]._values
1008-
new_cats = cat.remove_unused_categories().categories
1009-
cat = cat.set_categories(new_cats, ordered=True)
1010-
expected[col] = cat
1011-
tm.assert_frame_equal(res, expected)
1049+
res = written_and_read_again.set_index("index")
1050+
1051+
expected = original.copy()
1052+
for col in expected:
1053+
cat = expected[col]._values
1054+
new_cats = cat.remove_unused_categories().categories
1055+
cat = cat.set_categories(new_cats, ordered=True)
1056+
expected[col] = cat
1057+
expected.index = expected.index.astype(np.int32)
1058+
tm.assert_frame_equal(res, expected)
10121059

10131060
@pytest.mark.parametrize("file", ["stata10_115", "stata10_117"])
10141061
def test_categorical_order(self, file, datapath):
@@ -1456,8 +1503,11 @@ def test_out_of_range_float(self):
14561503
with tm.ensure_clean() as path:
14571504
original.to_stata(path)
14581505
reread = read_stata(path)
1459-
original["ColumnTooBig"] = original["ColumnTooBig"].astype(np.float64)
1460-
tm.assert_frame_equal(original, reread.set_index("index"))
1506+
1507+
original["ColumnTooBig"] = original["ColumnTooBig"].astype(np.float64)
1508+
expected = original.copy()
1509+
expected.index = expected.index.astype(np.int32)
1510+
tm.assert_frame_equal(reread.set_index("index"), expected)
14611511

14621512
@pytest.mark.parametrize("infval", [np.inf, -np.inf])
14631513
def test_inf(self, infval):
@@ -1885,7 +1935,10 @@ def test_compression(compression, version, use_dict, infer):
18851935
elif compression is None:
18861936
fp = path
18871937
reread = read_stata(fp, index_col="index")
1888-
tm.assert_frame_equal(reread, df)
1938+
1939+
expected = df.copy()
1940+
expected.index = expected.index.astype(np.int32)
1941+
tm.assert_frame_equal(reread, expected)
18891942

18901943

18911944
@pytest.mark.parametrize("method", ["zip", "infer"])
@@ -1906,20 +1959,29 @@ def test_compression_dict(method, file_ext):
19061959
else:
19071960
fp = path
19081961
reread = read_stata(fp, index_col="index")
1909-
tm.assert_frame_equal(reread, df)
1962+
1963+
expected = df.copy()
1964+
expected.index = expected.index.astype(np.int32)
1965+
tm.assert_frame_equal(reread, expected)
19101966

19111967

19121968
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
19131969
def test_chunked_categorical(version):
19141970
df = DataFrame({"cats": Series(["a", "b", "a", "b", "c"], dtype="category")})
19151971
df.index.name = "index"
1972+
1973+
expected = df.copy()
1974+
expected.index = expected.index.astype(np.int32)
1975+
19161976
with tm.ensure_clean() as path:
19171977
df.to_stata(path, version=version)
19181978
with StataReader(path, chunksize=2, order_categoricals=False) as reader:
19191979
for i, block in enumerate(reader):
19201980
block = block.set_index("index")
19211981
assert "cats" in block
1922-
tm.assert_series_equal(block.cats, df.cats.iloc[2 * i : 2 * (i + 1)])
1982+
tm.assert_series_equal(
1983+
block.cats, expected.cats.iloc[2 * i : 2 * (i + 1)]
1984+
)
19231985

19241986

19251987
def test_chunked_categorical_partial(datapath):

0 commit comments

Comments
 (0)