Skip to content

REF: less state in scatterplot #55917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 80 additions & 45 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import matplotlib as mpl
import numpy as np

from pandas._libs import lib
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -1221,13 +1222,6 @@ def __init__(self, data, x, y, **kwargs) -> None:
if is_integer(y) and not self.data.columns._holds_integer():
y = self.data.columns[y]

# Scatter plot allows to plot objects data
if self._kind == "hexbin":
if len(self.data[x]._get_numeric_data()) == 0:
raise ValueError(self._kind + " requires x column to be numeric")
if len(self.data[y]._get_numeric_data()) == 0:
raise ValueError(self._kind + " requires y column to be numeric")

self.x = x
self.y = y

Expand Down Expand Up @@ -1269,14 +1263,30 @@ class ScatterPlot(PlanePlot):
def _kind(self) -> Literal["scatter"]:
return "scatter"

def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
def __init__(
self,
data,
x,
y,
s=None,
c=None,
*,
colorbar: bool | lib.NoDefault = lib.no_default,
norm=None,
**kwargs,
) -> None:
if s is None:
# hide the matplotlib default for size, in case we want to change
# the handling of this argument later
s = 20
elif is_hashable(s) and s in data.columns:
s = data[s]
super().__init__(data, x, y, s=s, **kwargs)
self.s = s

self.colorbar = colorbar
self.norm = norm

super().__init__(data, x, y, **kwargs)
if is_integer(c) and not self.data.columns._holds_integer():
c = self.data.columns[c]
self.c = c
Expand All @@ -1292,6 +1302,44 @@ def _make_plot(self, fig: Figure):
)

color = self.kwds.pop("color", None)
c_values = self._get_c_values(color, color_by_categorical, c_is_column)
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
cb = self._get_colorbar(c_values, c_is_column)

if self.legend:
label = self.label
else:
label = None
scatter = ax.scatter(
data[x].values,
data[y].values,
c=c_values,
label=label,
cmap=cmap,
norm=norm,
s=self.s,
**self.kwds,
)
if cb:
cbar_label = c if c_is_column else ""
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
if color_by_categorical:
n_cats = len(self.data[c].cat.categories)
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
cbar.ax.set_yticklabels(self.data[c].cat.categories)

if label is not None:
self._append_legend_handles_labels(scatter, label)

errors_x = self._get_errorbars(label=x, index=0, yerr=False)
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
if len(errors_x) > 0 or len(errors_y) > 0:
err_kwds = dict(errors_x, **errors_y)
err_kwds["ecolor"] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)

def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
c = self.c
if c is not None and color is not None:
raise TypeError("Specify exactly one of `c` and `color`")
if c is None and color is None:
Expand All @@ -1304,7 +1352,10 @@ def _make_plot(self, fig: Figure):
c_values = self.data[c].values
else:
c_values = c
return c_values

def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
c = self.c
if self.colormap is not None:
cmap = mpl.colormaps.get_cmap(self.colormap)
# cmap is only used if c_values are integers, otherwise UserWarning.
Expand All @@ -1323,65 +1374,49 @@ def _make_plot(self, fig: Figure):
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
bounds = np.linspace(0, n_cats, n_cats + 1)
norm = colors.BoundaryNorm(bounds, cmap.N)
# TODO: warn that we are ignoring self.norm if user specified it?
# Doesn't happen in any tests 2023-11-09
else:
norm = self.kwds.pop("norm", None)
norm = self.norm
return norm, cmap

def _get_colorbar(self, c_values, c_is_column: bool) -> bool:
# plot colorbar if
# 1. colormap is assigned, and
# 2.`c` is a column containing only numeric values
plot_colorbar = self.colormap or c_is_column
cb = self.kwds.pop("colorbar", is_numeric_dtype(c_values) and plot_colorbar)

if self.legend and hasattr(self, "label"):
label = self.label
else:
label = None
scatter = ax.scatter(
data[x].values,
data[y].values,
c=c_values,
label=label,
cmap=cmap,
norm=norm,
**self.kwds,
)
if cb:
cbar_label = c if c_is_column else ""
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
if color_by_categorical:
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
cbar.ax.set_yticklabels(self.data[c].cat.categories)

if label is not None:
self._append_legend_handles_labels(scatter, label)
else:
self.legend = False

errors_x = self._get_errorbars(label=x, index=0, yerr=False)
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
if len(errors_x) > 0 or len(errors_y) > 0:
err_kwds = dict(errors_x, **errors_y)
err_kwds["ecolor"] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)
cb = self.colorbar
if cb is lib.no_default:
return is_numeric_dtype(c_values) and plot_colorbar
return cb


class HexBinPlot(PlanePlot):
@property
def _kind(self) -> Literal["hexbin"]:
return "hexbin"

def __init__(self, data, x, y, C=None, **kwargs) -> None:
def __init__(self, data, x, y, C=None, *, colorbar: bool = True, **kwargs) -> None:
super().__init__(data, x, y, **kwargs)
if is_integer(C) and not self.data.columns._holds_integer():
C = self.data.columns[C]
self.C = C

self.colorbar = colorbar

# Scatter plot allows to plot objects data
if len(self.data[self.x]._get_numeric_data()) == 0:
raise ValueError(self._kind + " requires x column to be numeric")
if len(self.data[self.y]._get_numeric_data()) == 0:
raise ValueError(self._kind + " requires y column to be numeric")

def _make_plot(self, fig: Figure) -> None:
x, y, data, C = self.x, self.y, self.data, self.C
ax = self.axes[0]
# pandas uses colormap, matplotlib uses cmap.
cmap = self.colormap or "BuGn"
cmap = mpl.colormaps.get_cmap(cmap)
cb = self.kwds.pop("colorbar", True)
cb = self.colorbar

if C is None:
c_values = None
Expand Down