Skip to content

Commit d9c7b50

Browse files
authored
REF: less state in scatterplot (#55917)
* REF: less state in scatterplot * REF: helper * REF: helper * less state * REF: less state
1 parent d650212 commit d9c7b50

File tree

1 file changed

+80
-45
lines changed

1 file changed

+80
-45
lines changed

pandas/plotting/_matplotlib/core.py

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import matplotlib as mpl
2323
import numpy as np
2424

25+
from pandas._libs import lib
2526
from pandas.errors import AbstractMethodError
2627
from pandas.util._decorators import cache_readonly
2728
from pandas.util._exceptions import find_stack_level
@@ -1221,13 +1222,6 @@ def __init__(self, data, x, y, **kwargs) -> None:
12211222
if is_integer(y) and not self.data.columns._holds_integer():
12221223
y = self.data.columns[y]
12231224

1224-
# Scatter plot allows to plot objects data
1225-
if self._kind == "hexbin":
1226-
if len(self.data[x]._get_numeric_data()) == 0:
1227-
raise ValueError(self._kind + " requires x column to be numeric")
1228-
if len(self.data[y]._get_numeric_data()) == 0:
1229-
raise ValueError(self._kind + " requires y column to be numeric")
1230-
12311225
self.x = x
12321226
self.y = y
12331227

@@ -1269,14 +1263,30 @@ class ScatterPlot(PlanePlot):
12691263
def _kind(self) -> Literal["scatter"]:
12701264
return "scatter"
12711265

1272-
def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
1266+
def __init__(
1267+
self,
1268+
data,
1269+
x,
1270+
y,
1271+
s=None,
1272+
c=None,
1273+
*,
1274+
colorbar: bool | lib.NoDefault = lib.no_default,
1275+
norm=None,
1276+
**kwargs,
1277+
) -> None:
12731278
if s is None:
12741279
# hide the matplotlib default for size, in case we want to change
12751280
# the handling of this argument later
12761281
s = 20
12771282
elif is_hashable(s) and s in data.columns:
12781283
s = data[s]
1279-
super().__init__(data, x, y, s=s, **kwargs)
1284+
self.s = s
1285+
1286+
self.colorbar = colorbar
1287+
self.norm = norm
1288+
1289+
super().__init__(data, x, y, **kwargs)
12801290
if is_integer(c) and not self.data.columns._holds_integer():
12811291
c = self.data.columns[c]
12821292
self.c = c
@@ -1292,6 +1302,44 @@ def _make_plot(self, fig: Figure):
12921302
)
12931303

12941304
color = self.kwds.pop("color", None)
1305+
c_values = self._get_c_values(color, color_by_categorical, c_is_column)
1306+
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
1307+
cb = self._get_colorbar(c_values, c_is_column)
1308+
1309+
if self.legend:
1310+
label = self.label
1311+
else:
1312+
label = None
1313+
scatter = ax.scatter(
1314+
data[x].values,
1315+
data[y].values,
1316+
c=c_values,
1317+
label=label,
1318+
cmap=cmap,
1319+
norm=norm,
1320+
s=self.s,
1321+
**self.kwds,
1322+
)
1323+
if cb:
1324+
cbar_label = c if c_is_column else ""
1325+
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
1326+
if color_by_categorical:
1327+
n_cats = len(self.data[c].cat.categories)
1328+
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
1329+
cbar.ax.set_yticklabels(self.data[c].cat.categories)
1330+
1331+
if label is not None:
1332+
self._append_legend_handles_labels(scatter, label)
1333+
1334+
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
1335+
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
1336+
if len(errors_x) > 0 or len(errors_y) > 0:
1337+
err_kwds = dict(errors_x, **errors_y)
1338+
err_kwds["ecolor"] = scatter.get_facecolor()[0]
1339+
ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)
1340+
1341+
def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
1342+
c = self.c
12951343
if c is not None and color is not None:
12961344
raise TypeError("Specify exactly one of `c` and `color`")
12971345
if c is None and color is None:
@@ -1304,7 +1352,10 @@ def _make_plot(self, fig: Figure):
13041352
c_values = self.data[c].values
13051353
else:
13061354
c_values = c
1355+
return c_values
13071356

1357+
def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
1358+
c = self.c
13081359
if self.colormap is not None:
13091360
cmap = mpl.colormaps.get_cmap(self.colormap)
13101361
# cmap is only used if c_values are integers, otherwise UserWarning.
@@ -1323,65 +1374,49 @@ def _make_plot(self, fig: Figure):
13231374
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
13241375
bounds = np.linspace(0, n_cats, n_cats + 1)
13251376
norm = colors.BoundaryNorm(bounds, cmap.N)
1377+
# TODO: warn that we are ignoring self.norm if user specified it?
1378+
# Doesn't happen in any tests 2023-11-09
13261379
else:
1327-
norm = self.kwds.pop("norm", None)
1380+
norm = self.norm
1381+
return norm, cmap
1382+
1383+
def _get_colorbar(self, c_values, c_is_column: bool) -> bool:
13281384
# plot colorbar if
13291385
# 1. colormap is assigned, and
13301386
# 2.`c` is a column containing only numeric values
13311387
plot_colorbar = self.colormap or c_is_column
1332-
cb = self.kwds.pop("colorbar", is_numeric_dtype(c_values) and plot_colorbar)
1333-
1334-
if self.legend and hasattr(self, "label"):
1335-
label = self.label
1336-
else:
1337-
label = None
1338-
scatter = ax.scatter(
1339-
data[x].values,
1340-
data[y].values,
1341-
c=c_values,
1342-
label=label,
1343-
cmap=cmap,
1344-
norm=norm,
1345-
**self.kwds,
1346-
)
1347-
if cb:
1348-
cbar_label = c if c_is_column else ""
1349-
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
1350-
if color_by_categorical:
1351-
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
1352-
cbar.ax.set_yticklabels(self.data[c].cat.categories)
1353-
1354-
if label is not None:
1355-
self._append_legend_handles_labels(scatter, label)
1356-
else:
1357-
self.legend = False
1358-
1359-
errors_x = self._get_errorbars(label=x, index=0, yerr=False)
1360-
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
1361-
if len(errors_x) > 0 or len(errors_y) > 0:
1362-
err_kwds = dict(errors_x, **errors_y)
1363-
err_kwds["ecolor"] = scatter.get_facecolor()[0]
1364-
ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)
1388+
cb = self.colorbar
1389+
if cb is lib.no_default:
1390+
return is_numeric_dtype(c_values) and plot_colorbar
1391+
return cb
13651392

13661393

13671394
class HexBinPlot(PlanePlot):
13681395
@property
13691396
def _kind(self) -> Literal["hexbin"]:
13701397
return "hexbin"
13711398

1372-
def __init__(self, data, x, y, C=None, **kwargs) -> None:
1399+
def __init__(self, data, x, y, C=None, *, colorbar: bool = True, **kwargs) -> None:
13731400
super().__init__(data, x, y, **kwargs)
13741401
if is_integer(C) and not self.data.columns._holds_integer():
13751402
C = self.data.columns[C]
13761403
self.C = C
13771404

1405+
self.colorbar = colorbar
1406+
1407+
# Scatter plot allows to plot objects data
1408+
if len(self.data[self.x]._get_numeric_data()) == 0:
1409+
raise ValueError(self._kind + " requires x column to be numeric")
1410+
if len(self.data[self.y]._get_numeric_data()) == 0:
1411+
raise ValueError(self._kind + " requires y column to be numeric")
1412+
13781413
def _make_plot(self, fig: Figure) -> None:
13791414
x, y, data, C = self.x, self.y, self.data, self.C
13801415
ax = self.axes[0]
13811416
# pandas uses colormap, matplotlib uses cmap.
13821417
cmap = self.colormap or "BuGn"
13831418
cmap = mpl.colormaps.get_cmap(cmap)
1384-
cb = self.kwds.pop("colorbar", True)
1419+
cb = self.colorbar
13851420

13861421
if C is None:
13871422
c_values = None

0 commit comments

Comments
 (0)