diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 935d43470..8c8fee8b4 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -148,7 +148,7 @@ ComplexDtypeArg: TypeAlias = ( ) TimedeltaDtypeArg: TypeAlias = Literal["timedelta64[ns]"] TimestampDtypeArg: TypeAlias = Literal["datetime64[ns]"] -CategoryDtypeArg: TypeAlias = Literal["category"] +CategoryDtypeArg: TypeAlias = CategoricalDtype | Literal["category"] AstypeArg: TypeAlias = ( BooleanDtypeArg @@ -159,7 +159,7 @@ AstypeArg: TypeAlias = ( | ComplexDtypeArg | TimedeltaDtypeArg | TimestampDtypeArg - | CategoricalDtype + | CategoryDtypeArg | ExtensionDtype ) # DtypeArg specifies all allowable dtypes in a functions its dtype argument diff --git a/tests/test_frame.py b/tests/test_frame.py index 9be9b77e5..254cfc514 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2442,3 +2442,5 @@ def test_astype() -> None: assert_type(ab.astype({"col1": "int32", "col2": str}), "pd.DataFrame"), pd.DataFrame, ) + check(assert_type(s.astype(pd.CategoricalDtype()), "pd.DataFrame"), pd.DataFrame) + check(assert_type(s.astype("category"), "pd.DataFrame"), pd.DataFrame) # GH 559 diff --git a/tests/test_series.py b/tests/test_series.py index 169fddb5d..63c2cd011 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1591,3 +1591,15 @@ def test_updated_astype() -> None: pd.Series, Decimal, ) + + # Categorical + check( + assert_type(s.astype(pd.CategoricalDtype()), "pd.Series[Any]"), + pd.Series, + np.integer, + ) + check( + assert_type(s.astype("category"), "pd.Series[Any]"), + pd.Series, + np.integer, + )