Skip to content

Commit a45600d

Browse files
authored
ENH: Synchronize io/stata with pandas master (#202)
* ENH: Synchronize io/stata with pandas master Sychronize and remvoe classes not part of the public API * MAINT: Update frame stata io * ENH: Add literals Add literals for limited value inputs * CLN: Remove non-public classes
1 parent 1aa40eb commit a45600d

File tree

4 files changed

+104
-117
lines changed

4 files changed

+104
-117
lines changed

pandas-stubs/_typing.pyi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,23 @@ GroupByObjectNonScalar = Union[
203203
]
204204
GroupByObject = Union[Scalar, GroupByObjectNonScalar]
205205

206+
StataDateFormat = Literal[
207+
"tc",
208+
"%tc",
209+
"td",
210+
"%td",
211+
"tw",
212+
"%tw",
213+
"tm",
214+
"%tm",
215+
"tq",
216+
"%tq",
217+
"th",
218+
"%th",
219+
"ty",
220+
"%ty",
221+
]
222+
206223
FillnaOptions = Literal["backfill", "bfill", "ffill", "pad"]
207224
ReplaceMethod = Literal["pad", "ffill", "bfill"]
208225
SortKind = Literal["quicksort", "mergesort", "heapsort", "stable"]

pandas-stubs/core/frame.pyi

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ from pandas._typing import (
5151
CompressionOptions,
5252
Dtype,
5353
DtypeNp,
54+
FilePath,
5455
FilePathOrBuffer,
5556
FilePathOrBytesBuffer,
5657
FillnaOptions,
@@ -74,9 +75,12 @@ from pandas._typing import (
7475
ScalarT,
7576
SeriesAxisType,
7677
SortKind,
78+
StataDateFormat,
79+
StorageOptions,
7780
StrLike,
7881
T as TType,
7982
TimestampConvention,
83+
WriteBuffer,
8084
np_ndarray_bool,
8185
np_ndarray_str,
8286
num,
@@ -245,15 +249,19 @@ class DataFrame(NDFrame, OpsMixin):
245249
) -> np.recarray: ...
246250
def to_stata(
247251
self,
248-
path: FilePathOrBuffer,
249-
convert_dates: dict | None = ...,
252+
path: FilePath | WriteBuffer[bytes],
253+
convert_dates: dict[HashableT, StataDateFormat] | None = ...,
250254
write_index: _bool = ...,
251255
byteorder: Literal["<", ">", "little", "big"] | None = ...,
252-
time_stamp=...,
256+
time_stamp: _dt.datetime | None = ...,
253257
data_label: _str | None = ...,
254-
variable_labels: dict | None = ...,
255-
version: int = ...,
256-
convert_strl: list[_str] | None = ...,
258+
variable_labels: dict[HashableT, str] | None = ...,
259+
version: Literal[114, 117, 118, 119] | None = ...,
260+
convert_strl: list[HashableT] | None = ...,
261+
compression: CompressionOptions = ...,
262+
storage_options: StorageOptions = ...,
263+
*,
264+
value_labels: dict[Hashable, dict[float, str]] | None = ...,
257265
) -> None: ...
258266
def to_feather(self, path: FilePathOrBuffer, **kwargs) -> None: ...
259267
@overload

pandas-stubs/io/stata.pyi

Lines changed: 39 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ import datetime
33
from io import BytesIO
44
from types import TracebackType
55
from typing import (
6-
Hashable,
76
Literal,
87
Sequence,
8+
overload,
99
)
1010

1111
import numpy as np
@@ -18,10 +18,12 @@ from pandas._typing import (
1818
FilePath,
1919
HashableT,
2020
ReadBuffer,
21+
StataDateFormat,
2122
StorageOptions,
2223
WriteBuffer,
2324
)
2425

26+
@overload
2527
def read_stata(
2628
path: FilePath | ReadBuffer[bytes],
2729
convert_dates: bool = ...,
@@ -32,70 +34,47 @@ def read_stata(
3234
columns: list[HashableT] | None = ...,
3335
order_categoricals: bool = ...,
3436
chunksize: int | None = ...,
35-
iterator: bool = ...,
37+
*,
38+
iterator: Literal[True],
3639
compression: CompressionOptions = ...,
3740
storage_options: StorageOptions = ...,
38-
) -> DataFrame | StataReader: ...
39-
40-
stata_epoch: datetime.datetime = ...
41-
excessive_string_length_error: str
41+
) -> StataReader: ...
42+
@overload
43+
def read_stata(
44+
path: FilePath | ReadBuffer[bytes],
45+
convert_dates: bool,
46+
convert_categoricals: bool,
47+
index_col: str | None,
48+
convert_missing: bool,
49+
preserve_dtypes: bool,
50+
columns: list[HashableT] | None,
51+
order_categoricals: bool,
52+
chunksize: int | None,
53+
iterator: Literal[True],
54+
compression: CompressionOptions = ...,
55+
storage_options: StorageOptions = ...,
56+
) -> StataReader: ...
57+
@overload
58+
def read_stata(
59+
path: FilePath | ReadBuffer[bytes],
60+
convert_dates: bool = ...,
61+
convert_categoricals: bool = ...,
62+
index_col: str | None = ...,
63+
convert_missing: bool = ...,
64+
preserve_dtypes: bool = ...,
65+
columns: list[HashableT] | None = ...,
66+
order_categoricals: bool = ...,
67+
chunksize: int | None = ...,
68+
iterator: Literal[False] = ...,
69+
compression: CompressionOptions = ...,
70+
storage_options: StorageOptions = ...,
71+
) -> DataFrame: ...
4272

4373
class PossiblePrecisionLoss(Warning): ...
44-
45-
precision_loss_doc: str
46-
4774
class ValueLabelTypeMismatch(Warning): ...
48-
49-
value_label_mismatch_doc: str
50-
5175
class InvalidColumnName(Warning): ...
5276

53-
invalid_name_doc: str
54-
55-
class StataValueLabel:
56-
labname: Hashable = ...
57-
value_labels: list[tuple[float, str]] = ...
58-
text_len: int = ...
59-
off: npt.NDArray[np.int32] = ...
60-
val: npt.NDArray[np.int32] = ...
61-
txt: list[bytes] = ...
62-
n: int = ...
63-
len: int = ...
64-
def __init__(
65-
self, catarray: pd.Series, encoding: Literal["latin-1", "utf-8"] = ...
66-
) -> None: ...
67-
def generate_value_label(self, byteorder: str) -> bytes: ...
68-
69-
class StataMissingValue:
70-
MISSING_VALUES: dict[float, str] = ...
71-
bases: tuple[int, int, int] = ...
72-
float32_base: bytes = ...
73-
increment: int = ...
74-
int_value: int = ...
75-
float64_base: bytes = ...
76-
BASE_MISSING_VALUES: dict[str, int] = ...
77-
def __init__(self, value: float) -> None: ...
78-
def __eq__(self, other: object) -> bool: ...
79-
@property
80-
def string(self) -> str: ...
81-
@property
82-
def value(self) -> float: ...
83-
@classmethod
84-
def get_base_missing_value(cls, dtype): ...
85-
8677
class StataParser:
87-
DTYPE_MAP: dict[int, np.dtype] = ...
88-
DTYPE_MAP_XML: dict[int, np.dtype] = ...
89-
TYPE_MAP: list[tuple[int | str, ...]] = ...
90-
TYPE_MAP_XML: dict[int, str] = ...
91-
VALID_RANGE: dict[
92-
str,
93-
tuple[int, int] | tuple[np.float32, np.float32] | tuple[np.float64, np.float64],
94-
] = ...
95-
OLD_TYPE_MAPPING: dict[int, int] = ...
96-
MISSING_VALUES: dict[str, int] = ...
97-
NUMPY_TYPE_MAP: dict[str, str] = ...
98-
RESERVED_WORDS: tuple[str, ...] = ...
9978
def __init__(self) -> None: ...
10079

10180
class StataReader(StataParser, abc.Iterator):
@@ -142,70 +121,19 @@ class StataReader(StataParser, abc.Iterator):
142121
def value_labels(self) -> dict[str, dict[float, str]]: ...
143122

144123
class StataWriter(StataParser):
145-
type_converters: dict[str, type[np.dtype]] = ...
146124
def __init__(
147125
self,
148126
fname: FilePath | WriteBuffer[bytes],
149127
data: DataFrame,
150-
convert_dates: dict[Hashable, str] | None = ...,
128+
convert_dates: dict[HashableT, StataDateFormat] | None = ...,
151129
write_index: bool = ...,
152130
byteorder: str | None = ...,
153131
time_stamp: datetime.datetime | None = ...,
154132
data_label: str | None = ...,
155-
variable_labels: dict[Hashable, str] | None = ...,
133+
variable_labels: dict[HashableT, str] | None = ...,
156134
compression: CompressionOptions = ...,
157135
storage_options: StorageOptions = ...,
158136
*,
159-
value_labels: dict[Hashable, dict[float, str]] | None = ...,
137+
value_labels: dict[HashableT, dict[float, str]] | None = ...,
160138
) -> None: ...
161139
def write_file(self) -> None: ...
162-
163-
class StataStrLWriter:
164-
df: DataFrame = ...
165-
columns: Sequence[str] = ...
166-
def __init__(
167-
self,
168-
df: DataFrame,
169-
columns: Sequence[str],
170-
version: int = ...,
171-
byteorder: str | None = ...,
172-
) -> None: ...
173-
def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]: ...
174-
def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes: ...
175-
176-
class StataWriter117(StataWriter):
177-
def __init__(
178-
self,
179-
fname: FilePath | WriteBuffer[bytes],
180-
data: DataFrame,
181-
convert_dates: dict[Hashable, str] | None = ...,
182-
write_index: bool = ...,
183-
byteorder: str | None = ...,
184-
time_stamp: datetime.datetime | None = ...,
185-
data_label: str | None = ...,
186-
variable_labels: dict[Hashable, str] | None = ...,
187-
convert_strl: Sequence[Hashable] | None = ...,
188-
compression: CompressionOptions = ...,
189-
storage_options: StorageOptions = ...,
190-
*,
191-
value_labels: dict[Hashable, dict[float, str]] | None = ...,
192-
) -> None: ...
193-
194-
class StataWriterUTF8(StataWriter117):
195-
def __init__(
196-
self,
197-
fname: FilePath | WriteBuffer[bytes],
198-
data: DataFrame,
199-
convert_dates: dict[Hashable, str] | None = ...,
200-
write_index: bool = ...,
201-
byteorder: str | None = ...,
202-
time_stamp: datetime.datetime | None = ...,
203-
data_label: str | None = ...,
204-
variable_labels: dict[Hashable, str] | None = ...,
205-
convert_strl: Sequence[Hashable] | None = ...,
206-
version: int | None = ...,
207-
compression: CompressionOptions = ...,
208-
storage_options: StorageOptions = ...,
209-
*,
210-
value_labels: dict[Hashable, dict[float, str]] | None = ...,
211-
) -> None: ...

tests/test_io.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,52 @@
1+
import pandas as pd
12
from pandas import (
23
DataFrame,
34
read_clipboard,
5+
read_stata,
46
)
7+
from pandas._testing import ensure_clean
58
import pytest
69
from typing_extensions import assert_type
710

811
from tests import check
912

1013
from pandas.io.clipboard import PyperclipException
1114
from pandas.io.parsers import TextFileReader
15+
from pandas.io.stata import StataReader
1216

1317
DF = DataFrame({"a": [1, 2, 3], "b": [0.0, 0.0, 0.0]})
1418

1519

20+
def test_read_stata_df():
21+
with ensure_clean() as path:
22+
DF.to_stata(path)
23+
check(assert_type(read_stata(path), pd.DataFrame), pd.DataFrame)
24+
25+
26+
def test_read_stata_iterator_positional():
27+
with ensure_clean() as path:
28+
str_path = str(path)
29+
DF.to_stata(str_path)
30+
check(
31+
assert_type(
32+
read_stata(
33+
str_path, False, False, None, False, False, None, False, 2, True
34+
),
35+
StataReader,
36+
),
37+
StataReader,
38+
)
39+
40+
41+
def test_read_stata_iterator():
42+
with ensure_clean() as path:
43+
str_path = str(path)
44+
DF.to_stata(str_path)
45+
check(
46+
assert_type(read_stata(str_path, iterator=True), StataReader), StataReader
47+
)
48+
49+
1650
def test_clipboard():
1751
try:
1852
DF.to_clipboard()

0 commit comments

Comments
 (0)