diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index d59220a1f97f8..31b61906cac53 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -22,6 +22,7 @@ import matplotlib as mpl import numpy as np +from pandas._libs import lib from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly from pandas.util._exceptions import find_stack_level @@ -1221,13 +1222,6 @@ def __init__(self, data, x, y, **kwargs) -> None: if is_integer(y) and not self.data.columns._holds_integer(): y = self.data.columns[y] - # Scatter plot allows to plot objects data - if self._kind == "hexbin": - if len(self.data[x]._get_numeric_data()) == 0: - raise ValueError(self._kind + " requires x column to be numeric") - if len(self.data[y]._get_numeric_data()) == 0: - raise ValueError(self._kind + " requires y column to be numeric") - self.x = x self.y = y @@ -1269,14 +1263,30 @@ class ScatterPlot(PlanePlot): def _kind(self) -> Literal["scatter"]: return "scatter" - def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None: + def __init__( + self, + data, + x, + y, + s=None, + c=None, + *, + colorbar: bool | lib.NoDefault = lib.no_default, + norm=None, + **kwargs, + ) -> None: if s is None: # hide the matplotlib default for size, in case we want to change # the handling of this argument later s = 20 elif is_hashable(s) and s in data.columns: s = data[s] - super().__init__(data, x, y, s=s, **kwargs) + self.s = s + + self.colorbar = colorbar + self.norm = norm + + super().__init__(data, x, y, **kwargs) if is_integer(c) and not self.data.columns._holds_integer(): c = self.data.columns[c] self.c = c @@ -1292,6 +1302,44 @@ def _make_plot(self, fig: Figure): ) color = self.kwds.pop("color", None) + c_values = self._get_c_values(color, color_by_categorical, c_is_column) + norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical) + cb = self._get_colorbar(c_values, c_is_column) + + if self.legend: + label = self.label + else: + label = None + scatter = ax.scatter( + data[x].values, + data[y].values, + c=c_values, + label=label, + cmap=cmap, + norm=norm, + s=self.s, + **self.kwds, + ) + if cb: + cbar_label = c if c_is_column else "" + cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label) + if color_by_categorical: + n_cats = len(self.data[c].cat.categories) + cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) + cbar.ax.set_yticklabels(self.data[c].cat.categories) + + if label is not None: + self._append_legend_handles_labels(scatter, label) + + errors_x = self._get_errorbars(label=x, index=0, yerr=False) + errors_y = self._get_errorbars(label=y, index=0, xerr=False) + if len(errors_x) > 0 or len(errors_y) > 0: + err_kwds = dict(errors_x, **errors_y) + err_kwds["ecolor"] = scatter.get_facecolor()[0] + ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds) + + def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool): + c = self.c if c is not None and color is not None: raise TypeError("Specify exactly one of `c` and `color`") if c is None and color is None: @@ -1304,7 +1352,10 @@ def _make_plot(self, fig: Figure): c_values = self.data[c].values else: c_values = c + return c_values + def _get_norm_and_cmap(self, c_values, color_by_categorical: bool): + c = self.c if self.colormap is not None: cmap = mpl.colormaps.get_cmap(self.colormap) # cmap is only used if c_values are integers, otherwise UserWarning. @@ -1323,45 +1374,21 @@ def _make_plot(self, fig: Figure): cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)]) bounds = np.linspace(0, n_cats, n_cats + 1) norm = colors.BoundaryNorm(bounds, cmap.N) + # TODO: warn that we are ignoring self.norm if user specified it? + # Doesn't happen in any tests 2023-11-09 else: - norm = self.kwds.pop("norm", None) + norm = self.norm + return norm, cmap + + def _get_colorbar(self, c_values, c_is_column: bool) -> bool: # plot colorbar if # 1. colormap is assigned, and # 2.`c` is a column containing only numeric values plot_colorbar = self.colormap or c_is_column - cb = self.kwds.pop("colorbar", is_numeric_dtype(c_values) and plot_colorbar) - - if self.legend and hasattr(self, "label"): - label = self.label - else: - label = None - scatter = ax.scatter( - data[x].values, - data[y].values, - c=c_values, - label=label, - cmap=cmap, - norm=norm, - **self.kwds, - ) - if cb: - cbar_label = c if c_is_column else "" - cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label) - if color_by_categorical: - cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) - cbar.ax.set_yticklabels(self.data[c].cat.categories) - - if label is not None: - self._append_legend_handles_labels(scatter, label) - else: - self.legend = False - - errors_x = self._get_errorbars(label=x, index=0, yerr=False) - errors_y = self._get_errorbars(label=y, index=0, xerr=False) - if len(errors_x) > 0 or len(errors_y) > 0: - err_kwds = dict(errors_x, **errors_y) - err_kwds["ecolor"] = scatter.get_facecolor()[0] - ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds) + cb = self.colorbar + if cb is lib.no_default: + return is_numeric_dtype(c_values) and plot_colorbar + return cb class HexBinPlot(PlanePlot): @@ -1369,19 +1396,27 @@ class HexBinPlot(PlanePlot): def _kind(self) -> Literal["hexbin"]: return "hexbin" - def __init__(self, data, x, y, C=None, **kwargs) -> None: + def __init__(self, data, x, y, C=None, *, colorbar: bool = True, **kwargs) -> None: super().__init__(data, x, y, **kwargs) if is_integer(C) and not self.data.columns._holds_integer(): C = self.data.columns[C] self.C = C + self.colorbar = colorbar + + # Scatter plot allows to plot objects data + if len(self.data[self.x]._get_numeric_data()) == 0: + raise ValueError(self._kind + " requires x column to be numeric") + if len(self.data[self.y]._get_numeric_data()) == 0: + raise ValueError(self._kind + " requires y column to be numeric") + def _make_plot(self, fig: Figure) -> None: x, y, data, C = self.x, self.y, self.data, self.C ax = self.axes[0] # pandas uses colormap, matplotlib uses cmap. cmap = self.colormap or "BuGn" cmap = mpl.colormaps.get_cmap(cmap) - cb = self.kwds.pop("colorbar", True) + cb = self.colorbar if C is None: c_values = None