diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 1f8a2648..fb2c1c9b 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1300,7 +1300,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def unstack( self, level: Level = ..., - fill_value: int | _str | dict | None = ..., + fill_value: Scalar | None = ..., + sort: _bool = ..., ) -> Self | Series: ... def melt( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index 7d4890ee..9b666bbc 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -4339,3 +4339,64 @@ def test_df_loc_dict() -> None: df.iloc[0] = {"X": 0} check(assert_type(df, pd.DataFrame), pd.DataFrame) + + +def test_unstack() -> None: + """Test different types of argument for `fill_value` in DataFrame.unstack.""" + df = pd.DataFrame( + [ + ["a", "b", pd.Timestamp(2021, 3, 2)], + ["a", "a", pd.Timestamp(2023, 4, 2)], + ["b", "b", pd.Timestamp(2024, 3, 2)], + ] + ).set_index([0, 1]) + df_sr = pd.DataFrame( + [ + ["a", "b", "abc"], + ["a", "a", "def"], + ["b", "b", "ghi"], + ] + ).set_index([0, 1]) + df_flt = pd.DataFrame( + [ + ["a", "b", 1], + ["a", "a", 12], + ["b", "b", 14], + ] + ).set_index([0, 1]) + + check(assert_type(df.unstack(0), pd.DataFrame | pd.Series), pd.DataFrame) + check( + assert_type( + df.unstack(1, fill_value=pd.Timestamp(2023, 4, 5)), pd.DataFrame | pd.Series + ), + pd.DataFrame, + ) + check( + assert_type(df_flt.unstack(1, fill_value=0.0), pd.DataFrame | pd.Series), + pd.DataFrame, + ) + check( + assert_type(df_flt.unstack(1, fill_value=1), pd.DataFrame | pd.Series), + pd.DataFrame, + ) + check( + assert_type(df_sr.unstack(1, fill_value="string"), pd.DataFrame | pd.Series), + pd.DataFrame, + ) + check( + assert_type(df.unstack(0, sort=False), pd.DataFrame | pd.Series), pd.DataFrame + ) + check( + assert_type( + df.unstack(1, fill_value=pd.Timestamp(2023, 4, 5), sort=True), + pd.DataFrame | pd.Series, + ), + pd.DataFrame, + ) + check( + assert_type( + df_flt.unstack(1, fill_value=0.0, sort=False), pd.DataFrame | pd.Series + ), + pd.DataFrame, + )