Skip to content

Commit ea9c19c

Browse files
authored
REF: Ensure MPLPlot.data is a DataFrame after __init__ (#55888)
* REF: Ensure MPLPlot.data is a DataFrame after __init__ * mypy fixup
1 parent f66246e commit ea9c19c

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

pandas/plotting/_matplotlib/boxplot.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def _make_plot(self, fig: Figure) -> None:
204204
else self.data
205205
)
206206

207-
for i, (label, y) in enumerate(self._iter_data(data=data)):
207+
# error: Argument "data" to "_iter_data" of "MPLPlot" has
208+
# incompatible type "object"; expected "DataFrame |
209+
# dict[Hashable, Series | DataFrame]"
210+
for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
208211
ax = self._get_ax(i)
209212
kwds = self.kwds.copy()
210213

@@ -216,9 +219,9 @@ def _make_plot(self, fig: Figure) -> None:
216219

217220
# When `by` is assigned, the ticklabels will become unique grouped
218221
# values, instead of label which is used as subtitle in this case.
219-
ticklabels = [
220-
pprint_thing(col) for col in self.data.columns.levels[0]
221-
]
222+
# error: "Index" has no attribute "levels"; maybe "nlevels"?
223+
levels = self.data.columns.levels # type: ignore[attr-defined]
224+
ticklabels = [pprint_thing(col) for col in levels[0]]
222225
else:
223226
ticklabels = [pprint_thing(label)]
224227

pandas/plotting/_matplotlib/core.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def _kind(self) -> str:
128128
def orientation(self) -> str | None:
129129
return None
130130

131+
data: DataFrame
132+
131133
def __init__(
132134
self,
133135
data,
@@ -270,6 +272,7 @@ def __init__(
270272
self.kwds = kwds
271273

272274
self._validate_color_args()
275+
self.data = self._ensure_frame(self.data)
273276

274277
@final
275278
@staticmethod
@@ -619,9 +622,7 @@ def _convert_to_ndarray(data):
619622
return data
620623

621624
@final
622-
def _compute_plot_data(self):
623-
data = self.data
624-
625+
def _ensure_frame(self, data) -> DataFrame:
625626
if isinstance(data, ABCSeries):
626627
label = self.label
627628
if label is None and data.name is None:
@@ -634,6 +635,11 @@ def _compute_plot_data(self):
634635
elif self._kind in ("hist", "box"):
635636
cols = self.columns if self.by is None else self.columns + self.by
636637
data = data.loc[:, cols]
638+
return data
639+
640+
@final
641+
def _compute_plot_data(self):
642+
data = self.data
637643

638644
# GH15079 reconstruct data if by is defined
639645
if self.by is not None:
@@ -887,6 +893,7 @@ def _get_xticks(self):
887893
index = self.data.index
888894
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
889895

896+
# TODO: be stricter about x?
890897
x: list[int] | np.ndarray
891898
if self.use_index:
892899
if isinstance(index, ABCPeriodIndex):
@@ -1468,7 +1475,10 @@ def _make_plot(self, fig: Figure) -> None:
14681475
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
14691476
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
14701477
plotf = self._plot # type: ignore[assignment]
1471-
it = self._iter_data(data=self.data)
1478+
# error: Incompatible types in assignment (expression has type
1479+
# "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has
1480+
# type "Iterable[tuple[Hashable, Series]]")
1481+
it = self._iter_data(data=self.data) # type: ignore[assignment]
14721482

14731483
stacking_id = self._get_stacking_id()
14741484
is_errorbar = com.any_not_none(*self.errors.values())
@@ -1481,7 +1491,9 @@ def _make_plot(self, fig: Figure) -> None:
14811491
colors,
14821492
kwds,
14831493
i,
1484-
label, # pyright: ignore[reportGeneralTypeIssues]
1494+
# error: Argument 4 to "_apply_style_colors" of "MPLPlot" has
1495+
# incompatible type "Hashable"; expected "str"
1496+
label, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
14851497
)
14861498

14871499
errors = self._get_errorbars(label=label, index=i)

pandas/plotting/_matplotlib/hist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def _make_plot(self, fig: Figure) -> None:
134134
else self.data
135135
)
136136

137-
for i, (label, y) in enumerate(self._iter_data(data=data)):
137+
# error: Argument "data" to "_iter_data" of "MPLPlot" has incompatible
138+
# type "object"; expected "DataFrame | dict[Hashable, Series | DataFrame]"
139+
for i, (label, y) in enumerate(self._iter_data(data=data)): # type: ignore[arg-type]
138140
ax = self._get_ax(i)
139141

140142
kwds = self.kwds.copy()

0 commit comments

Comments
 (0)