Skip to content

TYP: plotting, make weights kwd explicit #55877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
cls, ax, y, column_num=None, return_type: str = "axes", **kwds
cls, ax: Axes, y, column_num=None, return_type: str = "axes", **kwds
):
if y.ndim == 2:
y = [remove_na_arraylike(v) for v in y]
Expand Down
5 changes: 4 additions & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from typing import (
TYPE_CHECKING,
Any,
Literal,
final,
)
Expand Down Expand Up @@ -998,7 +999,9 @@ def on_right(self, i: int):
return self.data.columns[i] in self.secondary_y

@final
def _apply_style_colors(self, colors, kwds, col_num, label: str):
def _apply_style_colors(
self, colors, kwds: dict[str, Any], col_num: int, label: str
):
"""
Manage style and color based on column number and its label.
Returns tuple of appropriate style and kwds which "color" may be added.
Expand Down
59 changes: 35 additions & 24 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import (
TYPE_CHECKING,
Any,
Literal,
final,
)

import numpy as np
Expand Down Expand Up @@ -58,13 +60,15 @@ def __init__(
bottom: int | np.ndarray = 0,
*,
range=None,
weights=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not available publicly correct i.e. doesn't need a docstring update?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be passed by the user; im not inclined to update the docstring since i dont know what is accepted as it is passed through to matplotlib

**kwargs,
) -> None:
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom

self._bin_range = range
self.weights = weights

self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
Expand Down Expand Up @@ -96,7 +100,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
@classmethod
def _plot( # type: ignore[override]
cls,
ax,
ax: Axes,
y,
style=None,
bottom: int | np.ndarray = 0,
Expand Down Expand Up @@ -140,7 +144,7 @@ def _make_plot(self, fig: Figure) -> None:
if style is not None:
kwds["style"] = style

kwds = self._make_plot_keywords(kwds, y)
self._make_plot_keywords(kwds, y)

# the bins is multi-dimension array now and each plot need only 1-d and
# when by is applied, label should be columns that are grouped
Expand All @@ -149,21 +153,8 @@ def _make_plot(self, fig: Figure) -> None:
kwds["label"] = self.columns
kwds.pop("color")

# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
# and each sub-array (10,) will be called in each iteration. If users only
# provide 1D array, we assume the same weights is used for all iterations
weights = kwds.get("weights", None)
if weights is not None:
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
try:
weights = weights[:, i]
except IndexError as err:
raise ValueError(
"weights must have the same shape as data, "
"or be a single column"
) from err
weights = weights[~isna(y)]
kwds["weights"] = weights
if self.weights is not None:
kwds["weights"] = self._get_column_weights(self.weights, i, y)

y = reformat_hist_y_given_by(y, self.by)

Expand All @@ -175,12 +166,29 @@ def _make_plot(self, fig: Figure) -> None:

self._append_legend_handles_labels(artists[0], label)

def _make_plot_keywords(self, kwds, y):
def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None:
"""merge BoxPlot/KdePlot properties to passed kwds"""
# y is required for KdePlot
kwds["bottom"] = self.bottom
kwds["bins"] = self.bins
return kwds

@final
@staticmethod
def _get_column_weights(weights, i: int, y):
# We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
# and each sub-array (10,) will be called in each iteration. If users only
# provide 1D array, we assume the same weights is used for all iterations
if weights is not None:
if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
try:
weights = weights[:, i]
except IndexError as err:
raise ValueError(
"weights must have the same shape as data, "
"or be a single column"
) from err
weights = weights[~isna(y)]
return weights

def _post_plot_logic(self, ax: Axes, data) -> None:
if self.orientation == "horizontal":
Expand All @@ -207,11 +215,14 @@ def _kind(self) -> Literal["kde"]:
def orientation(self) -> Literal["vertical"]:
return "vertical"

def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
def __init__(
self, data, bw_method=None, ind=None, *, weights=None, **kwargs
) -> None:
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
self.bw_method = bw_method
self.ind = ind
self.weights = weights

@staticmethod
def _get_ind(y, ind):
Expand All @@ -233,9 +244,10 @@ def _get_ind(y, ind):
return ind

@classmethod
def _plot(
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
def _plot( # type: ignore[override]
cls,
ax,
ax: Axes,
y,
style=None,
bw_method=None,
Expand All @@ -253,10 +265,9 @@ def _plot(
lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
return lines

def _make_plot_keywords(self, kwds, y):
def _make_plot_keywords(self, kwds: dict[str, Any], y) -> None:
kwds["bw_method"] = self.bw_method
kwds["ind"] = self._get_ind(y, ind=self.ind)
return kwds

def _post_plot_logic(self, ax, data) -> None:
ax.set_ylabel("Density")
Expand Down