Skip to content

Commit c550aa8

Browse files
authored
ENH: Improve io/stata.py (#159)
* ENH: Improve typing in io/stata.py Improve types matching pandas main * ENH: Improve io/stata.py typing accuracy Match with pandas and improve type info
1 parent 9072b15 commit c550aa8

File tree

2 files changed

+134
-88
lines changed

2 files changed

+134
-88
lines changed

pandas-stubs/_typing.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ from typing import (
2727
)
2828

2929
import numpy as np
30-
from numpy import typing as npt
30+
import numpy.typing
3131
from pandas.core.arrays import ExtensionArray
3232
from pandas.core.frame import DataFrame
3333
from pandas.core.generic import NDFrame
@@ -42,6 +42,8 @@ from pandas._libs.tslibs import (
4242

4343
from pandas.core.dtypes.dtypes import ExtensionDtype
4444

45+
npt = numpy.typing
46+
4547
ArrayLike = Union[ExtensionArray, np.ndarray]
4648
AnyArrayLike = Union[Index, Series, np.ndarray]
4749
PythonScalar = Union[str, int, float, bool, complex]

pandas-stubs/io/stata.pyi

Lines changed: 131 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
11
from collections import abc
22
import datetime
3+
from io import BytesIO
4+
from types import TracebackType
35
from typing import (
4-
Dict,
56
Hashable,
6-
List,
7-
Optional,
7+
Literal,
88
Sequence,
99
)
1010

11+
import numpy as np
12+
import numpy.typing as npt
13+
import pandas as pd
1114
from pandas.core.frame import DataFrame
1215

13-
from pandas._typing import FilePathOrBuffer
16+
from pandas._typing import (
17+
CompressionOptions,
18+
FilePath,
19+
HashableT,
20+
ReadBuffer,
21+
StorageOptions,
22+
WriteBuffer,
23+
)
1424

1525
def read_stata(
16-
path: FilePathOrBuffer,
26+
path: FilePath | ReadBuffer[bytes],
1727
convert_dates: bool = ...,
1828
convert_categoricals: bool = ...,
19-
index_col: Optional[str] = ...,
29+
index_col: str | None = ...,
2030
convert_missing: bool = ...,
2131
preserve_dtypes: bool = ...,
22-
columns: Optional[List[str]] = ...,
32+
columns: list[HashableT] | None = ...,
2333
order_categoricals: bool = ...,
24-
chunksize: Optional[int] = ...,
34+
chunksize: int | None = ...,
2535
iterator: bool = ...,
26-
) -> DataFrame: ...
36+
compression: CompressionOptions = ...,
37+
storage_options: StorageOptions = ...,
38+
) -> DataFrame | StataReader: ...
2739

28-
stata_epoch = ...
40+
stata_epoch: datetime.datetime = ...
2941
excessive_string_length_error: str
3042

3143
class PossiblePrecisionLoss(Warning): ...
@@ -41,127 +53,159 @@ class InvalidColumnName(Warning): ...
4153
invalid_name_doc: str
4254

4355
class StataValueLabel:
44-
labname = ...
45-
value_labels = ...
46-
text_len = ...
47-
off = ...
48-
val = ...
49-
txt = ...
56+
labname: Hashable = ...
57+
value_labels: list[tuple[int | float, str]] = ...
58+
text_len: int = ...
59+
off: npt.NDArray[np.int32] = ...
60+
val: npt.NDArray[np.int32] = ...
61+
txt: list[bytes] = ...
5062
n: int = ...
51-
len = ...
52-
def __init__(self, catarray, encoding: str = ...): ...
53-
def generate_value_label(self, byteorder): ...
63+
len: int = ...
64+
def __init__(
65+
self, catarray: pd.Series, encoding: Literal["latin-1", "utf-8"] = ...
66+
) -> None: ...
67+
def generate_value_label(self, byteorder: str) -> bytes: ...
5468

5569
class StataMissingValue:
56-
MISSING_VALUES = ...
57-
bases = ...
70+
MISSING_VALUES: dict[float, str] = ...
71+
bases: tuple[int, int, int] = ...
5872
float32_base: bytes = ...
59-
increment = ...
60-
value = ...
61-
int_value = ...
73+
increment: int = ...
74+
int_value: int = ...
6275
float64_base: bytes = ...
63-
BASE_MISSING_VALUES = ...
64-
def __init__(self, value) -> None: ...
65-
string = ...
66-
def __eq__(self, other) -> bool: ...
76+
BASE_MISSING_VALUES: dict[str, int] = ...
77+
def __init__(self, value: int | float) -> None: ...
78+
def __eq__(self, other: object) -> bool: ...
79+
@property
80+
def string(self) -> str: ...
81+
@property
82+
def value(self) -> int | float: ...
6783
@classmethod
6884
def get_base_missing_value(cls, dtype): ...
6985

7086
class StataParser:
71-
DTYPE_MAP = ...
72-
DTYPE_MAP_XML = ...
73-
TYPE_MAP = ...
74-
TYPE_MAP_XML = ...
75-
VALID_RANGE = ...
76-
OLD_TYPE_MAPPING = ...
77-
MISSING_VALUES = ...
78-
NUMPY_TYPE_MAP = ...
79-
RESERVED_WORDS = ...
87+
DTYPE_MAP: dict[int, np.dtype] = ...
88+
DTYPE_MAP_XML: dict[int, np.dtype] = ...
89+
TYPE_MAP: list[tuple[int | str, ...]] = ...
90+
TYPE_MAP_XML: dict[int, str] = ...
91+
VALID_RANGE: dict[
92+
str,
93+
tuple[int, int] | tuple[np.float32, np.float32] | tuple[np.float64, np.float64],
94+
] = ...
95+
OLD_TYPE_MAPPING: dict[int, int] = ...
96+
MISSING_VALUES: dict[str, int] = ...
97+
NUMPY_TYPE_MAP: dict[str, str] = ...
98+
RESERVED_WORDS: tuple[str, ...] = ...
8099
def __init__(self) -> None: ...
81100

82101
class StataReader(StataParser, abc.Iterator):
83-
col_sizes = ...
84-
path_or_buf = ...
102+
col_sizes: list[int] = ...
103+
path_or_buf: BytesIO = ...
85104
def __init__(
86105
self,
87-
path_or_buf,
106+
path_or_buf: FilePath | ReadBuffer[bytes],
88107
convert_dates: bool = ...,
89108
convert_categoricals: bool = ...,
90-
index_col=...,
109+
index_col: str | None = ...,
91110
convert_missing: bool = ...,
92111
preserve_dtypes: bool = ...,
93-
columns=...,
112+
columns: Sequence[str] | None = ...,
94113
order_categoricals: bool = ...,
95-
chunksize=...,
114+
chunksize: int | None = ...,
115+
compression: CompressionOptions = ...,
116+
storage_options: StorageOptions = ...,
117+
) -> None: ...
118+
def __enter__(self) -> StataReader: ...
119+
def __exit__(
120+
self,
121+
exc_type: type[BaseException] | None,
122+
exc_value: BaseException | None,
123+
traceback: TracebackType | None,
96124
) -> None: ...
97-
def __enter__(self): ...
98-
def __exit__(self, exc_type, exc_value, traceback) -> None: ...
99125
def close(self) -> None: ...
100-
def __next__(self): ...
101-
def get_chunk(self, size=...): ...
126+
def __next__(self) -> DataFrame: ...
127+
def get_chunk(self, size: int | None = ...) -> DataFrame: ...
102128
def read(
103129
self,
104-
nrows=...,
105-
convert_dates=...,
106-
convert_categoricals=...,
107-
index_col=...,
108-
convert_missing=...,
109-
preserve_dtypes=...,
110-
columns=...,
111-
order_categoricals=...,
130+
nrows: int | None = ...,
131+
convert_dates: bool | None = ...,
132+
convert_categoricals: bool | None = ...,
133+
index_col: str | None = ...,
134+
convert_missing: bool | None = ...,
135+
preserve_dtypes: bool | None = ...,
136+
columns: list[str] | None = ...,
137+
order_categoricals: bool | None = ...,
112138
): ...
113139
@property
114-
def data_label(self): ...
115-
def variable_labels(self): ...
116-
def value_labels(self): ...
140+
def data_label(self) -> str: ...
141+
def variable_labels(self) -> dict[str, str]: ...
142+
def value_labels(self) -> dict[str, dict[int | float, str]]: ...
117143

118144
class StataWriter(StataParser):
119-
type_converters = ...
145+
type_converters: dict[str, type[np.dtype]] = ...
120146
def __init__(
121147
self,
122-
fname,
123-
data,
124-
convert_dates=...,
148+
fname: FilePath | WriteBuffer[bytes],
149+
data: DataFrame,
150+
convert_dates: dict[Hashable, str] | None = ...,
125151
write_index: bool = ...,
126-
byteorder=...,
127-
time_stamp=...,
128-
data_label=...,
129-
variable_labels=...,
152+
byteorder: str | None = ...,
153+
time_stamp: datetime.datetime | None = ...,
154+
data_label: str | None = ...,
155+
variable_labels: dict[Hashable, str] | None = ...,
156+
compression: CompressionOptions = ...,
157+
storage_options: StorageOptions = ...,
158+
*,
159+
value_labels: dict[Hashable, dict[float | int, str]] | None = ...,
130160
) -> None: ...
131161
def write_file(self) -> None: ...
132162

133163
class StataStrLWriter:
134-
df = ...
135-
columns = ...
136-
def __init__(self, df, columns, version: int = ..., byteorder=...) -> None: ...
137-
def generate_table(self): ...
138-
def generate_blob(self, gso_table): ...
164+
df: DataFrame = ...
165+
columns: Sequence[str] = ...
166+
def __init__(
167+
self,
168+
df: DataFrame,
169+
columns: Sequence[str],
170+
version: int = ...,
171+
byteorder: str | None = ...,
172+
) -> None: ...
173+
def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]: ...
174+
def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes: ...
139175

140176
class StataWriter117(StataWriter):
141177
def __init__(
142178
self,
143-
fname,
144-
data,
145-
convert_dates=...,
179+
fname: FilePath | WriteBuffer[bytes],
180+
data: DataFrame,
181+
convert_dates: dict[Hashable, str] | None = ...,
146182
write_index: bool = ...,
147-
byteorder=...,
148-
time_stamp=...,
149-
data_label=...,
150-
variable_labels=...,
151-
convert_strl=...,
183+
byteorder: str | None = ...,
184+
time_stamp: datetime.datetime | None = ...,
185+
data_label: str | None = ...,
186+
variable_labels: dict[Hashable, str] | None = ...,
187+
convert_strl: Sequence[Hashable] | None = ...,
188+
compression: CompressionOptions = ...,
189+
storage_options: StorageOptions = ...,
190+
*,
191+
value_labels: dict[Hashable, dict[float | int, str]] | None = ...,
152192
) -> None: ...
153193

154194
class StataWriterUTF8(StataWriter117):
155195
def __init__(
156196
self,
157-
fname: FilePathOrBuffer,
197+
fname: FilePath | WriteBuffer[bytes],
158198
data: DataFrame,
159-
convert_dates: Optional[Dict[Hashable, str]] = ...,
199+
convert_dates: dict[Hashable, str] | None = ...,
160200
write_index: bool = ...,
161-
byteorder: Optional[str] = ...,
162-
time_stamp: Optional[datetime.datetime] = ...,
163-
data_label: Optional[str] = ...,
164-
variable_labels: Optional[Dict[Hashable, str]] = ...,
165-
convert_strl: Optional[Sequence[Hashable]] = ...,
166-
version: Optional[int] = ...,
201+
byteorder: str | None = ...,
202+
time_stamp: datetime.datetime | None = ...,
203+
data_label: str | None = ...,
204+
variable_labels: dict[Hashable, str] | None = ...,
205+
convert_strl: Sequence[Hashable] | None = ...,
206+
version: int | None = ...,
207+
compression: CompressionOptions = ...,
208+
storage_options: StorageOptions = ...,
209+
*,
210+
value_labels: dict[Hashable, dict[float | int, str]] | None = ...,
167211
) -> None: ...

0 commit comments

Comments
 (0)