From 34efdbc5df5b26d29231d75fe1a38b8b72997ade Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Wed, 17 Aug 2022 16:59:32 +0100 Subject: [PATCH 1/4] ENH: Synchronize io/stata with pandas master Sychronize and remvoe classes not part of the public API --- pandas-stubs/io/stata.pyi | 92 +++++++++++++++----------------------- tests/test_io.py | 93 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 57 deletions(-) create mode 100644 tests/test_io.py diff --git a/pandas-stubs/io/stata.pyi b/pandas-stubs/io/stata.pyi index 42342928b..2b7de0a6a 100644 --- a/pandas-stubs/io/stata.pyi +++ b/pandas-stubs/io/stata.pyi @@ -6,6 +6,7 @@ from typing import ( Hashable, Literal, Sequence, + overload, ) import numpy as np @@ -22,6 +23,7 @@ from pandas._typing import ( WriteBuffer, ) +@overload def read_stata( path: FilePath | ReadBuffer[bytes], convert_dates: bool = ..., @@ -32,57 +34,46 @@ def read_stata( columns: list[HashableT] | None = ..., order_categoricals: bool = ..., chunksize: int | None = ..., - iterator: bool = ..., + *, + iterator: Literal[True], compression: CompressionOptions = ..., storage_options: StorageOptions = ..., -) -> DataFrame | StataReader: ... - -stata_epoch: datetime.datetime = ... -excessive_string_length_error: str +) -> StataReader: ... +@overload +def read_stata( + path: FilePath | ReadBuffer[bytes], + convert_dates: bool, + convert_categoricals: bool, + index_col: str | None, + convert_missing: bool, + preserve_dtypes: bool, + columns: list[HashableT] | None, + order_categoricals: bool, + chunksize: int | None, + iterator: Literal[True], + compression: CompressionOptions = ..., + storage_options: StorageOptions = ..., +) -> StataReader: ... +@overload +def read_stata( + path: FilePath | ReadBuffer[bytes], + convert_dates: bool = ..., + convert_categoricals: bool = ..., + index_col: str | None = ..., + convert_missing: bool = ..., + preserve_dtypes: bool = ..., + columns: list[HashableT] | None = ..., + order_categoricals: bool = ..., + chunksize: int | None = ..., + iterator: Literal[False] = ..., + compression: CompressionOptions = ..., + storage_options: StorageOptions = ..., +) -> DataFrame: ... class PossiblePrecisionLoss(Warning): ... - -precision_loss_doc: str - class ValueLabelTypeMismatch(Warning): ... - -value_label_mismatch_doc: str - class InvalidColumnName(Warning): ... -invalid_name_doc: str - -class StataValueLabel: - labname: Hashable = ... - value_labels: list[tuple[float, str]] = ... - text_len: int = ... - off: npt.NDArray[np.int32] = ... - val: npt.NDArray[np.int32] = ... - txt: list[bytes] = ... - n: int = ... - len: int = ... - def __init__( - self, catarray: pd.Series, encoding: Literal["latin-1", "utf-8"] = ... - ) -> None: ... - def generate_value_label(self, byteorder: str) -> bytes: ... - -class StataMissingValue: - MISSING_VALUES: dict[float, str] = ... - bases: tuple[int, int, int] = ... - float32_base: bytes = ... - increment: int = ... - int_value: int = ... - float64_base: bytes = ... - BASE_MISSING_VALUES: dict[str, int] = ... - def __init__(self, value: float) -> None: ... - def __eq__(self, other: object) -> bool: ... - @property - def string(self) -> str: ... - @property - def value(self) -> float: ... - @classmethod - def get_base_missing_value(cls, dtype): ... - class StataParser: DTYPE_MAP: dict[int, np.dtype] = ... DTYPE_MAP_XML: dict[int, np.dtype] = ... @@ -160,19 +151,6 @@ class StataWriter(StataParser): ) -> None: ... def write_file(self) -> None: ... -class StataStrLWriter: - df: DataFrame = ... - columns: Sequence[str] = ... - def __init__( - self, - df: DataFrame, - columns: Sequence[str], - version: int = ..., - byteorder: str | None = ..., - ) -> None: ... - def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]: ... - def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes: ... - class StataWriter117(StataWriter): def __init__( self, diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 000000000..1837c96dd --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +import tempfile +from typing import ( + IO, + Any, +) +import uuid + +import pandas as pd +from pandas import DataFrame +from typing_extensions import assert_type + +from tests import check + +from pandas.io.stata import ( + StataReader, + read_stata, +) + +DF = DataFrame({"a": [1, 2, 3], "b": [0.0, 0.0, 0.0]}) + + +@contextmanager +def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any): + """ + Gets a temporary path and agrees to remove on close. + This implementation does not use tempfile.mkstemp to avoid having a file handle. + If the code using the returned path wants to delete the file itself, windows + requires that no program has a file handle to it. + Parameters + ---------- + filename : str (optional) + suffix of the created file. + return_filelike : bool (default False) + if True, returns a file-like which is *always* cleaned. Necessary for + savefig and other functions which want to append extensions. + **kwargs + Additional keywords are passed to open(). + """ + folder = Path(tempfile.gettempdir()) + + if filename is None: + filename = "" + filename = str(uuid.uuid4()) + filename + path = folder / filename + + path.touch() + + handle_or_str: str | IO = str(path) + if return_filelike: + kwargs.setdefault("mode", "w+b") + handle_or_str = open(path, **kwargs) + + try: + yield handle_or_str + finally: + if not isinstance(handle_or_str, str): + handle_or_str.close() + if path.is_file(): + path.unlink() + + +def test_read_stata_df(): + with ensure_clean() as path: + DF.to_stata(path) + check(assert_type(read_stata(path), pd.DataFrame), pd.DataFrame) + + +def test_read_stata_iterator_positional(): + with ensure_clean() as path: + str_path = str(path) + DF.to_stata(str_path) + check( + assert_type( + read_stata( + str_path, False, False, None, False, False, None, False, 2, True + ), + StataReader, + ), + StataReader, + ) + + +def test_read_stata_iterator(): + with ensure_clean() as path: + str_path = str(path) + DF.to_stata(str_path) + check( + assert_type(read_stata(str_path, iterator=True), StataReader), StataReader + ) From 163cebd6e5557516796aed6e294c6bd51613c231 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Wed, 17 Aug 2022 17:45:56 +0100 Subject: [PATCH 2/4] MAINT: Update frame stata io --- pandas-stubs/core/frame.pyi | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index cea064a8f..fee9bec34 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -48,8 +48,10 @@ from pandas._typing import ( Axes, Axis, AxisType, + CompressionOptions, Dtype, DtypeNp, + FilePath, FilePathOrBuffer, FilePathOrBytesBuffer, GroupByObjectNonScalar, @@ -67,8 +69,10 @@ from pandas._typing import ( Scalar, ScalarT, SeriesAxisType, + StorageOptions, StrLike, T as TType, + WriteBuffer, np_ndarray_bool, np_ndarray_str, num, @@ -236,15 +240,19 @@ class DataFrame(NDFrame, OpsMixin): ) -> np.recarray: ... def to_stata( self, - path: FilePathOrBuffer, - convert_dates: dict | None = ..., + path: FilePath | WriteBuffer[bytes], + convert_dates: dict[Hashable, str] | None = ..., write_index: _bool = ..., - byteorder: _str | Literal["<", ">", "little", "big"] | None = ..., - time_stamp=..., + byteorder: Literal["<", ">", "little", "big"] | None = ..., + time_stamp: _dt.datetime | None = ..., data_label: _str | None = ..., - variable_labels: dict | None = ..., - version: int = ..., - convert_strl: list[_str] | None = ..., + variable_labels: dict[Hashable, str] | None = ..., + version: int | None = ..., + convert_strl: list[HashableT] | None = ..., + compression: CompressionOptions = ..., + storage_options: StorageOptions = ..., + *, + value_labels: dict[Hashable, dict[float, str]] | None = ..., ) -> None: ... def to_feather(self, path: FilePathOrBuffer, **kwargs) -> None: ... @overload From 49734ca2efefe34590eb4637f4ce38c3e1b311e7 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Thu, 18 Aug 2022 07:56:34 +0100 Subject: [PATCH 3/4] ENH: Add literals Add literals for limited value inputs --- pandas-stubs/_typing.pyi | 17 +++++++++++++++++ pandas-stubs/core/frame.pyi | 7 ++++--- pandas-stubs/io/stata.pyi | 21 +++++++++++---------- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 71bfc1839..1423146b4 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -203,4 +203,21 @@ GroupByObjectNonScalar = Union[ ] GroupByObject = Union[Scalar, GroupByObjectNonScalar] +StataDateFormat = Literal[ + "tc", + "%tc", + "td", + "%td", + "tw", + "%tw", + "tm", + "%tm", + "tq", + "%tq", + "th", + "%th", + "ty", + "%ty", +] + __all__ = ["npt", "type_t"] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index fee9bec34..d779f44ba 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -69,6 +69,7 @@ from pandas._typing import ( Scalar, ScalarT, SeriesAxisType, + StataDateFormat, StorageOptions, StrLike, T as TType, @@ -241,13 +242,13 @@ class DataFrame(NDFrame, OpsMixin): def to_stata( self, path: FilePath | WriteBuffer[bytes], - convert_dates: dict[Hashable, str] | None = ..., + convert_dates: dict[HashableT, StataDateFormat] | None = ..., write_index: _bool = ..., byteorder: Literal["<", ">", "little", "big"] | None = ..., time_stamp: _dt.datetime | None = ..., data_label: _str | None = ..., - variable_labels: dict[Hashable, str] | None = ..., - version: int | None = ..., + variable_labels: dict[HashableT, str] | None = ..., + version: Literal[114, 117, 118, 119] | None = ..., convert_strl: list[HashableT] | None = ..., compression: CompressionOptions = ..., storage_options: StorageOptions = ..., diff --git a/pandas-stubs/io/stata.pyi b/pandas-stubs/io/stata.pyi index 2b7de0a6a..5bb90b47d 100644 --- a/pandas-stubs/io/stata.pyi +++ b/pandas-stubs/io/stata.pyi @@ -19,6 +19,7 @@ from pandas._typing import ( FilePath, HashableT, ReadBuffer, + StataDateFormat, StorageOptions, WriteBuffer, ) @@ -138,16 +139,16 @@ class StataWriter(StataParser): self, fname: FilePath | WriteBuffer[bytes], data: DataFrame, - convert_dates: dict[Hashable, str] | None = ..., + convert_dates: dict[HashableT, StataDateFormat] | None = ..., write_index: bool = ..., byteorder: str | None = ..., time_stamp: datetime.datetime | None = ..., data_label: str | None = ..., - variable_labels: dict[Hashable, str] | None = ..., + variable_labels: dict[HashableT, str] | None = ..., compression: CompressionOptions = ..., storage_options: StorageOptions = ..., *, - value_labels: dict[Hashable, dict[float, str]] | None = ..., + value_labels: dict[HashableT, dict[float, str]] | None = ..., ) -> None: ... def write_file(self) -> None: ... @@ -156,17 +157,17 @@ class StataWriter117(StataWriter): self, fname: FilePath | WriteBuffer[bytes], data: DataFrame, - convert_dates: dict[Hashable, str] | None = ..., + convert_dates: dict[HashableT, StataDateFormat] | None = ..., write_index: bool = ..., byteorder: str | None = ..., time_stamp: datetime.datetime | None = ..., data_label: str | None = ..., - variable_labels: dict[Hashable, str] | None = ..., + variable_labels: dict[HashableT, str] | None = ..., convert_strl: Sequence[Hashable] | None = ..., compression: CompressionOptions = ..., storage_options: StorageOptions = ..., *, - value_labels: dict[Hashable, dict[float, str]] | None = ..., + value_labels: dict[HashableT, dict[float, str]] | None = ..., ) -> None: ... class StataWriterUTF8(StataWriter117): @@ -174,16 +175,16 @@ class StataWriterUTF8(StataWriter117): self, fname: FilePath | WriteBuffer[bytes], data: DataFrame, - convert_dates: dict[Hashable, str] | None = ..., + convert_dates: dict[HashableT, StataDateFormat] | None = ..., write_index: bool = ..., byteorder: str | None = ..., time_stamp: datetime.datetime | None = ..., data_label: str | None = ..., - variable_labels: dict[Hashable, str] | None = ..., + variable_labels: dict[HashableT, str] | None = ..., convert_strl: Sequence[Hashable] | None = ..., - version: int | None = ..., + version: Literal[118, 119] | None = ..., compression: CompressionOptions = ..., storage_options: StorageOptions = ..., *, - value_labels: dict[Hashable, dict[float, str]] | None = ..., + value_labels: dict[HashableT, dict[float, str]] | None = ..., ) -> None: ... From 52e104dc7375a87e5da72ee524b8019b089bfa7a Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Mon, 22 Aug 2022 16:31:53 +0100 Subject: [PATCH 4/4] CLN: Remove non-public classes --- pandas-stubs/io/stata.pyi | 51 --------------------------------------- tests/test_io.py | 6 ++--- 2 files changed, 2 insertions(+), 55 deletions(-) diff --git a/pandas-stubs/io/stata.pyi b/pandas-stubs/io/stata.pyi index 5bb90b47d..12b8dad74 100644 --- a/pandas-stubs/io/stata.pyi +++ b/pandas-stubs/io/stata.pyi @@ -3,7 +3,6 @@ import datetime from io import BytesIO from types import TracebackType from typing import ( - Hashable, Literal, Sequence, overload, @@ -76,18 +75,6 @@ class ValueLabelTypeMismatch(Warning): ... class InvalidColumnName(Warning): ... class StataParser: - DTYPE_MAP: dict[int, np.dtype] = ... - DTYPE_MAP_XML: dict[int, np.dtype] = ... - TYPE_MAP: list[tuple[int | str, ...]] = ... - TYPE_MAP_XML: dict[int, str] = ... - VALID_RANGE: dict[ - str, - tuple[int, int] | tuple[np.float32, np.float32] | tuple[np.float64, np.float64], - ] = ... - OLD_TYPE_MAPPING: dict[int, int] = ... - MISSING_VALUES: dict[str, int] = ... - NUMPY_TYPE_MAP: dict[str, str] = ... - RESERVED_WORDS: tuple[str, ...] = ... def __init__(self) -> None: ... class StataReader(StataParser, abc.Iterator): @@ -134,7 +121,6 @@ class StataReader(StataParser, abc.Iterator): def value_labels(self) -> dict[str, dict[float, str]]: ... class StataWriter(StataParser): - type_converters: dict[str, type[np.dtype]] = ... def __init__( self, fname: FilePath | WriteBuffer[bytes], @@ -151,40 +137,3 @@ class StataWriter(StataParser): value_labels: dict[HashableT, dict[float, str]] | None = ..., ) -> None: ... def write_file(self) -> None: ... - -class StataWriter117(StataWriter): - def __init__( - self, - fname: FilePath | WriteBuffer[bytes], - data: DataFrame, - convert_dates: dict[HashableT, StataDateFormat] | None = ..., - write_index: bool = ..., - byteorder: str | None = ..., - time_stamp: datetime.datetime | None = ..., - data_label: str | None = ..., - variable_labels: dict[HashableT, str] | None = ..., - convert_strl: Sequence[Hashable] | None = ..., - compression: CompressionOptions = ..., - storage_options: StorageOptions = ..., - *, - value_labels: dict[HashableT, dict[float, str]] | None = ..., - ) -> None: ... - -class StataWriterUTF8(StataWriter117): - def __init__( - self, - fname: FilePath | WriteBuffer[bytes], - data: DataFrame, - convert_dates: dict[HashableT, StataDateFormat] | None = ..., - write_index: bool = ..., - byteorder: str | None = ..., - time_stamp: datetime.datetime | None = ..., - data_label: str | None = ..., - variable_labels: dict[HashableT, str] | None = ..., - convert_strl: Sequence[Hashable] | None = ..., - version: Literal[118, 119] | None = ..., - compression: CompressionOptions = ..., - storage_options: StorageOptions = ..., - *, - value_labels: dict[HashableT, dict[float, str]] | None = ..., - ) -> None: ... diff --git a/tests/test_io.py b/tests/test_io.py index 445f1ed29..e63acffd5 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,6 +2,7 @@ from pandas import ( DataFrame, read_clipboard, + read_stata, ) from pandas._testing import ensure_clean import pytest @@ -11,10 +12,7 @@ from pandas.io.clipboard import PyperclipException from pandas.io.parsers import TextFileReader -from pandas.io.stata import ( - StataReader, - read_stata, -) +from pandas.io.stata import StataReader DF = DataFrame({"a": [1, 2, 3], "b": [0.0, 0.0, 0.0]})