From c85c82e31115cb82ed36c0f63a1f9847404d9770 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sat, 3 Feb 2024 14:03:17 -0500 Subject: [PATCH 1/2] TYP: misc IO return types --- pandas/_typing.py | 3 +++ pandas/io/excel/_openpyxl.py | 3 ++- pandas/io/formats/excel.py | 4 ++- pandas/io/formats/format.py | 6 +++-- pandas/io/formats/info.py | 8 +++--- pandas/io/formats/style.py | 2 +- pandas/io/formats/style_render.py | 4 ++- pandas/io/html.py | 2 +- pandas/io/json/_normalize.py | 27 ++++++++++++++++++-- pandas/io/parquet.py | 2 +- pandas/io/parsers/base_parser.py | 36 +++++++++++++++------------ pandas/io/parsers/c_parser_wrapper.py | 14 ++++++++--- pandas/io/parsers/python_parser.py | 1 + pandas/io/parsers/readers.py | 2 +- pandas/io/sas/sas7bdat.py | 5 ++-- pandas/io/sql.py | 15 +++++------ pandas/io/stata.py | 2 +- 17 files changed, 91 insertions(+), 45 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index c704516f74300..1fec41463904c 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -529,3 +529,6 @@ def closed(self) -> bool: Callable[[HashableT], bool], None, ] + +# maintaine the sub-type of any hashable sequence +SequenceT = TypeVar("SequenceT", bound=Sequence[Hashable]) diff --git a/pandas/io/excel/_openpyxl.py b/pandas/io/excel/_openpyxl.py index c546443868a62..218a592c22b4a 100644 --- a/pandas/io/excel/_openpyxl.py +++ b/pandas/io/excel/_openpyxl.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from openpyxl import Workbook from openpyxl.descriptors.serialisable import Serialisable + from openpyxl.styles import Fill from pandas._typing import ( ExcelWriterIfSheetExists, @@ -244,7 +245,7 @@ def _convert_to_stop(cls, stop_seq): return map(cls._convert_to_color, stop_seq) @classmethod - def _convert_to_fill(cls, fill_dict: dict[str, Any]): + def _convert_to_fill(cls, fill_dict: dict[str, Any]) -> Fill: """ Convert ``fill_dict`` to an openpyxl v2 Fill object. diff --git a/pandas/io/formats/excel.py b/pandas/io/formats/excel.py index 0c3a53eb1cfea..892f69e76359b 100644 --- a/pandas/io/formats/excel.py +++ b/pandas/io/formats/excel.py @@ -284,7 +284,9 @@ def build_border( for side in ["top", "right", "bottom", "left"] } - def _border_style(self, style: str | None, width: str | None, color: str | None): + def _border_style( + self, style: str | None, width: str | None, color: str | None + ) -> str | None: # convert styles and widths to openxml, one of: # 'dashDot' # 'dashDotDot' diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 00c7526edfa48..65124f97459cd 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -1346,7 +1346,9 @@ def get_result_as_array(self) -> np.ndarray: the parameters given at initialisation, as a numpy array """ - def format_with_na_rep(values: ArrayLike, formatter: Callable, na_rep: str): + def format_with_na_rep( + values: ArrayLike, formatter: Callable, na_rep: str + ) -> np.ndarray: mask = isna(values) formatted = np.array( [ @@ -1358,7 +1360,7 @@ def format_with_na_rep(values: ArrayLike, formatter: Callable, na_rep: str): def format_complex_with_na_rep( values: ArrayLike, formatter: Callable, na_rep: str - ): + ) -> np.ndarray: real_values = np.real(values).ravel() # type: ignore[arg-type] imag_values = np.imag(values).ravel() # type: ignore[arg-type] real_mask, imag_mask = isna(real_values), isna(imag_values) diff --git a/pandas/io/formats/info.py b/pandas/io/formats/info.py index 552affbd053f2..2d28b032ca49d 100644 --- a/pandas/io/formats/info.py +++ b/pandas/io/formats/info.py @@ -392,7 +392,7 @@ def dtype_counts(self) -> Mapping[str, int]: @property @abstractmethod - def non_null_counts(self) -> Sequence[int]: + def non_null_counts(self) -> list[int] | Series: """Sequence of non-null counts for all columns or column (if series).""" @property @@ -486,7 +486,7 @@ def col_count(self) -> int: return len(self.ids) @property - def non_null_counts(self) -> Sequence[int]: + def non_null_counts(self) -> Series: """Sequence of non-null counts for all columns or column (if series).""" return self.data.count() @@ -546,7 +546,7 @@ def render( printer.to_buffer(buf) @property - def non_null_counts(self) -> Sequence[int]: + def non_null_counts(self) -> list[int]: return [self.data.count()] @property @@ -750,7 +750,7 @@ def memory_usage_string(self) -> str: return self.info.memory_usage_string @property - def non_null_counts(self) -> Sequence[int]: + def non_null_counts(self) -> list[int] | Series: return self.info.non_null_counts def add_object_type_line(self) -> None: diff --git a/pandas/io/formats/style.py b/pandas/io/formats/style.py index c85c6c3ef0ff7..0e67949709a22 100644 --- a/pandas/io/formats/style.py +++ b/pandas/io/formats/style.py @@ -3827,7 +3827,7 @@ def _background_gradient( vmax: float | None = None, gmap: Sequence | np.ndarray | DataFrame | Series | None = None, text_only: bool = False, -): +) -> list[str] | DataFrame: """ Color background in a range according to the data or a gradient map """ diff --git a/pandas/io/formats/style_render.py b/pandas/io/formats/style_render.py index 80df46bf2336a..4ba094ec614d0 100644 --- a/pandas/io/formats/style_render.py +++ b/pandas/io/formats/style_render.py @@ -2030,7 +2030,9 @@ def _class_styles(self): } ] - def _pseudo_css(self, uuid: str, name: str, row: int, col: int, text: str): + def _pseudo_css( + self, uuid: str, name: str, row: int, col: int, text: str + ) -> list[CSSDict]: """ For every table data-cell that has a valid tooltip (not None, NaN or empty string) must create two pseudo CSS entries for the specific diff --git a/pandas/io/html.py b/pandas/io/html.py index 0f3704b698915..302f901aa0d16 100644 --- a/pandas/io/html.py +++ b/pandas/io/html.py @@ -469,7 +469,7 @@ def row_is_all_th(row): def _expand_colspan_rowspan( self, rows, section: Literal["header", "footer", "body"] - ): + ) -> list[list]: """ Given a list of s, return a list of text rows. diff --git a/pandas/io/json/_normalize.py b/pandas/io/json/_normalize.py index 39fbce0b6901c..49f95430d9bb9 100644 --- a/pandas/io/json/_normalize.py +++ b/pandas/io/json/_normalize.py @@ -11,6 +11,7 @@ TYPE_CHECKING, Any, DefaultDict, + overload, ) import numpy as np @@ -42,13 +43,35 @@ def convert_to_line_delimits(s: str) -> str: return convert_json_to_lines(s) +@overload def nested_to_record( - ds, + ds: dict, + prefix: str = ..., + sep: str = ..., + level: int = ..., + max_level: int | None = ..., +) -> dict[str, Any]: + ... + + +@overload +def nested_to_record( + ds: list[dict], + prefix: str = ..., + sep: str = ..., + level: int = ..., + max_level: int | None = ..., +) -> list[dict[str, Any]]: + ... + + +def nested_to_record( + ds: dict | list[dict], prefix: str = "", sep: str = ".", level: int = 0, max_level: int | None = None, -): +) -> dict[str, Any] | list[dict[str, Any]]: """ A simplified json_normalize diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index e04a3acc829f1..a6d58d6cffb10 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -150,7 +150,7 @@ def validate_dataframe(df: DataFrame) -> None: if not isinstance(df, DataFrame): raise ValueError("to_parquet only supports IO with DataFrames") - def write(self, df: DataFrame, path, compression, **kwargs): + def write(self, df: DataFrame, path, compression, **kwargs) -> None: raise AbstractMethodError(self) def read(self, path, columns=None, **kwargs) -> DataFrame: diff --git a/pandas/io/parsers/base_parser.py b/pandas/io/parsers/base_parser.py index 09f0f2af8e5c6..3aef0692d5f59 100644 --- a/pandas/io/parsers/base_parser.py +++ b/pandas/io/parsers/base_parser.py @@ -84,7 +84,6 @@ if TYPE_CHECKING: from collections.abc import ( - Hashable, Iterable, Mapping, Sequence, @@ -94,7 +93,10 @@ ArrayLike, DtypeArg, DtypeObj, + Hashable, + HashableT, Scalar, + SequenceT, ) @@ -350,13 +352,13 @@ def extract(r): @final def _maybe_make_multi_index_columns( self, - columns: Sequence[Hashable], + columns: SequenceT, col_names: Sequence[Hashable] | None = None, - ) -> Sequence[Hashable] | MultiIndex: + ) -> SequenceT | MultiIndex: # possibly create a column mi here if is_potential_multi_index(columns): - list_columns = cast(list[tuple], columns) - return MultiIndex.from_tuples(list_columns, names=col_names) + columns_mi = cast("Sequence[tuple[Hashable, ...]]", columns) + return MultiIndex.from_tuples(columns_mi, names=col_names) return columns @final @@ -520,7 +522,7 @@ def _convert_to_ndarrays( verbose: bool = False, converters=None, dtypes=None, - ): + ) -> dict[Any, np.ndarray]: result = {} for c, values in dct.items(): conv_f = None if converters is None else converters.get(c, None) @@ -923,23 +925,23 @@ def _check_data_length( @overload def _evaluate_usecols( self, - usecols: set[int] | Callable[[Hashable], object], - names: Sequence[Hashable], + usecols: Callable[[Hashable], object], + names: Iterable[Hashable], ) -> set[int]: ... @overload def _evaluate_usecols( - self, usecols: set[str], names: Sequence[Hashable] - ) -> set[str]: + self, usecols: SequenceT, names: Iterable[Hashable] + ) -> SequenceT: ... @final def _evaluate_usecols( self, - usecols: Callable[[Hashable], object] | set[str] | set[int], - names: Sequence[Hashable], - ) -> set[str] | set[int]: + usecols: Callable[[Hashable], object] | SequenceT, + names: Iterable[Hashable], + ) -> SequenceT | set[int]: """ Check whether or not the 'usecols' parameter is a callable. If so, enumerates the 'names' @@ -952,7 +954,7 @@ def _evaluate_usecols( return usecols @final - def _validate_usecols_names(self, usecols, names: Sequence): + def _validate_usecols_names(self, usecols: SequenceT, names: Sequence) -> SequenceT: """ Validates that all usecols are present in a given list of names. If not, raise a ValueError that @@ -1072,7 +1074,9 @@ def _clean_index_names(self, columns, index_col) -> tuple[list | None, list, lis return index_names, columns, index_col @final - def _get_empty_meta(self, columns, dtype: DtypeArg | None = None): + def _get_empty_meta( + self, columns: Sequence[HashableT], dtype: DtypeArg | None = None + ) -> tuple[Index, list[HashableT], dict[HashableT, Series]]: columns = list(columns) index_col = self.index_col @@ -1275,7 +1279,7 @@ def _process_date_conversion( columns, keep_date_col: bool = False, dtype_backend=lib.no_default, -): +) -> tuple[dict, list]: def _isindex(colspec): return (isinstance(index_col, list) and colspec in index_col) or ( isinstance(index_names, list) and colspec in index_names diff --git a/pandas/io/parsers/c_parser_wrapper.py b/pandas/io/parsers/c_parser_wrapper.py index 0cd788c5e5739..47666da67170c 100644 --- a/pandas/io/parsers/c_parser_wrapper.py +++ b/pandas/io/parsers/c_parser_wrapper.py @@ -42,9 +42,11 @@ from pandas._typing import ( ArrayLike, + AnyArrayLike, DtypeArg, DtypeObj, ReadCsvBuffer, + SequenceT, ) from pandas import ( @@ -225,7 +227,7 @@ def read( ) -> tuple[ Index | MultiIndex | None, Sequence[Hashable] | MultiIndex, - Mapping[Hashable, ArrayLike], + Mapping[Hashable, AnyArrayLike], ]: index: Index | MultiIndex | None column_names: Sequence[Hashable] | MultiIndex @@ -248,7 +250,11 @@ def read( names, dtype=self.dtype, ) - columns = self._maybe_make_multi_index_columns(columns, self.col_names) + # error: Incompatible types in assignment (expression has type + # "list[Hashable] | MultiIndex", variable has type "list[Hashable]") + columns = self._maybe_make_multi_index_columns( # type: ignore[assignment] + columns, self.col_names + ) if self.usecols is not None: columns = self._filter_usecols(columns) @@ -334,11 +340,11 @@ def read( return index, column_names, date_data - def _filter_usecols(self, names: Sequence[Hashable]) -> Sequence[Hashable]: + def _filter_usecols(self, names: SequenceT) -> SequenceT | list[Hashable]: # hackish usecols = self._evaluate_usecols(self.usecols, names) if usecols is not None and len(names) != len(usecols): - names = [ + return [ name for i, name in enumerate(names) if i in usecols or name in usecols ] return names diff --git a/pandas/io/parsers/python_parser.py b/pandas/io/parsers/python_parser.py index e830db559c5a5..dbda47172f6ac 100644 --- a/pandas/io/parsers/python_parser.py +++ b/pandas/io/parsers/python_parser.py @@ -266,6 +266,7 @@ def read( # done with first read, next time raise StopIteration self._first_chunk = False + index: Index | None columns: Sequence[Hashable] = list(self.orig_names) if not len(content): # pragma: no cover # DataFrame with the right metadata, even though it's length 0 diff --git a/pandas/io/parsers/readers.py b/pandas/io/parsers/readers.py index d35c153459bf8..71e1a31759a0c 100644 --- a/pandas/io/parsers/readers.py +++ b/pandas/io/parsers/readers.py @@ -2101,7 +2101,7 @@ def _floatify_na_values(na_values): return result -def _stringify_na_values(na_values, floatify: bool): +def _stringify_na_values(na_values, floatify: bool) -> set[str | float]: """return a stringified and numeric for these values""" result: list[str | float] = [] for x in na_values: diff --git a/pandas/io/sas/sas7bdat.py b/pandas/io/sas/sas7bdat.py index 895079bc15588..275fad2a565bf 100644 --- a/pandas/io/sas/sas7bdat.py +++ b/pandas/io/sas/sas7bdat.py @@ -54,6 +54,7 @@ from pandas._typing import ( CompressionOptions, FilePath, + NaTType, ReadBuffer, ) @@ -62,7 +63,7 @@ _sas_origin = Timestamp("1960-01-01") -def _parse_datetime(sas_datetime: float, unit: str): +def _parse_datetime(sas_datetime: float, unit: str) -> datetime | NaTType: if isna(sas_datetime): return pd.NaT @@ -326,7 +327,7 @@ def __next__(self) -> DataFrame: return da # Read a single float of the given width (4 or 8). - def _read_float(self, offset: int, width: int): + def _read_float(self, offset: int, width: int) -> float: assert self._cached_page is not None if width == 4: return read_float_with_byteswap( diff --git a/pandas/io/sql.py b/pandas/io/sql.py index d9a5e6dfd0cf8..4e0ddd0f56ba8 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -67,6 +67,7 @@ if TYPE_CHECKING: from collections.abc import ( + Generator, Iterator, Mapping, ) @@ -136,7 +137,7 @@ def _handle_date_column( return to_datetime(col, errors="coerce", format=format, utc=utc) -def _parse_date_columns(data_frame, parse_dates): +def _parse_date_columns(data_frame: DataFrame, parse_dates) -> DataFrame: """ Force non-datetime columns to be read as such. Supports both string formatted and integer timestamp columns. @@ -199,7 +200,7 @@ def _wrap_result( parse_dates=None, dtype: DtypeArg | None = None, dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", -): +) -> DataFrame: """Wrap result set of a SQLAlchemy query in a DataFrame.""" frame = _convert_arrays_to_dataframe(data, columns, coerce_float, dtype_backend) @@ -1153,7 +1154,7 @@ def _query_iterator( coerce_float: bool = True, parse_dates=None, dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", - ): + ) -> Generator[DataFrame, None, None]: """Return generator through chunked result set.""" has_read_data = False with exit_stack: @@ -1765,7 +1766,7 @@ def _query_iterator( parse_dates=None, dtype: DtypeArg | None = None, dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", - ): + ) -> Generator[DataFrame, None, None]: """Return generator through chunked result set""" has_read_data = False with exit_stack: @@ -2466,7 +2467,7 @@ def _create_sql_schema( } -def _get_unicode_name(name: object): +def _get_unicode_name(name: object) -> str: try: uname = str(name).encode("utf-8", "strict").decode("utf-8") except UnicodeError as err: @@ -2474,7 +2475,7 @@ def _get_unicode_name(name: object): return uname -def _get_valid_sqlite_name(name: object): +def _get_valid_sqlite_name(name: object) -> str: # See https://stackoverflow.com/questions/6514274/how-do-you-escape-strings\ # -for-sqlite-table-column-names-in-python # Ensure the string can be encoded as UTF-8. @@ -2712,7 +2713,7 @@ def _query_iterator( parse_dates=None, dtype: DtypeArg | None = None, dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", - ): + ) -> Generator[DataFrame, None, None]: """Return generator through chunked result set""" has_read_data = False while True: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index a2c15938c04bf..447c97d078e02 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -373,7 +373,7 @@ def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series: def parse_dates_safe( dates: Series, delta: bool = False, year: bool = False, days: bool = False - ): + ) -> DataFrame: d = {} if lib.is_np_dtype(dates.dtype, "M"): if delta: From eae53d0d4de71bfef715ff76e4c74b679e716a95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sat, 3 Feb 2024 14:19:00 -0500 Subject: [PATCH 2/2] isort --- pandas/io/parsers/c_parser_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/io/parsers/c_parser_wrapper.py b/pandas/io/parsers/c_parser_wrapper.py index 47666da67170c..f24d7a628998e 100644 --- a/pandas/io/parsers/c_parser_wrapper.py +++ b/pandas/io/parsers/c_parser_wrapper.py @@ -41,8 +41,8 @@ ) from pandas._typing import ( - ArrayLike, AnyArrayLike, + ArrayLike, DtypeArg, DtypeObj, ReadCsvBuffer,