Skip to content

REF: avoid statefulness in 'color' args #55904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 36 additions & 42 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from matplotlib.artist import setp
import numpy as np

from pandas._libs import lib
from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level

Expand Down Expand Up @@ -113,26 +114,26 @@ def _plot( # type: ignore[override]
else:
return ax, bp

def _validate_color_args(self):
if "color" in self.kwds:
if self.colormap is not None:
warnings.warn(
"'color' and 'colormap' cannot be used "
"simultaneously. Using 'color'",
stacklevel=find_stack_level(),
)
self.color = self.kwds.pop("color")
def _validate_color_args(self, color, colormap):
if color is lib.no_default:
return None

if isinstance(self.color, dict):
valid_keys = ["boxes", "whiskers", "medians", "caps"]
for key in self.color:
if key not in valid_keys:
raise ValueError(
f"color dict contains invalid key '{key}'. "
f"The key must be either {valid_keys}"
)
else:
self.color = None
if colormap is not None:
warnings.warn(
"'color' and 'colormap' cannot be used "
"simultaneously. Using 'color'",
stacklevel=find_stack_level(),
)

if isinstance(color, dict):
valid_keys = ["boxes", "whiskers", "medians", "caps"]
for key in color:
if key not in valid_keys:
raise ValueError(
f"color dict contains invalid key '{key}'. "
f"The key must be either {valid_keys}"
)
return color

@cache_readonly
def _color_attrs(self):
Expand Down Expand Up @@ -182,16 +183,8 @@ def maybe_color_bp(self, bp) -> None:
medians = self.color or self._medians_c
caps = self.color or self._caps_c

# GH 30346, when users specifying those arguments explicitly, our defaults
# for these four kwargs should be overridden; if not, use Pandas settings
if not self.kwds.get("boxprops"):
setp(bp["boxes"], color=boxes, alpha=1)
if not self.kwds.get("whiskerprops"):
setp(bp["whiskers"], color=whiskers, alpha=1)
if not self.kwds.get("medianprops"):
setp(bp["medians"], color=medians, alpha=1)
if not self.kwds.get("capprops"):
setp(bp["caps"], color=caps, alpha=1)
color_tup = (boxes, whiskers, medians, caps)
maybe_color_bp(bp, color_tup=color_tup, **self.kwds)

def _make_plot(self, fig: Figure) -> None:
if self.subplots:
Expand Down Expand Up @@ -276,6 +269,19 @@ def result(self):
return self._return_obj


def maybe_color_bp(bp, color_tup, **kwds) -> None:
# GH#30346, when users specifying those arguments explicitly, our defaults
# for these four kwargs should be overridden; if not, use Pandas settings
if not kwds.get("boxprops"):
setp(bp["boxes"], color=color_tup[0], alpha=1)
if not kwds.get("whiskerprops"):
setp(bp["whiskers"], color=color_tup[1], alpha=1)
if not kwds.get("medianprops"):
setp(bp["medians"], color=color_tup[2], alpha=1)
if not kwds.get("capprops"):
setp(bp["caps"], color=color_tup[3], alpha=1)


def _grouped_plot_by_column(
plotf,
data,
Expand Down Expand Up @@ -389,18 +395,6 @@ def _get_colors():

return result

def maybe_color_bp(bp, **kwds) -> None:
# GH 30346, when users specifying those arguments explicitly, our defaults
# for these four kwargs should be overridden; if not, use Pandas settings
if not kwds.get("boxprops"):
setp(bp["boxes"], color=colors[0], alpha=1)
if not kwds.get("whiskerprops"):
setp(bp["whiskers"], color=colors[1], alpha=1)
if not kwds.get("medianprops"):
setp(bp["medians"], color=colors[2], alpha=1)
if not kwds.get("capprops"):
setp(bp["caps"], color=colors[3], alpha=1)

def plot_group(keys, values, ax: Axes, **kwds):
# GH 45465: xlabel/ylabel need to be popped out before plotting happens
xlabel, ylabel = kwds.pop("xlabel", None), kwds.pop("ylabel", None)
Expand All @@ -419,7 +413,7 @@ def plot_group(keys, values, ax: Axes, **kwds):
_set_ticklabels(
ax=ax, labels=keys, is_vertical=kwds.get("vert", True), rotation=rot
)
maybe_color_bp(bp, **kwds)
maybe_color_bp(bp, color_tup=colors, **kwds)

# Return axes in multiplot case, maybe revisit later # 985
if return_type == "dict":
Expand Down
57 changes: 32 additions & 25 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,10 @@ def __init__(

self.kwds = kwds

self._validate_color_args()
color = kwds.pop("color", lib.no_default)
self.color = self._validate_color_args(color, self.colormap)
assert "color" not in self.kwds

self.data = self._ensure_frame(self.data)

@final
Expand Down Expand Up @@ -396,34 +399,31 @@ def _validate_subplots_kwarg(
out.append((idx_loc,))
return out

def _validate_color_args(self):
if (
"color" in self.kwds
and self.nseries == 1
and self.kwds["color"] is not None
and not is_list_like(self.kwds["color"])
):
def _validate_color_args(self, color, colormap):
if color is lib.no_default:
# It was not provided by the user
if "colors" in self.kwds and colormap is not None:
warnings.warn(
"'color' and 'colormap' cannot be used simultaneously. "
"Using 'color'",
stacklevel=find_stack_level(),
)
return None
if self.nseries == 1 and color is not None and not is_list_like(color):
# support series.plot(color='green')
self.kwds["color"] = [self.kwds["color"]]
color = [color]

if (
"color" in self.kwds
and isinstance(self.kwds["color"], tuple)
and self.nseries == 1
and len(self.kwds["color"]) in (3, 4)
):
if isinstance(color, tuple) and self.nseries == 1 and len(color) in (3, 4):
# support RGB and RGBA tuples in series plot
self.kwds["color"] = [self.kwds["color"]]
color = [color]

if (
"color" in self.kwds or "colors" in self.kwds
) and self.colormap is not None:
if colormap is not None:
warnings.warn(
"'color' and 'colormap' cannot be used simultaneously. Using 'color'",
stacklevel=find_stack_level(),
)

if "color" in self.kwds and self.style is not None:
if self.style is not None:
if is_list_like(self.style):
styles = self.style
else:
Expand All @@ -436,6 +436,7 @@ def _validate_color_args(self):
"'color' keyword argument. Please use one or the "
"other or pass 'style' without a color symbol"
)
return color

@final
@staticmethod
Expand Down Expand Up @@ -1058,11 +1059,14 @@ def _get_colors(
):
if num_colors is None:
num_colors = self.nseries

if color_kwds == "color":
color = self.color
else:
color = self.kwds.get(color_kwds)
return get_standard_colors(
num_colors=num_colors,
colormap=self.colormap,
color=self.kwds.get(color_kwds),
color=color,
)

# TODO: tighter typing for first return?
Expand Down Expand Up @@ -1302,7 +1306,7 @@ def _make_plot(self, fig: Figure):
self.data[c].dtype, CategoricalDtype
)

color = self.kwds.pop("color", None)
color = self.color
c_values = self._get_c_values(color, color_by_categorical, c_is_column)
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
cb = self._get_colorbar(c_values, c_is_column)
Expand Down Expand Up @@ -1487,6 +1491,8 @@ def _make_plot(self, fig: Figure) -> None:
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
kwds = self.kwds.copy()
if self.color is not None:
kwds["color"] = self.color
style, kwds = self._apply_style_colors(
colors,
kwds,
Expand Down Expand Up @@ -1998,8 +2004,9 @@ def __init__(self, data, kind=None, **kwargs) -> None:
self.logx = False
self.loglog = False

def _validate_color_args(self) -> None:
pass
def _validate_color_args(self, color, colormap) -> None:
# TODO: warn if color is passed and ignored?
return None

def _make_plot(self, fig: Figure) -> None:
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
Expand Down
2 changes: 2 additions & 0 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def _make_plot(self, fig: Figure) -> None:
ax = self._get_ax(i)

kwds = self.kwds.copy()
if self.color is not None:
kwds["color"] = self.color

label = pprint_thing(label)
label = self._mark_right_label(label, index=i)
Expand Down