22
22
import matplotlib as mpl
23
23
import numpy as np
24
24
25
+ from pandas ._libs import lib
25
26
from pandas .errors import AbstractMethodError
26
27
from pandas .util ._decorators import cache_readonly
27
28
from pandas .util ._exceptions import find_stack_level
@@ -1221,13 +1222,6 @@ def __init__(self, data, x, y, **kwargs) -> None:
1221
1222
if is_integer (y ) and not self .data .columns ._holds_integer ():
1222
1223
y = self .data .columns [y ]
1223
1224
1224
- # Scatter plot allows to plot objects data
1225
- if self ._kind == "hexbin" :
1226
- if len (self .data [x ]._get_numeric_data ()) == 0 :
1227
- raise ValueError (self ._kind + " requires x column to be numeric" )
1228
- if len (self .data [y ]._get_numeric_data ()) == 0 :
1229
- raise ValueError (self ._kind + " requires y column to be numeric" )
1230
-
1231
1225
self .x = x
1232
1226
self .y = y
1233
1227
@@ -1269,14 +1263,30 @@ class ScatterPlot(PlanePlot):
1269
1263
def _kind (self ) -> Literal ["scatter" ]:
1270
1264
return "scatter"
1271
1265
1272
- def __init__ (self , data , x , y , s = None , c = None , ** kwargs ) -> None :
1266
+ def __init__ (
1267
+ self ,
1268
+ data ,
1269
+ x ,
1270
+ y ,
1271
+ s = None ,
1272
+ c = None ,
1273
+ * ,
1274
+ colorbar : bool | lib .NoDefault = lib .no_default ,
1275
+ norm = None ,
1276
+ ** kwargs ,
1277
+ ) -> None :
1273
1278
if s is None :
1274
1279
# hide the matplotlib default for size, in case we want to change
1275
1280
# the handling of this argument later
1276
1281
s = 20
1277
1282
elif is_hashable (s ) and s in data .columns :
1278
1283
s = data [s ]
1279
- super ().__init__ (data , x , y , s = s , ** kwargs )
1284
+ self .s = s
1285
+
1286
+ self .colorbar = colorbar
1287
+ self .norm = norm
1288
+
1289
+ super ().__init__ (data , x , y , ** kwargs )
1280
1290
if is_integer (c ) and not self .data .columns ._holds_integer ():
1281
1291
c = self .data .columns [c ]
1282
1292
self .c = c
@@ -1292,6 +1302,44 @@ def _make_plot(self, fig: Figure):
1292
1302
)
1293
1303
1294
1304
color = self .kwds .pop ("color" , None )
1305
+ c_values = self ._get_c_values (color , color_by_categorical , c_is_column )
1306
+ norm , cmap = self ._get_norm_and_cmap (c_values , color_by_categorical )
1307
+ cb = self ._get_colorbar (c_values , c_is_column )
1308
+
1309
+ if self .legend :
1310
+ label = self .label
1311
+ else :
1312
+ label = None
1313
+ scatter = ax .scatter (
1314
+ data [x ].values ,
1315
+ data [y ].values ,
1316
+ c = c_values ,
1317
+ label = label ,
1318
+ cmap = cmap ,
1319
+ norm = norm ,
1320
+ s = self .s ,
1321
+ ** self .kwds ,
1322
+ )
1323
+ if cb :
1324
+ cbar_label = c if c_is_column else ""
1325
+ cbar = self ._plot_colorbar (ax , fig = fig , label = cbar_label )
1326
+ if color_by_categorical :
1327
+ n_cats = len (self .data [c ].cat .categories )
1328
+ cbar .set_ticks (np .linspace (0.5 , n_cats - 0.5 , n_cats ))
1329
+ cbar .ax .set_yticklabels (self .data [c ].cat .categories )
1330
+
1331
+ if label is not None :
1332
+ self ._append_legend_handles_labels (scatter , label )
1333
+
1334
+ errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
1335
+ errors_y = self ._get_errorbars (label = y , index = 0 , xerr = False )
1336
+ if len (errors_x ) > 0 or len (errors_y ) > 0 :
1337
+ err_kwds = dict (errors_x , ** errors_y )
1338
+ err_kwds ["ecolor" ] = scatter .get_facecolor ()[0 ]
1339
+ ax .errorbar (data [x ].values , data [y ].values , linestyle = "none" , ** err_kwds )
1340
+
1341
+ def _get_c_values (self , color , color_by_categorical : bool , c_is_column : bool ):
1342
+ c = self .c
1295
1343
if c is not None and color is not None :
1296
1344
raise TypeError ("Specify exactly one of `c` and `color`" )
1297
1345
if c is None and color is None :
@@ -1304,7 +1352,10 @@ def _make_plot(self, fig: Figure):
1304
1352
c_values = self .data [c ].values
1305
1353
else :
1306
1354
c_values = c
1355
+ return c_values
1307
1356
1357
+ def _get_norm_and_cmap (self , c_values , color_by_categorical : bool ):
1358
+ c = self .c
1308
1359
if self .colormap is not None :
1309
1360
cmap = mpl .colormaps .get_cmap (self .colormap )
1310
1361
# cmap is only used if c_values are integers, otherwise UserWarning.
@@ -1323,65 +1374,49 @@ def _make_plot(self, fig: Figure):
1323
1374
cmap = colors .ListedColormap ([cmap (i ) for i in range (cmap .N )])
1324
1375
bounds = np .linspace (0 , n_cats , n_cats + 1 )
1325
1376
norm = colors .BoundaryNorm (bounds , cmap .N )
1377
+ # TODO: warn that we are ignoring self.norm if user specified it?
1378
+ # Doesn't happen in any tests 2023-11-09
1326
1379
else :
1327
- norm = self .kwds .pop ("norm" , None )
1380
+ norm = self .norm
1381
+ return norm , cmap
1382
+
1383
+ def _get_colorbar (self , c_values , c_is_column : bool ) -> bool :
1328
1384
# plot colorbar if
1329
1385
# 1. colormap is assigned, and
1330
1386
# 2.`c` is a column containing only numeric values
1331
1387
plot_colorbar = self .colormap or c_is_column
1332
- cb = self .kwds .pop ("colorbar" , is_numeric_dtype (c_values ) and plot_colorbar )
1333
-
1334
- if self .legend and hasattr (self , "label" ):
1335
- label = self .label
1336
- else :
1337
- label = None
1338
- scatter = ax .scatter (
1339
- data [x ].values ,
1340
- data [y ].values ,
1341
- c = c_values ,
1342
- label = label ,
1343
- cmap = cmap ,
1344
- norm = norm ,
1345
- ** self .kwds ,
1346
- )
1347
- if cb :
1348
- cbar_label = c if c_is_column else ""
1349
- cbar = self ._plot_colorbar (ax , fig = fig , label = cbar_label )
1350
- if color_by_categorical :
1351
- cbar .set_ticks (np .linspace (0.5 , n_cats - 0.5 , n_cats ))
1352
- cbar .ax .set_yticklabels (self .data [c ].cat .categories )
1353
-
1354
- if label is not None :
1355
- self ._append_legend_handles_labels (scatter , label )
1356
- else :
1357
- self .legend = False
1358
-
1359
- errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
1360
- errors_y = self ._get_errorbars (label = y , index = 0 , xerr = False )
1361
- if len (errors_x ) > 0 or len (errors_y ) > 0 :
1362
- err_kwds = dict (errors_x , ** errors_y )
1363
- err_kwds ["ecolor" ] = scatter .get_facecolor ()[0 ]
1364
- ax .errorbar (data [x ].values , data [y ].values , linestyle = "none" , ** err_kwds )
1388
+ cb = self .colorbar
1389
+ if cb is lib .no_default :
1390
+ return is_numeric_dtype (c_values ) and plot_colorbar
1391
+ return cb
1365
1392
1366
1393
1367
1394
class HexBinPlot (PlanePlot ):
1368
1395
@property
1369
1396
def _kind (self ) -> Literal ["hexbin" ]:
1370
1397
return "hexbin"
1371
1398
1372
- def __init__ (self , data , x , y , C = None , ** kwargs ) -> None :
1399
+ def __init__ (self , data , x , y , C = None , * , colorbar : bool = True , * *kwargs ) -> None :
1373
1400
super ().__init__ (data , x , y , ** kwargs )
1374
1401
if is_integer (C ) and not self .data .columns ._holds_integer ():
1375
1402
C = self .data .columns [C ]
1376
1403
self .C = C
1377
1404
1405
+ self .colorbar = colorbar
1406
+
1407
+ # Scatter plot allows to plot objects data
1408
+ if len (self .data [self .x ]._get_numeric_data ()) == 0 :
1409
+ raise ValueError (self ._kind + " requires x column to be numeric" )
1410
+ if len (self .data [self .y ]._get_numeric_data ()) == 0 :
1411
+ raise ValueError (self ._kind + " requires y column to be numeric" )
1412
+
1378
1413
def _make_plot (self , fig : Figure ) -> None :
1379
1414
x , y , data , C = self .x , self .y , self .data , self .C
1380
1415
ax = self .axes [0 ]
1381
1416
# pandas uses colormap, matplotlib uses cmap.
1382
1417
cmap = self .colormap or "BuGn"
1383
1418
cmap = mpl .colormaps .get_cmap (cmap )
1384
- cb = self .kwds . pop ( " colorbar" , True )
1419
+ cb = self .colorbar
1385
1420
1386
1421
if C is None :
1387
1422
c_values = None
0 commit comments