From d0085c2880081827d2fdb04f6240bd626bb5b4bc Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:16:15 -0700 Subject: [PATCH 1/6] Use standard import matplotlib as mpl --- pandas/plotting/_matplotlib/boxplot.py | 12 ++++++------ pandas/plotting/_matplotlib/core.py | 6 ++---- pandas/plotting/_matplotlib/misc.py | 9 ++++----- pandas/plotting/_matplotlib/style.py | 4 +--- pandas/plotting/_matplotlib/tools.py | 23 ++++++++++------------- 5 files changed, 23 insertions(+), 31 deletions(-) diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index 2a28cd94b64e5..11c0ba01fff64 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -7,7 +7,7 @@ ) import warnings -from matplotlib.artist import setp +import matplotlib as mpl import numpy as np from pandas._libs import lib @@ -274,13 +274,13 @@ def maybe_color_bp(bp, color_tup, **kwds) -> None: # GH#30346, when users specifying those arguments explicitly, our defaults # for these four kwargs should be overridden; if not, use Pandas settings if not kwds.get("boxprops"): - setp(bp["boxes"], color=color_tup[0], alpha=1) + mpl.artist.setp(bp["boxes"], color=color_tup[0], alpha=1) if not kwds.get("whiskerprops"): - setp(bp["whiskers"], color=color_tup[1], alpha=1) + mpl.artist.setp(bp["whiskers"], color=color_tup[1], alpha=1) if not kwds.get("medianprops"): - setp(bp["medians"], color=color_tup[2], alpha=1) + mpl.artist.setp(bp["medians"], color=color_tup[2], alpha=1) if not kwds.get("capprops"): - setp(bp["caps"], color=color_tup[3], alpha=1) + mpl.artist.setp(bp["caps"], color=color_tup[3], alpha=1) def _grouped_plot_by_column( @@ -455,7 +455,7 @@ def plot_group(keys, values, ax: Axes, **kwds): if ax is None: rc = {"figure.figsize": figsize} if figsize is not None else {} - with plt.rc_context(rc): + with mpl.rc_context(rc): ax = plt.gca() data = data._get_numeric_data() naxes = len(data.columns) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index fffeb9b82492f..934473aecc2e1 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -176,8 +176,6 @@ def __init__( style=None, **kwds, ) -> None: - import matplotlib.pyplot as plt - # if users assign an empty list or tuple, raise `ValueError` # similar to current `df.box` and `df.hist` APIs. if by in ([], ()): @@ -238,7 +236,7 @@ def __init__( self.rot = self._default_rot if grid is None: - grid = False if secondary_y else plt.rcParams["axes.grid"] + grid = False if secondary_y else mpl.rcParams["axes.grid"] self.grid = grid self.legend = legend @@ -1386,7 +1384,7 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool): if c is not None and color is not None: raise TypeError("Specify exactly one of `c` and `color`") if c is None and color is None: - c_values = self.plt.rcParams["patch.facecolor"] + c_values = mpl.rcParams["patch.facecolor"] elif color is not None: c_values = color elif color_by_categorical: diff --git a/pandas/plotting/_matplotlib/misc.py b/pandas/plotting/_matplotlib/misc.py index 1f9212587e05e..4a891ec27e8cb 100644 --- a/pandas/plotting/_matplotlib/misc.py +++ b/pandas/plotting/_matplotlib/misc.py @@ -3,8 +3,7 @@ import random from typing import TYPE_CHECKING -from matplotlib import patches -import matplotlib.lines as mlines +import matplotlib as mpl import numpy as np from pandas.core.dtypes.missing import notna @@ -129,7 +128,7 @@ def scatter_matrix( def _get_marker_compat(marker): - if marker not in mlines.lineMarkers: + if marker not in mpl.lines.lineMarkers: return "o" return marker @@ -190,10 +189,10 @@ def normalize(series): ) ax.legend() - ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none")) + ax.add_patch(mpl.patches.Circle((0.0, 0.0), radius=1.0, facecolor="none")) for xy, name in zip(s, df.columns): - ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray")) + ax.add_patch(mpl.patches.Circle(xy, radius=0.025, facecolor="gray")) if xy[0] < 0.0 and xy[1] < 0.0: ax.text( diff --git a/pandas/plotting/_matplotlib/style.py b/pandas/plotting/_matplotlib/style.py index d725d53bd21ec..962f9711d9916 100644 --- a/pandas/plotting/_matplotlib/style.py +++ b/pandas/plotting/_matplotlib/style.py @@ -260,9 +260,7 @@ def _get_colors_from_color_type(color_type: str, num_colors: int) -> list[Color] def _get_default_colors(num_colors: int) -> list[Color]: """Get `num_colors` of default colors from matplotlib rc params.""" - import matplotlib.pyplot as plt - - colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]] + colors = [c["color"] for c in mpl.rcParams["axes.prop_cycle"]] return colors[0:num_colors] diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py index 50cfdbd967ea7..ae82f0232aee0 100644 --- a/pandas/plotting/_matplotlib/tools.py +++ b/pandas/plotting/_matplotlib/tools.py @@ -5,8 +5,7 @@ from typing import TYPE_CHECKING import warnings -from matplotlib import ticker -import matplotlib.table +import matplotlib as mpl import numpy as np from pandas.util._exceptions import find_stack_level @@ -77,7 +76,7 @@ def table( # error: Argument "cellText" to "table" has incompatible type "ndarray[Any, # Any]"; expected "Sequence[Sequence[str]] | None" - return matplotlib.table.table( + return mpl.table.table( ax, cellText=cellText, # type: ignore[arg-type] rowLabels=rowLabels, @@ -327,10 +326,10 @@ def _remove_labels_from_axis(axis: Axis) -> None: # set_visible will not be effective if # minor axis has NullLocator and NullFormatter (default) - if isinstance(axis.get_minor_locator(), ticker.NullLocator): - axis.set_minor_locator(ticker.AutoLocator()) - if isinstance(axis.get_minor_formatter(), ticker.NullFormatter): - axis.set_minor_formatter(ticker.FormatStrFormatter("")) + if isinstance(axis.get_minor_locator(), mpl.ticker.NullLocator): + axis.set_minor_locator(mpl.ticker.AutoLocator()) + if isinstance(axis.get_minor_formatter(), mpl.ticker.NullFormatter): + axis.set_minor_formatter(mpl.ticker.FormatStrFormatter("")) for t in axis.get_minorticklabels(): t.set_visible(False) @@ -455,17 +454,15 @@ def set_ticks_props( ylabelsize: int | None = None, yrot=None, ): - import matplotlib.pyplot as plt - for ax in flatten_axes(axes): if xlabelsize is not None: - plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) + mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) if xrot is not None: - plt.setp(ax.get_xticklabels(), rotation=xrot) + mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) if ylabelsize is not None: - plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) + mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) if yrot is not None: - plt.setp(ax.get_yticklabels(), rotation=yrot) + mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) return axes From a959b65496925280e5accf627c3043b6689c5aca Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:47:06 -0700 Subject: [PATCH 2/6] Standardaize more matplotlib imports --- pandas/plotting/_core.py | 2 +- pandas/plotting/_matplotlib/__init__.py | 2 +- pandas/plotting/_matplotlib/converter.py | 79 +++++++++---------- pandas/plotting/_matplotlib/core.py | 33 +++----- pandas/plotting/_matplotlib/timeseries.py | 4 +- .../tests/io/formats/style/test_matplotlib.py | 4 +- pandas/tests/plotting/frame/test_frame.py | 33 ++------ .../tests/plotting/frame/test_frame_color.py | 32 ++++---- .../tests/plotting/frame/test_frame_legend.py | 7 +- pandas/tests/plotting/test_converter.py | 35 ++++---- 10 files changed, 93 insertions(+), 138 deletions(-) diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py index c83985917591c..0daf3cfafe81c 100644 --- a/pandas/plotting/_core.py +++ b/pandas/plotting/_core.py @@ -1598,7 +1598,7 @@ def area( See Also -------- - DataFrame.plot : Make plots of DataFrame using matplotlib / pylab. + DataFrame.plot : Make plots of DataFrame using matplotlib. Examples -------- diff --git a/pandas/plotting/_matplotlib/__init__.py b/pandas/plotting/_matplotlib/__init__.py index 75c61da03795a..87f3ca09ad346 100644 --- a/pandas/plotting/_matplotlib/__init__.py +++ b/pandas/plotting/_matplotlib/__init__.py @@ -69,7 +69,7 @@ def plot(data, kind, **kwargs): kwargs["ax"] = getattr(ax, "left_ax", ax) plot_obj = PLOT_CLASSES[kind](data, **kwargs) plot_obj.generate() - plot_obj.draw() + plt.draw_if_interactive() return plot_obj.result diff --git a/pandas/plotting/_matplotlib/converter.py b/pandas/plotting/_matplotlib/converter.py index 50fa722f6dd72..89ca1eef646d3 100644 --- a/pandas/plotting/_matplotlib/converter.py +++ b/pandas/plotting/_matplotlib/converter.py @@ -14,14 +14,7 @@ ) import warnings -import matplotlib.dates as mdates -from matplotlib.ticker import ( - AutoLocator, - Formatter, - Locator, -) -from matplotlib.transforms import nonsingular -import matplotlib.units as munits +import matplotlib as mpl import numpy as np from pandas._libs import lib @@ -122,25 +115,27 @@ def register() -> None: pairs = get_pairs() for type_, cls in pairs: # Cache previous converter if present - if type_ in munits.registry and not isinstance(munits.registry[type_], cls): - previous = munits.registry[type_] + if type_ in mpl.units.registry and not isinstance( + mpl.units.registry[type_], cls + ): + previous = mpl.units.registry[type_] _mpl_units[type_] = previous # Replace with pandas converter - munits.registry[type_] = cls() + mpl.units.registry[type_] = cls() def deregister() -> None: # Renamed in pandas.plotting.__init__ for type_, cls in get_pairs(): # We use type to catch our classes directly, no inheritance - if type(munits.registry.get(type_)) is cls: - munits.registry.pop(type_) + if type(mpl.units.registry.get(type_)) is cls: + mpl.units.registry.pop(type_) # restore the old keys for unit, formatter in _mpl_units.items(): if type(formatter) not in {DatetimeConverter, PeriodConverter, TimeConverter}: # make it idempotent by excluding ours. - munits.registry[unit] = formatter + mpl.units.registry[unit] = formatter def _to_ordinalf(tm: pydt.time) -> float: @@ -157,7 +152,7 @@ def time2num(d): return d -class TimeConverter(munits.ConversionInterface): +class TimeConverter(mpl.units.ConversionInterface): @staticmethod def convert(value, unit, axis): valid_types = (str, pydt.time) @@ -170,13 +165,13 @@ def convert(value, unit, axis): return value @staticmethod - def axisinfo(unit, axis) -> munits.AxisInfo | None: + def axisinfo(unit, axis) -> mpl.units.AxisInfo | None: if unit != "time": return None - majloc = AutoLocator() + majloc = mpl.ticker.AutoLocator() majfmt = TimeFormatter(majloc) - return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label="time") + return mpl.units.AxisInfo(majloc=majloc, majfmt=majfmt, label="time") @staticmethod def default_units(x, axis) -> str: @@ -184,7 +179,7 @@ def default_units(x, axis) -> str: # time formatter -class TimeFormatter(Formatter): +class TimeFormatter(mpl.ticker.Formatter): def __init__(self, locs) -> None: self.locs = locs @@ -227,7 +222,7 @@ def __call__(self, x, pos: int | None = 0) -> str: # Period Conversion -class PeriodConverter(mdates.DateConverter): +class PeriodConverter(mpl.dates.DateConverter): @staticmethod def convert(values, units, axis): if is_nested_list_like(values): @@ -284,7 +279,7 @@ def get_datevalue(date, freq): # Datetime Conversion -class DatetimeConverter(mdates.DateConverter): +class DatetimeConverter(mpl.dates.DateConverter): @staticmethod def convert(values, unit, axis): # values might be a 1-d array, or a list-like of arrays. @@ -298,12 +293,12 @@ def convert(values, unit, axis): def _convert_1d(values, unit, axis): def try_parse(values): try: - return mdates.date2num(tools.to_datetime(values)) + return mpl.dates.date2num(tools.to_datetime(values)) except Exception: return values if isinstance(values, (datetime, pydt.date, np.datetime64, pydt.time)): - return mdates.date2num(values) + return mpl.dates.date2num(values) elif is_integer(values) or is_float(values): return values elif isinstance(values, str): @@ -326,12 +321,12 @@ def try_parse(values): except Exception: pass - values = mdates.date2num(values) + values = mpl.dates.date2num(values) return values @staticmethod - def axisinfo(unit: tzinfo | None, axis) -> munits.AxisInfo: + def axisinfo(unit: tzinfo | None, axis) -> mpl.units.AxisInfo: """ Return the :class:`~matplotlib.units.AxisInfo` for *unit*. @@ -345,17 +340,17 @@ def axisinfo(unit: tzinfo | None, axis) -> munits.AxisInfo: datemin = pydt.date(2000, 1, 1) datemax = pydt.date(2010, 1, 1) - return munits.AxisInfo( + return mpl.units.AxisInfo( majloc=majloc, majfmt=majfmt, label="", default_limits=(datemin, datemax) ) -class PandasAutoDateFormatter(mdates.AutoDateFormatter): +class PandasAutoDateFormatter(mpl.dates.AutoDateFormatter): def __init__(self, locator, tz=None, defaultfmt: str = "%Y-%m-%d") -> None: - mdates.AutoDateFormatter.__init__(self, locator, tz, defaultfmt) + mpl.dates.AutoDateFormatter.__init__(self, locator, tz, defaultfmt) -class PandasAutoDateLocator(mdates.AutoDateLocator): +class PandasAutoDateLocator(mpl.dates.AutoDateLocator): def get_locator(self, dmin, dmax): """Pick the best locator based on a distance.""" tot_sec = (dmax - dmin).total_seconds() @@ -375,17 +370,17 @@ def get_locator(self, dmin, dmax): ) return locator - return mdates.AutoDateLocator.get_locator(self, dmin, dmax) + return mpl.dates.AutoDateLocator.get_locator(self, dmin, dmax) def _get_unit(self): return MilliSecondLocator.get_unit_generic(self._freq) -class MilliSecondLocator(mdates.DateLocator): +class MilliSecondLocator(mpl.dates.DateLocator): UNIT = 1.0 / (24 * 3600 * 1000) def __init__(self, tz) -> None: - mdates.DateLocator.__init__(self, tz) + mpl.dates.DateLocator.__init__(self, tz) self._interval = 1.0 def _get_unit(self): @@ -393,7 +388,7 @@ def _get_unit(self): @staticmethod def get_unit_generic(freq): - unit = mdates.RRuleLocator.get_unit_generic(freq) + unit = mpl.dates.RRuleLocator.get_unit_generic(freq) if unit < 0: return MilliSecondLocator.UNIT return unit @@ -406,7 +401,7 @@ def __call__(self): return [] # We need to cap at the endpoints of valid datetime - nmax, nmin = mdates.date2num((dmax, dmin)) + nmax, nmin = mpl.dates.date2num((dmax, dmin)) num = (nmax - nmin) * 86400 * 1000 max_millis_ticks = 6 @@ -435,12 +430,12 @@ def __call__(self): try: if len(all_dates) > 0: - locs = self.raise_if_exceeds(mdates.date2num(all_dates)) + locs = self.raise_if_exceeds(mpl.dates.date2num(all_dates)) return locs except Exception: # pragma: no cover pass - lims = mdates.date2num([dmin, dmax]) + lims = mpl.dates.date2num([dmin, dmax]) return lims def _get_interval(self): @@ -453,8 +448,8 @@ def autoscale(self): # We need to cap at the endpoints of valid datetime dmin, dmax = self.datalim_to_dt() - vmin = mdates.date2num(dmin) - vmax = mdates.date2num(dmax) + vmin = mpl.dates.date2num(dmin) + vmax = mpl.dates.date2num(dmax) return self.nonsingular(vmin, vmax) @@ -917,7 +912,7 @@ def get_finder(freq: BaseOffset): raise NotImplementedError(f"Unsupported frequency: {dtype_code}") -class TimeSeries_DateLocator(Locator): +class TimeSeries_DateLocator(mpl.ticker.Locator): """ Locates the ticks along an axis controlled by a :class:`Series`. @@ -998,7 +993,7 @@ def autoscale(self): if vmin == vmax: vmin -= 1 vmax += 1 - return nonsingular(vmin, vmax) + return mpl.transforms.nonsingular(vmin, vmax) # ------------------------------------------------------------------------- @@ -1006,7 +1001,7 @@ def autoscale(self): # ------------------------------------------------------------------------- -class TimeSeries_DateFormatter(Formatter): +class TimeSeries_DateFormatter(mpl.ticker.Formatter): """ Formats the ticks along an axis controlled by a :class:`PeriodIndex`. @@ -1082,7 +1077,7 @@ def __call__(self, x, pos: int | None = 0) -> str: return period.strftime(fmt) -class TimeSeries_TimedeltaFormatter(Formatter): +class TimeSeries_TimedeltaFormatter(mpl.ticker.Formatter): """ Formats the ticks along an axis controlled by a :class:`TimedeltaIndex`. """ diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index 934473aecc2e1..2d3c81f2512aa 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -107,9 +107,7 @@ def _color_in_style(style: str) -> bool: """ Check if there is a color letter in the style string. """ - from matplotlib.colors import BASE_COLORS - - return not set(BASE_COLORS).isdisjoint(style) + return not set(mpl.colors.BASE_COLORS).isdisjoint(style) class MPLPlot(ABC): @@ -495,10 +493,6 @@ def _get_nseries(self, data: Series | DataFrame) -> int: def nseries(self) -> int: return self._get_nseries(self.data) - @final - def draw(self) -> None: - self.plt.draw_if_interactive() - @final def generate(self) -> None: self._compute_plot_data() @@ -568,6 +562,8 @@ def axes(self) -> Sequence[Axes]: @final @cache_readonly def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]: + import matplotlib.pyplot as plt + if self.subplots: naxes = ( self.nseries if isinstance(self.subplots, bool) else len(self.subplots) @@ -582,7 +578,7 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]: layout_type=self._layout_type, ) elif self.ax is None: - fig = self.plt.figure(figsize=self.figsize) + fig = plt.figure(figsize=self.figsize) axes = fig.add_subplot(111) else: fig = self.ax.get_figure() @@ -916,13 +912,6 @@ def _get_ax_legend(ax: Axes): ax = other_ax return ax, leg - @final - @cache_readonly - def plt(self): - import matplotlib.pyplot as plt - - return plt - _need_to_set_index = False @final @@ -1217,9 +1206,9 @@ def _get_errorbars( @final def _get_subplots(self, fig: Figure) -> list[Axes]: if Version(mpl.__version__) < Version("3.8"): - from matplotlib.axes import Subplot as Klass + Klass = mpl.axes.Subplot else: - from matplotlib.axes import Axes as Klass + Klass = mpl.axes.Axes return [ ax @@ -1409,12 +1398,10 @@ def _get_norm_and_cmap(self, c_values, color_by_categorical: bool): cmap = None if color_by_categorical and cmap is not None: - from matplotlib import colors - n_cats = len(self.data[c].cat.categories) - cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)]) + cmap = mpl.colors.ListedColormap([cmap(i) for i in range(cmap.N)]) bounds = np.linspace(0, n_cats, n_cats + 1) - norm = colors.BoundaryNorm(bounds, cmap.N) + norm = mpl.colors.BoundaryNorm(bounds, cmap.N) # TODO: warn that we are ignoring self.norm if user specified it? # Doesn't happen in any tests 2023-11-09 else: @@ -1674,8 +1661,6 @@ def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None: ax._stacker_neg_prior[stacking_id] += values # type: ignore[attr-defined] def _post_plot_logic(self, ax: Axes, data) -> None: - from matplotlib.ticker import FixedLocator - def get_label(i): if is_float(i) and i.is_integer(): i = int(i) @@ -1689,7 +1674,7 @@ def get_label(i): xticklabels = [get_label(x) for x in xticks] # error: Argument 1 to "FixedLocator" has incompatible type "ndarray[Any, # Any]"; expected "Sequence[float]" - ax.xaxis.set_major_locator(FixedLocator(xticks)) # type: ignore[arg-type] + ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks)) # type: ignore[arg-type] ax.set_xticklabels(xticklabels) # If the index is an irregular time series, then by default diff --git a/pandas/plotting/_matplotlib/timeseries.py b/pandas/plotting/_matplotlib/timeseries.py index d438f521c0dbc..d95ccad2da565 100644 --- a/pandas/plotting/_matplotlib/timeseries.py +++ b/pandas/plotting/_matplotlib/timeseries.py @@ -333,7 +333,7 @@ def format_dateaxis( default, changing the limits of the x axis will intelligently change the positions of the ticks. """ - from matplotlib import pylab + import matplotlib.pyplot as plt # handle index specific formatting # Note: DatetimeIndex does not use this @@ -365,4 +365,4 @@ def format_dateaxis( else: raise TypeError("index type not supported") - pylab.draw_if_interactive() + plt.draw_if_interactive() diff --git a/pandas/tests/io/formats/style/test_matplotlib.py b/pandas/tests/io/formats/style/test_matplotlib.py index 70ddd65c02d14..296fb20d855c4 100644 --- a/pandas/tests/io/formats/style/test_matplotlib.py +++ b/pandas/tests/io/formats/style/test_matplotlib.py @@ -7,11 +7,9 @@ Series, ) -pytest.importorskip("matplotlib") +mpl = pytest.importorskip("matplotlib") pytest.importorskip("jinja2") -import matplotlib as mpl - from pandas.io.formats.style import Styler pytestmark = pytest.mark.usefixtures("mpl_cleanup") diff --git a/pandas/tests/plotting/frame/test_frame.py b/pandas/tests/plotting/frame/test_frame.py index adb56a40b0071..e809bd33610f1 100644 --- a/pandas/tests/plotting/frame/test_frame.py +++ b/pandas/tests/plotting/frame/test_frame.py @@ -1177,20 +1177,16 @@ def test_hist_df_series(self): _check_ticks_props(axes, xrot=40, yrot=0) def test_hist_df_series_cumulative_density(self): - from matplotlib.patches import Rectangle - series = Series(np.random.default_rng(2).random(10)) ax = series.plot.hist(cumulative=True, bins=4, density=True) # height of last bin (index 5) must be 1.0 - rects = [x for x in ax.get_children() if isinstance(x, Rectangle)] + rects = [x for x in ax.get_children() if isinstance(x, mpl.patches.Rectangle)] tm.assert_almost_equal(rects[-1].get_height(), 1.0) def test_hist_df_series_cumulative(self): - from matplotlib.patches import Rectangle - series = Series(np.random.default_rng(2).random(10)) ax = series.plot.hist(cumulative=True, bins=4) - rects = [x for x in ax.get_children() if isinstance(x, Rectangle)] + rects = [x for x in ax.get_children() if isinstance(x, mpl.patches.Rectangle)] tm.assert_almost_equal(rects[-2].get_height(), 10.0) @@ -1385,8 +1381,6 @@ def test_plot_int_columns(self): ], ) def test_style_by_column(self, markers): - import matplotlib.pyplot as plt - fig = plt.gcf() fig.clf() fig.add_subplot(111) @@ -1969,9 +1963,6 @@ def test_sharex_and_ax(self): # https://github.com/pandas-dev/pandas/issues/9737 using gridspec, # the axis in fig.get_axis() are sorted differently than pandas # expected them, so make sure that only the right ones are removed - import matplotlib.pyplot as plt - - plt.close("all") gs, axes = _generate_4_axes_via_gridspec() df = DataFrame( @@ -2009,8 +2000,6 @@ def test_sharex_false_and_ax(self): # https://github.com/pandas-dev/pandas/issues/9737 using gridspec, # the axis in fig.get_axis() are sorted differently than pandas # expected them, so make sure that only the right ones are removed - import matplotlib.pyplot as plt - df = DataFrame( { "a": [1, 2, 3, 4, 5, 6], @@ -2035,8 +2024,6 @@ def test_sharey_and_ax(self): # https://github.com/pandas-dev/pandas/issues/9737 using gridspec, # the axis in fig.get_axis() are sorted differently than pandas # expected them, so make sure that only the right ones are removed - import matplotlib.pyplot as plt - gs, axes = _generate_4_axes_via_gridspec() df = DataFrame( @@ -2073,8 +2060,6 @@ def _check(axes): def test_sharey_and_ax_tight(self): # https://github.com/pandas-dev/pandas/issues/9737 using gridspec, - import matplotlib.pyplot as plt - df = DataFrame( { "a": [1, 2, 3, 4, 5, 6], @@ -2134,9 +2119,6 @@ def test_memory_leak(self, kind): def test_df_gridspec_patterns_vert_horiz(self): # GH 10819 - from matplotlib import gridspec - import matplotlib.pyplot as plt - ts = Series( np.random.default_rng(2).standard_normal(10), index=date_range("1/1/2000", periods=10), @@ -2149,14 +2131,14 @@ def test_df_gridspec_patterns_vert_horiz(self): ) def _get_vertical_grid(): - gs = gridspec.GridSpec(3, 1) + gs = mpl.gridspec.GridSpec(3, 1) fig = plt.figure() ax1 = fig.add_subplot(gs[:2, :]) ax2 = fig.add_subplot(gs[2, :]) return ax1, ax2 def _get_horizontal_grid(): - gs = gridspec.GridSpec(1, 3) + gs = mpl.gridspec.GridSpec(1, 3) fig = plt.figure() ax1 = fig.add_subplot(gs[:, :2]) ax2 = fig.add_subplot(gs[:, 2]) @@ -2217,9 +2199,6 @@ def _get_horizontal_grid(): def test_df_gridspec_patterns_boxed(self): # GH 10819 - from matplotlib import gridspec - import matplotlib.pyplot as plt - ts = Series( np.random.default_rng(2).standard_normal(10), index=date_range("1/1/2000", periods=10), @@ -2227,7 +2206,7 @@ def test_df_gridspec_patterns_boxed(self): # boxed def _get_boxed_grid(): - gs = gridspec.GridSpec(3, 3) + gs = mpl.gridspec.GridSpec(3, 3) fig = plt.figure() ax1 = fig.add_subplot(gs[:2, :2]) ax2 = fig.add_subplot(gs[:2, 2]) @@ -2595,8 +2574,6 @@ def test_plot_period_index_makes_no_right_shift(self, freq): def _generate_4_axes_via_gridspec(): - import matplotlib.pyplot as plt - gs = mpl.gridspec.GridSpec(2, 2) ax_tl = plt.subplot(gs[0, 0]) ax_ll = plt.subplot(gs[1, 0]) diff --git a/pandas/tests/plotting/frame/test_frame_color.py b/pandas/tests/plotting/frame/test_frame_color.py index 76d3b20aaa2c6..4b35e896e1a6c 100644 --- a/pandas/tests/plotting/frame/test_frame_color.py +++ b/pandas/tests/plotting/frame/test_frame_color.py @@ -364,14 +364,16 @@ def test_line_colors_and_styles_subplots_list_styles(self): _check_colors(ax.get_lines(), linecolors=[c]) def test_area_colors(self): - from matplotlib.collections import PolyCollection - custom_colors = "rgcby" df = DataFrame(np.random.default_rng(2).random((5, 5))) ax = df.plot.area(color=custom_colors) _check_colors(ax.get_lines(), linecolors=custom_colors) - poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)] + poly = [ + o + for o in ax.get_children() + if isinstance(o, mpl.collections.PolyCollection) + ] _check_colors(poly, facecolors=custom_colors) handles, _ = ax.get_legend_handles_labels() @@ -381,14 +383,15 @@ def test_area_colors(self): assert h.get_alpha() is None def test_area_colors_poly(self): - from matplotlib import cm - from matplotlib.collections import PolyCollection - df = DataFrame(np.random.default_rng(2).random((5, 5))) ax = df.plot.area(colormap="jet") - jet_colors = [cm.jet(n) for n in np.linspace(0, 1, len(df))] + jet_colors = [mpl.cm.jet(n) for n in np.linspace(0, 1, len(df))] _check_colors(ax.get_lines(), linecolors=jet_colors) - poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)] + poly = [ + o + for o in ax.get_children() + if isinstance(o, mpl.collections.PolyCollection) + ] _check_colors(poly, facecolors=jet_colors) handles, _ = ax.get_legend_handles_labels() @@ -397,15 +400,16 @@ def test_area_colors_poly(self): assert h.get_alpha() is None def test_area_colors_stacked_false(self): - from matplotlib import cm - from matplotlib.collections import PolyCollection - df = DataFrame(np.random.default_rng(2).random((5, 5))) - jet_colors = [cm.jet(n) for n in np.linspace(0, 1, len(df))] + jet_colors = [mpl.cm.jet(n) for n in np.linspace(0, 1, len(df))] # When stacked=False, alpha is set to 0.5 - ax = df.plot.area(colormap=cm.jet, stacked=False) + ax = df.plot.area(colormap=mpl.cm.jet, stacked=False) _check_colors(ax.get_lines(), linecolors=jet_colors) - poly = [o for o in ax.get_children() if isinstance(o, PolyCollection)] + poly = [ + o + for o in ax.get_children() + if isinstance(o, mpl.collections.PolyCollection) + ] jet_with_alpha = [(c[0], c[1], c[2], 0.5) for c in jet_colors] _check_colors(poly, facecolors=jet_with_alpha) diff --git a/pandas/tests/plotting/frame/test_frame_legend.py b/pandas/tests/plotting/frame/test_frame_legend.py index 402a4b9531e5d..a9723fe4ef871 100644 --- a/pandas/tests/plotting/frame/test_frame_legend.py +++ b/pandas/tests/plotting/frame/test_frame_legend.py @@ -26,9 +26,6 @@ class TestFrameLegend: ) def test_mixed_yerr(self): # https://github.com/pandas-dev/pandas/issues/39522 - from matplotlib.collections import LineCollection - from matplotlib.lines import Line2D - df = DataFrame([{"x": 1, "a": 1, "b": 1}, {"x": 2, "a": 2, "b": 3}]) ax = df.plot("x", "a", c="orange", yerr=0.1, label="orange") @@ -40,8 +37,8 @@ def test_mixed_yerr(self): else: result_handles = legend.legend_handles - assert isinstance(result_handles[0], LineCollection) - assert isinstance(result_handles[1], Line2D) + assert isinstance(result_handles[0], mpl.collections.LineCollection) + assert isinstance(result_handles[1], mpl.lines.Line2D) def test_legend_false(self): # https://github.com/pandas-dev/pandas/issues/40044 diff --git a/pandas/tests/plotting/test_converter.py b/pandas/tests/plotting/test_converter.py index 6a1777b098de0..f2619e754d4ef 100644 --- a/pandas/tests/plotting/test_converter.py +++ b/pandas/tests/plotting/test_converter.py @@ -34,9 +34,8 @@ Second, ) +mpl = pytest.importorskip("matplotlib") plt = pytest.importorskip("matplotlib.pyplot") -dates = pytest.importorskip("matplotlib.dates") -units = pytest.importorskip("matplotlib.units") from pandas.plotting._matplotlib import converter @@ -97,8 +96,8 @@ def test_matplotlib_formatters(self): with cf.option_context("plotting.matplotlib.register_converters", True): with cf.option_context("plotting.matplotlib.register_converters", False): - assert Timestamp not in units.registry - assert Timestamp in units.registry + assert Timestamp not in mpl.units.registry + assert Timestamp in mpl.units.registry def test_option_no_warning(self): s = Series(range(12), index=date_range("2017", periods=12)) @@ -115,25 +114,25 @@ def test_option_no_warning(self): def test_registry_resets(self): # make a copy, to reset to - original = dict(units.registry) + original = dict(mpl.units.registry) try: # get to a known state - units.registry.clear() - date_converter = dates.DateConverter() - units.registry[datetime] = date_converter - units.registry[date] = date_converter + mpl.units.registry.clear() + date_converter = mpl.dates.DateConverter() + mpl.units.registry[datetime] = date_converter + mpl.units.registry[date] = date_converter register_matplotlib_converters() - assert units.registry[date] is not date_converter + assert mpl.units.registry[date] is not date_converter deregister_matplotlib_converters() - assert units.registry[date] is date_converter + assert mpl.units.registry[date] is date_converter finally: # restore original stater - units.registry.clear() + mpl.units.registry.clear() for k, v in original.items(): - units.registry[k] = v + mpl.units.registry[k] = v class TestDateTimeConverter: @@ -148,7 +147,7 @@ def test_convert_accepts_unicode(self, dtc): def test_conversion(self, dtc): rs = dtc.convert(["2012-1-1"], None, None)[0] - xp = dates.date2num(datetime(2012, 1, 1)) + xp = mpl.dates.date2num(datetime(2012, 1, 1)) assert rs == xp rs = dtc.convert("2012-1-1", None, None) @@ -196,7 +195,7 @@ def test_conversion_float(self, dtc): rtol = 0.5 * 10**-9 rs = dtc.convert(Timestamp("2012-1-1 01:02:03", tz="UTC"), None, None) - xp = converter.mdates.date2num(Timestamp("2012-1-1 01:02:03", tz="UTC")) + xp = mpl.dates.date2num(Timestamp("2012-1-1 01:02:03", tz="UTC")) tm.assert_almost_equal(rs, xp, rtol=rtol) rs = dtc.convert( @@ -217,10 +216,10 @@ def test_conversion_float(self, dtc): def test_conversion_outofbounds_datetime(self, dtc, values): # 2579 rs = dtc.convert(values, None, None) - xp = converter.mdates.date2num(values) + xp = mpl.dates.date2num(values) tm.assert_numpy_array_equal(rs, xp) rs = dtc.convert(values[0], None, None) - xp = converter.mdates.date2num(values[0]) + xp = mpl.dates.date2num(values[0]) assert rs == xp @pytest.mark.parametrize( @@ -243,7 +242,7 @@ def test_dateindex_conversion(self, freq, dtc): rtol = 10**-9 dateindex = date_range("2020-01-01", periods=10, freq=freq) rs = dtc.convert(dateindex, None, None) - xp = converter.mdates.date2num(dateindex._mpl_repr()) + xp = mpl.dates.date2num(dateindex._mpl_repr()) tm.assert_almost_equal(rs, xp, rtol=rtol) @pytest.mark.parametrize("offset", [Second(), Milli(), Micro(50)]) From ab860ce83147ee8bf4f810f24f6896a6d59f09b0 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:49:01 -0700 Subject: [PATCH 3/6] Fix matplotlib units --- pandas/plotting/_matplotlib/converter.py | 25 ++++++++++++------------ pandas/tests/plotting/test_converter.py | 21 ++++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pandas/plotting/_matplotlib/converter.py b/pandas/plotting/_matplotlib/converter.py index 89ca1eef646d3..b2740b78189bc 100644 --- a/pandas/plotting/_matplotlib/converter.py +++ b/pandas/plotting/_matplotlib/converter.py @@ -15,6 +15,7 @@ import warnings import matplotlib as mpl +import matplotlib.units as munits import numpy as np from pandas._libs import lib @@ -115,27 +116,25 @@ def register() -> None: pairs = get_pairs() for type_, cls in pairs: # Cache previous converter if present - if type_ in mpl.units.registry and not isinstance( - mpl.units.registry[type_], cls - ): - previous = mpl.units.registry[type_] + if type_ in munits.registry and not isinstance(munits.registry[type_], cls): + previous = munits.registry[type_] _mpl_units[type_] = previous # Replace with pandas converter - mpl.units.registry[type_] = cls() + munits.registry[type_] = cls() def deregister() -> None: # Renamed in pandas.plotting.__init__ for type_, cls in get_pairs(): # We use type to catch our classes directly, no inheritance - if type(mpl.units.registry.get(type_)) is cls: - mpl.units.registry.pop(type_) + if type(munits.registry.get(type_)) is cls: + munits.registry.pop(type_) # restore the old keys for unit, formatter in _mpl_units.items(): if type(formatter) not in {DatetimeConverter, PeriodConverter, TimeConverter}: # make it idempotent by excluding ours. - mpl.units.registry[unit] = formatter + munits.registry[unit] = formatter def _to_ordinalf(tm: pydt.time) -> float: @@ -152,7 +151,7 @@ def time2num(d): return d -class TimeConverter(mpl.units.ConversionInterface): +class TimeConverter(munits.ConversionInterface): @staticmethod def convert(value, unit, axis): valid_types = (str, pydt.time) @@ -165,13 +164,13 @@ def convert(value, unit, axis): return value @staticmethod - def axisinfo(unit, axis) -> mpl.units.AxisInfo | None: + def axisinfo(unit, axis) -> munits.AxisInfo | None: if unit != "time": return None majloc = mpl.ticker.AutoLocator() majfmt = TimeFormatter(majloc) - return mpl.units.AxisInfo(majloc=majloc, majfmt=majfmt, label="time") + return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label="time") @staticmethod def default_units(x, axis) -> str: @@ -326,7 +325,7 @@ def try_parse(values): return values @staticmethod - def axisinfo(unit: tzinfo | None, axis) -> mpl.units.AxisInfo: + def axisinfo(unit: tzinfo | None, axis) -> munits.AxisInfo: """ Return the :class:`~matplotlib.units.AxisInfo` for *unit*. @@ -340,7 +339,7 @@ def axisinfo(unit: tzinfo | None, axis) -> mpl.units.AxisInfo: datemin = pydt.date(2000, 1, 1) datemax = pydt.date(2010, 1, 1) - return mpl.units.AxisInfo( + return munits.AxisInfo( majloc=majloc, majfmt=majfmt, label="", default_limits=(datemin, datemax) ) diff --git a/pandas/tests/plotting/test_converter.py b/pandas/tests/plotting/test_converter.py index f2619e754d4ef..31c514bf4aec8 100644 --- a/pandas/tests/plotting/test_converter.py +++ b/pandas/tests/plotting/test_converter.py @@ -36,6 +36,7 @@ mpl = pytest.importorskip("matplotlib") plt = pytest.importorskip("matplotlib.pyplot") +munits = pytest.importorskip("matplotlib.units") from pandas.plotting._matplotlib import converter @@ -96,8 +97,8 @@ def test_matplotlib_formatters(self): with cf.option_context("plotting.matplotlib.register_converters", True): with cf.option_context("plotting.matplotlib.register_converters", False): - assert Timestamp not in mpl.units.registry - assert Timestamp in mpl.units.registry + assert Timestamp not in munits.registry + assert Timestamp in munits.registry def test_option_no_warning(self): s = Series(range(12), index=date_range("2017", periods=12)) @@ -114,25 +115,25 @@ def test_option_no_warning(self): def test_registry_resets(self): # make a copy, to reset to - original = dict(mpl.units.registry) + original = dict(munits.registry) try: # get to a known state - mpl.units.registry.clear() + munits.registry.clear() date_converter = mpl.dates.DateConverter() - mpl.units.registry[datetime] = date_converter - mpl.units.registry[date] = date_converter + munits.registry[datetime] = date_converter + munits.registry[date] = date_converter register_matplotlib_converters() - assert mpl.units.registry[date] is not date_converter + assert munits.registry[date] is not date_converter deregister_matplotlib_converters() - assert mpl.units.registry[date] is date_converter + assert munits.registry[date] is date_converter finally: # restore original stater - mpl.units.registry.clear() + munits.registry.clear() for k, v in original.items(): - mpl.units.registry[k] = v + munits.registry[k] = v class TestDateTimeConverter: From 9940ef85c86327d60af07b7a4fc5ffd1a223d222 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Jun 2024 12:52:32 -0700 Subject: [PATCH 4/6] Reduce diff a little more --- pandas/tests/plotting/test_converter.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pandas/tests/plotting/test_converter.py b/pandas/tests/plotting/test_converter.py index 31c514bf4aec8..66059b6ca2988 100644 --- a/pandas/tests/plotting/test_converter.py +++ b/pandas/tests/plotting/test_converter.py @@ -36,7 +36,7 @@ mpl = pytest.importorskip("matplotlib") plt = pytest.importorskip("matplotlib.pyplot") -munits = pytest.importorskip("matplotlib.units") +units = pytest.importorskip("matplotlib.units") from pandas.plotting._matplotlib import converter @@ -97,8 +97,8 @@ def test_matplotlib_formatters(self): with cf.option_context("plotting.matplotlib.register_converters", True): with cf.option_context("plotting.matplotlib.register_converters", False): - assert Timestamp not in munits.registry - assert Timestamp in munits.registry + assert Timestamp not in units.registry + assert Timestamp in units.registry def test_option_no_warning(self): s = Series(range(12), index=date_range("2017", periods=12)) @@ -115,25 +115,25 @@ def test_option_no_warning(self): def test_registry_resets(self): # make a copy, to reset to - original = dict(munits.registry) + original = dict(units.registry) try: # get to a known state - munits.registry.clear() + units.registry.clear() date_converter = mpl.dates.DateConverter() - munits.registry[datetime] = date_converter - munits.registry[date] = date_converter + units.registry[datetime] = date_converter + units.registry[date] = date_converter register_matplotlib_converters() - assert munits.registry[date] is not date_converter + assert units.registry[date] is not date_converter deregister_matplotlib_converters() - assert munits.registry[date] is date_converter + assert units.registry[date] is date_converter finally: # restore original stater - munits.registry.clear() + units.registry.clear() for k, v in original.items(): - munits.registry[k] = v + units.registry[k] = v class TestDateTimeConverter: From daa81e60a4520eb560d9e7409b7b6d0d27e91f5e Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:50:08 -0700 Subject: [PATCH 5/6] Import matplotlib dates --- pandas/plotting/_matplotlib/converter.py | 35 ++++++++++++------------ pandas/tests/plotting/test_converter.py | 14 +++++----- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/pandas/plotting/_matplotlib/converter.py b/pandas/plotting/_matplotlib/converter.py index b2740b78189bc..6a37179719c54 100644 --- a/pandas/plotting/_matplotlib/converter.py +++ b/pandas/plotting/_matplotlib/converter.py @@ -15,6 +15,7 @@ import warnings import matplotlib as mpl +import matplotlib.dates as mdates import matplotlib.units as munits import numpy as np @@ -221,7 +222,7 @@ def __call__(self, x, pos: int | None = 0) -> str: # Period Conversion -class PeriodConverter(mpl.dates.DateConverter): +class PeriodConverter(mdates.DateConverter): @staticmethod def convert(values, units, axis): if is_nested_list_like(values): @@ -278,7 +279,7 @@ def get_datevalue(date, freq): # Datetime Conversion -class DatetimeConverter(mpl.dates.DateConverter): +class DatetimeConverter(mdates.DateConverter): @staticmethod def convert(values, unit, axis): # values might be a 1-d array, or a list-like of arrays. @@ -292,12 +293,12 @@ def convert(values, unit, axis): def _convert_1d(values, unit, axis): def try_parse(values): try: - return mpl.dates.date2num(tools.to_datetime(values)) + return mdates.date2num(tools.to_datetime(values)) except Exception: return values if isinstance(values, (datetime, pydt.date, np.datetime64, pydt.time)): - return mpl.dates.date2num(values) + return mdates.date2num(values) elif is_integer(values) or is_float(values): return values elif isinstance(values, str): @@ -320,7 +321,7 @@ def try_parse(values): except Exception: pass - values = mpl.dates.date2num(values) + values = mdates.date2num(values) return values @@ -344,12 +345,12 @@ def axisinfo(unit: tzinfo | None, axis) -> munits.AxisInfo: ) -class PandasAutoDateFormatter(mpl.dates.AutoDateFormatter): +class PandasAutoDateFormatter(mdates.AutoDateFormatter): def __init__(self, locator, tz=None, defaultfmt: str = "%Y-%m-%d") -> None: - mpl.dates.AutoDateFormatter.__init__(self, locator, tz, defaultfmt) + mdates.AutoDateFormatter.__init__(self, locator, tz, defaultfmt) -class PandasAutoDateLocator(mpl.dates.AutoDateLocator): +class PandasAutoDateLocator(mdates.AutoDateLocator): def get_locator(self, dmin, dmax): """Pick the best locator based on a distance.""" tot_sec = (dmax - dmin).total_seconds() @@ -369,17 +370,17 @@ def get_locator(self, dmin, dmax): ) return locator - return mpl.dates.AutoDateLocator.get_locator(self, dmin, dmax) + return mdates.AutoDateLocator.get_locator(self, dmin, dmax) def _get_unit(self): return MilliSecondLocator.get_unit_generic(self._freq) -class MilliSecondLocator(mpl.dates.DateLocator): +class MilliSecondLocator(mdates.DateLocator): UNIT = 1.0 / (24 * 3600 * 1000) def __init__(self, tz) -> None: - mpl.dates.DateLocator.__init__(self, tz) + mdates.DateLocator.__init__(self, tz) self._interval = 1.0 def _get_unit(self): @@ -387,7 +388,7 @@ def _get_unit(self): @staticmethod def get_unit_generic(freq): - unit = mpl.dates.RRuleLocator.get_unit_generic(freq) + unit = mdates.RRuleLocator.get_unit_generic(freq) if unit < 0: return MilliSecondLocator.UNIT return unit @@ -400,7 +401,7 @@ def __call__(self): return [] # We need to cap at the endpoints of valid datetime - nmax, nmin = mpl.dates.date2num((dmax, dmin)) + nmax, nmin = mdates.date2num((dmax, dmin)) num = (nmax - nmin) * 86400 * 1000 max_millis_ticks = 6 @@ -429,12 +430,12 @@ def __call__(self): try: if len(all_dates) > 0: - locs = self.raise_if_exceeds(mpl.dates.date2num(all_dates)) + locs = self.raise_if_exceeds(mdates.date2num(all_dates)) return locs except Exception: # pragma: no cover pass - lims = mpl.dates.date2num([dmin, dmax]) + lims = mdates.date2num([dmin, dmax]) return lims def _get_interval(self): @@ -447,8 +448,8 @@ def autoscale(self): # We need to cap at the endpoints of valid datetime dmin, dmax = self.datalim_to_dt() - vmin = mpl.dates.date2num(dmin) - vmax = mpl.dates.date2num(dmax) + vmin = mdates.date2num(dmin) + vmax = mdates.date2num(dmax) return self.nonsingular(vmin, vmax) diff --git a/pandas/tests/plotting/test_converter.py b/pandas/tests/plotting/test_converter.py index 66059b6ca2988..cfdfa7f723599 100644 --- a/pandas/tests/plotting/test_converter.py +++ b/pandas/tests/plotting/test_converter.py @@ -34,8 +34,8 @@ Second, ) -mpl = pytest.importorskip("matplotlib") plt = pytest.importorskip("matplotlib.pyplot") +dates = pytest.importorskip("matplotlib.dates") units = pytest.importorskip("matplotlib.units") from pandas.plotting._matplotlib import converter @@ -120,7 +120,7 @@ def test_registry_resets(self): try: # get to a known state units.registry.clear() - date_converter = mpl.dates.DateConverter() + date_converter = dates.DateConverter() units.registry[datetime] = date_converter units.registry[date] = date_converter @@ -148,7 +148,7 @@ def test_convert_accepts_unicode(self, dtc): def test_conversion(self, dtc): rs = dtc.convert(["2012-1-1"], None, None)[0] - xp = mpl.dates.date2num(datetime(2012, 1, 1)) + xp = dates.date2num(datetime(2012, 1, 1)) assert rs == xp rs = dtc.convert("2012-1-1", None, None) @@ -196,7 +196,7 @@ def test_conversion_float(self, dtc): rtol = 0.5 * 10**-9 rs = dtc.convert(Timestamp("2012-1-1 01:02:03", tz="UTC"), None, None) - xp = mpl.dates.date2num(Timestamp("2012-1-1 01:02:03", tz="UTC")) + xp = dates.date2num(Timestamp("2012-1-1 01:02:03", tz="UTC")) tm.assert_almost_equal(rs, xp, rtol=rtol) rs = dtc.convert( @@ -217,10 +217,10 @@ def test_conversion_float(self, dtc): def test_conversion_outofbounds_datetime(self, dtc, values): # 2579 rs = dtc.convert(values, None, None) - xp = mpl.dates.date2num(values) + xp = dates.date2num(values) tm.assert_numpy_array_equal(rs, xp) rs = dtc.convert(values[0], None, None) - xp = mpl.dates.date2num(values[0]) + xp = dates.date2num(values[0]) assert rs == xp @pytest.mark.parametrize( @@ -243,7 +243,7 @@ def test_dateindex_conversion(self, freq, dtc): rtol = 10**-9 dateindex = date_range("2020-01-01", periods=10, freq=freq) rs = dtc.convert(dateindex, None, None) - xp = mpl.dates.date2num(dateindex._mpl_repr()) + xp = dates.date2num(dateindex._mpl_repr()) tm.assert_almost_equal(rs, xp, rtol=rtol) @pytest.mark.parametrize("offset", [Second(), Milli(), Micro(50)]) From 13c3b66d71f1fc88e3a07dcb40b6a1ec9fd2a36a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:10:26 -0700 Subject: [PATCH 6/6] satisfy pyright --- pandas/plotting/_matplotlib/converter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pandas/plotting/_matplotlib/converter.py b/pandas/plotting/_matplotlib/converter.py index 6a37179719c54..a8f08769ceae2 100644 --- a/pandas/plotting/_matplotlib/converter.py +++ b/pandas/plotting/_matplotlib/converter.py @@ -169,7 +169,7 @@ def axisinfo(unit, axis) -> munits.AxisInfo | None: if unit != "time": return None - majloc = mpl.ticker.AutoLocator() + majloc = mpl.ticker.AutoLocator() # pyright: ignore[reportAttributeAccessIssue] majfmt = TimeFormatter(majloc) return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label="time") @@ -179,7 +179,7 @@ def default_units(x, axis) -> str: # time formatter -class TimeFormatter(mpl.ticker.Formatter): +class TimeFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue] def __init__(self, locs) -> None: self.locs = locs @@ -912,7 +912,7 @@ def get_finder(freq: BaseOffset): raise NotImplementedError(f"Unsupported frequency: {dtype_code}") -class TimeSeries_DateLocator(mpl.ticker.Locator): +class TimeSeries_DateLocator(mpl.ticker.Locator): # pyright: ignore[reportAttributeAccessIssue] """ Locates the ticks along an axis controlled by a :class:`Series`. @@ -1001,7 +1001,7 @@ def autoscale(self): # ------------------------------------------------------------------------- -class TimeSeries_DateFormatter(mpl.ticker.Formatter): +class TimeSeries_DateFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue] """ Formats the ticks along an axis controlled by a :class:`PeriodIndex`. @@ -1077,7 +1077,7 @@ def __call__(self, x, pos: int | None = 0) -> str: return period.strftime(fmt) -class TimeSeries_TimedeltaFormatter(mpl.ticker.Formatter): +class TimeSeries_TimedeltaFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue] """ Formats the ticks along an axis controlled by a :class:`TimedeltaIndex`. """