diff --git a/pandas-stubs/io/formats/style.pyi b/pandas-stubs/io/formats/style.pyi index aea4da64..cf0ec7f6 100644 --- a/pandas-stubs/io/formats/style.pyi +++ b/pandas-stubs/io/formats/style.pyi @@ -53,6 +53,11 @@ class _DataFrameFunc(Protocol): self, series: DataFrame, /, *args: Any, **kwargs: Any ) -> npt.NDArray | DataFrame: ... +class _MapCallable(Protocol): + def __call__( + self, first_arg: Scalar, /, *args: Any, **kwargs: Any + ) -> str | None: ... + class Styler(StylerRenderer): def __init__( self, @@ -71,6 +76,19 @@ class Styler(StylerRenderer): formatter: ExtFormatter | None = ..., ) -> None: ... def concat(self, other: Styler) -> Styler: ... + @overload + def map( + self, + func: Callable[[Scalar], str | None], + subset: Subset | None = ..., + ) -> Styler: ... + @overload + def map( + self, + func: _MapCallable, + subset: Subset | None = ..., + **kwargs: Any, + ) -> Styler: ... def set_tooltips( self, ttips: DataFrame, diff --git a/tests/test_styler.py b/tests/test_styler.py index 3293a173..0efff33c 100644 --- a/tests/test_styler.py +++ b/tests/test_styler.py @@ -30,7 +30,6 @@ DF = DataFrame({"a": [1, 2, 3], "b": [3.14, 2.72, 1.61]}) - PWD = pathlib.Path(os.path.split(os.path.abspath(__file__))[0]) if TYPE_CHECKING: @@ -233,3 +232,30 @@ def test_styler_columns_and_index() -> None: styler = DF.style check(assert_type(styler.columns, Index), Index) check(assert_type(styler.index, Index), Index) + + +def test_styler_map() -> None: + """Test type returned with Styler.map GH1226.""" + df = DataFrame(data={"col1": [1, -2], "col2": [-3, 4]}) + check( + assert_type( + df.style.map( + lambda v: "color: red;" if isinstance(v, float) and v < 0 else None + ), + Styler, + ), + Styler, + ) + + def color_negative(v: Scalar, /, color: str) -> str | None: + return f"color: {color};" if isinstance(v, float) and v < 0 else None + + df = DataFrame(np.random.randn(5, 2), columns=["A", "B"]) + + check( + assert_type( + df.style.map(color_negative, color="red"), + Styler, + ), + Styler, + )