From 39f07026f155677fb79474b88b0a69df001b4e51 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 8 Nov 2023 16:20:30 -0800 Subject: [PATCH 1/2] REF: Ensure MPLPlot.data is a DataFrame after __init__ --- pandas/plotting/_matplotlib/boxplot.py | 6 +++--- pandas/plotting/_matplotlib/core.py | 22 +++++++++++++++++----- pandas/plotting/_matplotlib/hist.py | 7 +++++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index e6481aab50f6e..52457991944f7 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -215,9 +215,9 @@ def _make_plot(self, fig: Figure) -> None: # When `by` is assigned, the ticklabels will become unique grouped # values, instead of label which is used as subtitle in this case. - ticklabels = [ - pprint_thing(col) for col in self.data.columns.levels[0] - ] + # error: "Index" has no attribute "levels"; maybe "nlevels"? + levels = self.data.columns.levels # type: ignore[attr-defined] + ticklabels = [pprint_thing(col) for col in levels[0]] else: ticklabels = [pprint_thing(label)] diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index b67a8186c8c2b..36e108ea7c8ce 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -13,6 +13,7 @@ TYPE_CHECKING, Any, Literal, + cast, final, ) import warnings @@ -91,7 +92,10 @@ npt, ) - from pandas import Series + from pandas import ( + PeriodIndex, + Series, + ) def _color_in_style(style: str) -> bool: @@ -127,6 +131,7 @@ def orientation(self) -> str | None: return None axes: np.ndarray # of Axes objects + data: DataFrame def __init__( self, @@ -268,6 +273,7 @@ def __init__( self.kwds = kwds self._validate_color_args() + self.data = self._ensure_frame(self.data) @final def _validate_sharex(self, sharex: bool | None, ax, by) -> bool: @@ -600,9 +606,7 @@ def _convert_to_ndarray(data): return data @final - def _compute_plot_data(self): - data = self.data - + def _ensure_frame(self, data) -> DataFrame: if isinstance(data, ABCSeries): label = self.label if label is None and data.name is None: @@ -615,6 +619,11 @@ def _compute_plot_data(self): elif self._kind in ("hist", "box"): cols = self.columns if self.by is None else self.columns + self.by data = data.loc[:, cols] + return data + + @final + def _compute_plot_data(self): + data = self.data # GH15079 reconstruct data if by is defined if self.by is not None: @@ -872,10 +881,13 @@ def _get_xticks(self, convert_period: bool = False): index = self.data.index is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time") + # TODO: be stricter about x? + x: np.ndarray | list if self.use_index: if convert_period and isinstance(index, ABCPeriodIndex): self.data = self.data.reindex(index=index.sort_values()) - x = self.data.index.to_timestamp()._mpl_repr() + index = cast("PeriodIndex", self.data.index) + x = index.to_timestamp()._mpl_repr() elif is_any_real_numeric_dtype(index.dtype): # Matplotlib supports numeric values or datetime objects as # xaxis values. Taking LBYL approach here, by the time diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index fd0dde40c0ab3..ba50ae0e2a977 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -45,7 +45,10 @@ from pandas._typing import PlottingOrientation - from pandas import DataFrame + from pandas import ( + DataFrame, + Series, + ) class HistPlot(LinePlot): @@ -87,7 +90,7 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]): bins = self._calculate_bins(self.data, bins) return bins - def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray: + def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray: """Calculate bins given data""" nd_values = data.infer_objects(copy=False)._get_numeric_data() values = np.ravel(nd_values) From 3dbdc081e1e73ccfaa1db472058f34292568d4bc Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 9 Nov 2023 09:26:17 -0800 Subject: [PATCH 2/2] mypy fixup --- pandas/plotting/_matplotlib/boxplot.py | 5 ++++- pandas/plotting/_matplotlib/core.py | 9 +++++++-- pandas/plotting/_matplotlib/hist.py | 4 +++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index 34316f0b96495..fa45e6b21fbef 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -204,7 +204,10 @@ def _make_plot(self, fig: Figure) -> None: else self.data ) - for i, (label, y) in enumerate(self._iter_data(data=data)): + # error: Argument "data" to "_iter_data" of "MPLPlot" has + # incompatible type "object"; expected "DataFrame | + # dict[Hashable, Series | DataFrame]" + for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type] ax = self._get_ax(i) kwds = self.kwds.copy() diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index c0a904797aa9f..72e3bae10a205 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -1435,7 +1435,10 @@ def _make_plot(self, fig: Figure) -> None: # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]") plotf = self._plot # type: ignore[assignment] - it = self._iter_data(data=self.data) + # error: Incompatible types in assignment (expression has type + # "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has + # type "Iterable[tuple[Hashable, Series]]") + it = self._iter_data(data=self.data) # type: ignore[assignment] stacking_id = self._get_stacking_id() is_errorbar = com.any_not_none(*self.errors.values()) @@ -1448,7 +1451,9 @@ def _make_plot(self, fig: Figure) -> None: colors, kwds, i, - label, # pyright: ignore[reportGeneralTypeIssues] + # error: Argument 4 to "_apply_style_colors" of "MPLPlot" has + # incompatible type "Hashable"; expected "str" + label, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] ) errors = self._get_errorbars(label=label, index=i) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index e42914a9802dd..cdb0c4da203e9 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -134,7 +134,9 @@ def _make_plot(self, fig: Figure) -> None: else self.data ) - for i, (label, y) in enumerate(self._iter_data(data=data)): + # error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible + # type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]" + for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type] ax = self._get_ax(i) kwds = self.kwds.copy()