From 27c2e029fdbf6c62bc78dc2c2d724ead35f9fd20 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Sun, 9 Jun 2024 16:13:22 -0700 Subject: [PATCH 1/3] Avoid ravel in plotting --- pandas/plotting/_matplotlib/boxplot.py | 9 ++------- pandas/plotting/_matplotlib/core.py | 2 +- pandas/plotting/_matplotlib/hist.py | 14 +++++--------- pandas/plotting/_matplotlib/tools.py | 18 +++++++++++------- 4 files changed, 19 insertions(+), 24 deletions(-) diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index 11c0ba01fff64..6bb10068bee38 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -311,8 +311,6 @@ def _grouped_plot_by_column( layout=layout, ) - _axes = flatten_axes(axes) - # GH 45465: move the "by" label based on "vert" xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None) if kwargs.get("vert", True): @@ -322,8 +320,7 @@ def _grouped_plot_by_column( ax_values = [] - for i, col in enumerate(columns): - ax = _axes[i] + for ax, col in zip(flatten_axes(axes), columns): gp_col = grouped[col] keys, values = zip(*gp_col) re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs) @@ -531,10 +528,8 @@ def boxplot_frame_groupby( figsize=figsize, layout=layout, ) - axes = flatten_axes(axes) - data = {} - for (key, group), ax in zip(grouped, axes): + for (key, group), ax in zip(grouped, flatten_axes(axes)): d = group.boxplot( ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds ) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 2d3c81f2512aa..22be9baf1ff5c 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -586,7 +586,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]: fig.set_size_inches(self.figsize) axes = self.ax - axes = flatten_axes(axes) + axes = np.fromiter(flatten_axes(axes), dtype=object) if self.logx is True or self.loglog is True: [a.set_xscale("log") for a in axes] diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index ca635386be335..b1af23be8124c 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -95,7 +95,9 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]): def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray: """Calculate bins given data""" nd_values = data.infer_objects()._get_numeric_data() - values = np.ravel(nd_values) + values = nd_values.values + if nd_values.ndim == 1: + values = values.reshape(-1) values = values[~isna(values)] hist, bins = np.histogram(values, bins=bins, range=self._bin_range) @@ -322,10 +324,7 @@ def _grouped_plot( naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout ) - _axes = flatten_axes(axes) - - for i, (key, group) in enumerate(grouped): - ax = _axes[i] + for ax, (key, group) in zip(flatten_axes(axes), grouped): if numeric_only and isinstance(group, ABCDataFrame): group = group._get_numeric_data() plotf(group, ax, **kwargs) @@ -557,12 +556,9 @@ def hist_frame( figsize=figsize, layout=layout, ) - _axes = flatten_axes(axes) - can_set_label = "label" not in kwds - for i, col in enumerate(data.columns): - ax = _axes[i] + for ax, col in zip(flatten_axes(axes), data.columns): if legend and can_set_label: kwds["label"] = col ax.hist(data[col].dropna().values, bins=bins, **kwds) diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index ae82f0232aee0..8eb123c05a6bc 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -18,7 +18,10 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import ( + Generator, + Iterable, + ) from matplotlib.axes import Axes from matplotlib.axis import Axis @@ -231,7 +234,7 @@ def create_subplots( else: if is_list_like(ax): if squeeze: - ax = flatten_axes(ax) + ax = np.fromiter(flatten_axes(ax), dtype=object) if layout is not None: warnings.warn( "When passing multiple axes, layout keyword is ignored.", @@ -260,7 +263,7 @@ def create_subplots( if squeeze: return fig, ax else: - return fig, flatten_axes(ax) + return fig, np.fromiter(flatten_axes(ax), dtype=object) else: warnings.warn( "To output multiple subplots, the figure containing " @@ -439,12 +442,13 @@ def handle_shared_axes( _remove_labels_from_axis(ax.yaxis) -def flatten_axes(axes: Axes | Iterable[Axes]) -> np.ndarray: +def flatten_axes(axes: Axes | Iterable[Axes]) -> Generator[Axes, None, None]: if not is_list_like(axes): - return np.array([axes]) + yield axes elif isinstance(axes, (np.ndarray, ABCIndex)): - return np.asarray(axes).ravel() - return np.array(axes) + yield from np.asarray(axes).reshape(-1) + else: + yield from axes def set_ticks_props( From e29deb94f0ceba3b7d30ea1a381c2ed54ccbea76 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 10 Jun 2024 11:12:35 -0700 Subject: [PATCH 2/3] Use reshape instead of ravel --- pandas/plotting/_matplotlib/hist.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index b1af23be8124c..2c4d714bf1a0c 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -96,12 +96,11 @@ def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray: """Calculate bins given data""" nd_values = data.infer_objects()._get_numeric_data() values = nd_values.values - if nd_values.ndim == 1: + if nd_values.ndim == 2: values = values.reshape(-1) values = values[~isna(values)] - hist, bins = np.histogram(values, bins=bins, range=self._bin_range) - return bins + return np.histogram_bin_edges(values, bins=bins, range=self._bin_range) # error: Signature of "_plot" incompatible with supertype "LinePlot" @classmethod From bfa72c8a2b6c221b5ed105ba7ce572a1a0d79cb3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:00:22 -0700 Subject: [PATCH 3/3] Add type ignore --- pandas/plotting/_matplotlib/tools.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index 8eb123c05a6bc..f9c370b2486fd 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -444,11 +444,11 @@ def handle_shared_axes( def flatten_axes(axes: Axes | Iterable[Axes]) -> Generator[Axes, None, None]: if not is_list_like(axes): - yield axes + yield axes # type: ignore[misc] elif isinstance(axes, (np.ndarray, ABCIndex)): yield from np.asarray(axes).reshape(-1) else: - yield from axes + yield from axes # type: ignore[misc] def set_ticks_props( @@ -460,13 +460,13 @@ def set_ticks_props( ): for ax in flatten_axes(axes): if xlabelsize is not None: - mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) + mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) # type: ignore[arg-type] if xrot is not None: - mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) + mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) # type: ignore[arg-type] if ylabelsize is not None: - mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) + mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) # type: ignore[arg-type] if yrot is not None: - mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) + mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) # type: ignore[arg-type] return axes