diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 6264a0637..6cd3fff6c 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -58,6 +58,7 @@ from pandas._typing import ( FilePathOrBuffer, FilePathOrBytesBuffer, GroupByObjectNonScalar, + HashableT, IgnoreRaise, IndexingInt, IndexLabel, @@ -587,7 +588,7 @@ class DataFrame(NDFrame, OpsMixin): def set_index( self, keys: Union[ - Label, Series, Index, np.ndarray, Iterator[Hashable], List[Hashable] + Label, Series, Index, np.ndarray, Iterator[HashableT], List[HashableT] ], drop: _bool = ..., append: _bool = ..., @@ -599,7 +600,7 @@ class DataFrame(NDFrame, OpsMixin): def set_index( self, keys: Union[ - Label, Series, Index, np.ndarray, Iterator[Hashable], List[Hashable] + Label, Series, Index, np.ndarray, Iterator[HashableT], List[HashableT] ], drop: _bool = ..., append: _bool = ..., @@ -611,7 +612,7 @@ class DataFrame(NDFrame, OpsMixin): def set_index( self, keys: Union[ - Label, Series, Index, np.ndarray, Iterator[Hashable], List[Hashable] + Label, Series, Index, np.ndarray, Iterator[HashableT], List[HashableT] ], drop: _bool = ..., append: _bool = ..., @@ -622,7 +623,7 @@ class DataFrame(NDFrame, OpsMixin): def set_index( self, keys: Union[ - Label, Series, Index, np.ndarray, Iterator[Hashable], List[Hashable] + Label, Series, Index, np.ndarray, Iterator[HashableT], List[HashableT] ], drop: _bool = ..., append: _bool = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index 523bb7c1e..bf3a87fed 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1,4 +1,6 @@ # flake8: noqa: F841 +from __future__ import annotations + import datetime import io from pathlib import Path @@ -1305,3 +1307,43 @@ def test_groupby_result() -> None: for kk, g in df.groupby("a"): pass + + +def test_setitem_list(): + # GH 153 + lst1: list[str] = ["a", "b", "c"] + lst2: list[int] = [1, 2, 3] + lst3: list[float] = [4.0, 5.0, 6.0] + lst4: list[tuple[str, int]] = [("a", 1), ("b", 2), ("c", 3)] + lst5: list[complex] = [0 + 1j, 0 + 2j, 0 + 3j] + + columns: list[Hashable] = [ + "a", + "b", + "c", + 1, + 2, + 3, + 4.0, + 5.0, + 6.0, + ("a", 1), + ("b", 2), + ("c", 3), + 0 + 1j, + 0 + 2j, + 0 + 3j, + ] + + df = pd.DataFrame(np.empty((3, 15)), columns=columns) + + check(assert_type(df.set_index(lst1), pd.DataFrame), pd.DataFrame) + check(assert_type(df.set_index(lst2), pd.DataFrame), pd.DataFrame) + check(assert_type(df.set_index(lst3), pd.DataFrame), pd.DataFrame) + check(assert_type(df.set_index(lst4), pd.DataFrame), pd.DataFrame) + check(assert_type(df.set_index(lst5), pd.DataFrame), pd.DataFrame) + + iter1: Iterator[str] = (v for v in lst1) + iter2: Iterator[tuple[str, int]] = (v for v in lst4) + check(assert_type(df.set_index(iter1), pd.DataFrame), pd.DataFrame) + check(assert_type(df.set_index(iter2), pd.DataFrame), pd.DataFrame)