diff --git a/doc/source/changes/version_0_33.rst.inc b/doc/source/changes/version_0_33.rst.inc index c288ba1f1..47882d096 100644 --- a/doc/source/changes/version_0_33.rst.inc +++ b/doc/source/changes/version_0_33.rst.inc @@ -51,7 +51,18 @@ New features Miscellaneous improvements ^^^^^^^^^^^^^^^^^^^^^^^^^^ -* improved something. +* greatly improved :py:obj:`Array.plot()` method and "submethods" (:py:obj:`Array.plot.bar()`, etc.) + + - support `x`, `y` and `by` arguments in plot functions where it make sense + When only some of them are specified, the other arguments pick from remaining + available axes. This means a lot of plots can now be expressed more intuitively and concisely + (you do not need to transpose your array to get the result you want, you just + specify the axes you want to use in 'x' or 'y'. + - `subplots` argument now accepts an axis (or tuple of them) in addition to a + boolean to specify *which* axes to use as subplots. + - support for *labels* (instead of axes) in x and y for line plot and scatter. + - support passing a dict as legend to customize the legend. + - many tweaks to make several plots look better out of the box. Fixes diff --git a/larray/core/array.py b/larray/core/array.py index aad1f5e8e..8617a2361 100644 --- a/larray/core/array.py +++ b/larray/core/array.py @@ -324,6 +324,306 @@ def __iter__(self): return self +def _use_pandas_plot_docstring(f): + f.__doc__ = getattr(pd.DataFrame.plot, f.__name__).__doc__ + return f + + +class PlotObject(object): + __slots__ = ('array',) + + def __init__(self, array): + self.array = array + + @staticmethod + def _handle_x_y_axes(axes, x, y, subplots): + label_axis = None + + if np.isscalar(x) and x not in axes: + x = axes._translate_axis_key(x) + label_axis = x.axis + + if np.isscalar(y) and y not in axes: + y = axes._translate_axis_key(y) + if label_axis is not None and y.axis is not x.axis: + raise ValueError(f'{x} and {y} are labels from different axes') + label_axis = y.axis + + def handle_axes_arg(avail_axes, arg): + if arg is not None: + arg = avail_axes[arg] + if isinstance(arg, Axis): + arg = AxisCollection([arg]) + avail_axes = avail_axes - arg + return avail_axes, arg + + if label_axis is not None: + available_axes = axes - label_axis + else: + available_axes, x = handle_axes_arg(axes, x) + available_axes, y = handle_axes_arg(available_axes, y) + + if subplots is True: + # use last available axis by default + subplots = [-1] + + if subplots: + available_axes, subplot_axes = handle_axes_arg(available_axes, subplots) + else: + subplot_axes = AxisCollection() + + if label_axis is not None: + series_axes = available_axes[:-1] + if y is None: + # create a Group with the labels of label_axis not used for x + # the weird construction is to get a Group (and not an Axis) but avoid getting an LSet + # which would evaluate to an OrderedSet which is not supported by later code + y = label_axis.i[label_axis[:].difference(x).translate()] + else: + if x is None and y is None: + # use last available axis by default + x = available_axes[[-1]] + series_axes = available_axes - x + elif x is None: + x = available_axes + series_axes = y + y = None + elif y is None: + series_axes = available_axes + else: + if available_axes: + raise ValueError(f"some axes are not used: {available_axes}") + series_axes = y + y = None + assert isinstance(x, AxisCollection) + assert isinstance(series_axes, AxisCollection) + assert isinstance(subplot_axes, AxisCollection) + assert y is None + + return subplot_axes, series_axes, x, y + + @staticmethod + def _to_pd_obj(array): + if array.ndim == 1: + return array.to_series() + else: + return array.to_frame() + + @staticmethod + def _plot_array(array, *args, x=None, y=None, series=None, _x_axes_last=False, **kwargs): + label_axis = None + if array.ndim == 1: + pass + elif isinstance(x, AxisCollection): + # FIXME: arr.plot(x='b', y='d1', subplots='a') does not work + # (and since arr.plot(y='d1', subplots='a') defaults to x='c' we can't get this easily + # XXX: arr.plot(y='d', subplots=True) => x=(a, b), subplots='c' + # I wonder if x='a', subplots=(b, c) wouldn't be better? + assert y is None + # move x_axes first + array = array.transpose(x) + array = array.combine_axes(x, sep=' ') if len(x) >= 2 else array + if _x_axes_last: + # move combined axis last + array = array.transpose(..., array.axes[0]) + x = None + else: + assert (x is None or isinstance(x, Group)) and (y is None or isinstance(y, Group)) + if isinstance(x, Group): + label_axis = x.axis + x = x.eval() + if isinstance(y, Group): + label_axis = y.axis + y = y.eval() + + if label_axis is not None: + # move label_axis last (it must be a dataframe column) + array = array.transpose(..., label_axis) + + lineplot = 'kind' not in kwargs or kwargs['kind'] == 'line' + if lineplot and label_axis is not None and series is not None and len(series) > 0: + # the problem with this approach (n calls to pandas.plot) is that the color + # cycling and "stacked" bar/area of pandas break for all kinds of plots except "line" + # when we have more than one dimension involved + for series_key, series_data in array.items(series): + series_name = ' '.join(str(k) for k in series_key) + # support for list-like y + if isinstance(y, (list, np.ndarray)): + label = [f'{series_name} {y_label}' for y_label in y] + else: + label = f'{series_name} {y}' + PlotObject._to_pd_obj(series_data).plot(*args, x=x, y=y, label=label, **kwargs) + return kwargs['ax'] + else: + # this version works fine for all kinds of plots as long as we only use axes and not labels + if series is not None and len(series) >= 1: + # move series axes first and combine them + array = array.transpose(series).combine_axes(series, sep=' ') + # move it last (as columns) unless we need x axes or label axis last + if not _x_axes_last and label_axis is None: + array = array.transpose(..., array.axes[0]) + + return PlotObject._to_pd_obj(array).plot(*args, x=x, y=y, **kwargs) + + def __call__(self, x=None, y=None, ax=None, subplots=False, layout=None, figsize=None, + sharex=None, sharey=False, tight_layout=None, constrained_layout=None, title=None, legend=None, + **kwargs): + from matplotlib.figure import Figure + + array = self.array + legend_kwargs = legend if isinstance(legend, dict) else {} + + subplot_axes, series_axes, x, y = PlotObject._handle_x_y_axes(array.axes, x, y, subplots) + + if constrained_layout is None: + constrained_layout = True + + if subplots: + if ax is not None: + raise ValueError("ax cannot be used in combination with subplots argument") + fig = Figure(figsize=figsize, tight_layout=tight_layout, constrained_layout=constrained_layout) + + num_subplots = subplot_axes.size + if layout is None: + subplots_shape = subplot_axes.shape + if len(subplots_shape) > 2: + # default to last axis horizontal, other axes combined vertically + layout = np.prod(subplots_shape[:-1]), subplots_shape[-1] + else: + layout = subplot_axes.shape + + if sharex is None: + sharex = True + ax = fig.subplots(*layout, sharex=sharex, sharey=sharey) + # it is easier to always work with a flat array + flat_ax = ax.flat + # remove blank plot(s) at the end, if any + if len(flat_ax) > num_subplots: + for plot_ax in flat_ax[num_subplots:]: + plot_ax.remove() + # this not strictly necessary but is cleaner in case we reuse flax_ax + flat_ax = flat_ax[:num_subplots] + if title is not None: + fig.suptitle(title) + for i, (ndkey, subarr) in enumerate(array.items(subplot_axes)): + title = ' '.join(str(ak) for ak in ndkey) + self._plot_array(subarr, x=x, y=y, series=series_axes, ax=flat_ax[i], legend=False, title=title, + **kwargs) + else: + if ax is None: + fig = Figure(figsize=figsize, tight_layout=tight_layout, constrained_layout=constrained_layout) + ax = fig.subplots(1, 1) + self._plot_array(array, x=x, y=y, series=series_axes, ax=ax, legend=False, title=title, **kwargs) + + if legend or legend is None: + first_ax = ax.flat[0] if subplots else ax + handles, labels = first_ax.get_legend_handles_labels() + if legend is None: + # if there is a single series (per plot), a legend is useless + legend = len(handles) > 1 or legend_kwargs + + if legend: + if 'title' not in legend_kwargs: + axes_names = series_axes.names + # if y is a label (not an axis), this counts as an extra axis as far as the legend is concerned + if isinstance(y, Group): + axes_names += y.axis.name + legend_kwargs['title'] = ' '.join(axes_names) + # use figure to place legend to add a single legend for all subplots + legend_parent = first_ax.figure if subplots else ax + legend_parent.legend(handles, labels, **legend_kwargs) + return ax + + @_use_pandas_plot_docstring + def line(self, x=None, y=None, **kwds): + return self(kind='line', x=x, y=y, **kwds) + + @_use_pandas_plot_docstring + def bar(self, x=None, y=None, **kwds): + return self(kind='bar', x=x, y=y, **kwds) + + @_use_pandas_plot_docstring + def barh(self, x=None, y=None, **kwds): + return self(kind='barh', x=x, y=y, **kwds) + + @_use_pandas_plot_docstring + def box(self, by=None, **kwds): + x = kwds.pop('x', None) + if x is None: + x = by if by is not None else () + ax = self(kind='box', x=x, _x_axes_last=True, **kwds) + if 'ax' not in kwds and by is None: + # avoid having a single None tick + ax.get_xaxis().set_visible(False) + return ax + + @_use_pandas_plot_docstring + def hist(self, by=None, bins=10, **kwds): + y = kwds.pop('y', None) + if y is None: + if by is None: + y = self.array.axes + if 'legend' not in kwds: + kwds['legend'] = False + else: + y = by + return self(kind='hist', y=y, bins=bins, **kwds) + + @_use_pandas_plot_docstring + def kde(self, by=None, bw_method=None, ind=None, **kwds): + y = kwds.pop('y', None) + if y is None: + if by is None: + y = self.array.axes + if 'legend' not in kwds: + kwds['legend'] = False + else: + y = by + return self(kind='kde', bw_method=bw_method, ind=ind, y=y, **kwds) + + @_use_pandas_plot_docstring + def area(self, x=None, y=None, **kwds): + return self(kind='area', x=x, y=y, **kwds) + + @_use_pandas_plot_docstring + def pie(self, y=None, legend=False, **kwds): + if y is None: + # add a dummy axis with blank name and a 'value' label and plot that label to avoid 'None' labels for + # each subplot (when used) if we had used y = () instead + self = self.array.expand(' =__dummy_value').plot + y = '__dummy_value' + if 'ylabel' not in kwds: + # avoid showing '__dummy_value' as ylabel + kwds['ylabel'] = '' + + # avoid a deprecation warning issued by matplotlib 3.3+ (and not fixed in Pandas as of Pandas 1.3.0) + if 'normalize' not in kwds: + kwds['normalize'] = True + + ax = self(kind='pie', y=y, legend=legend, **kwds) + + # if we created the Axes and we have subplots, hide all x axis because as of now + # (pandas 1.3.0 and matplotlib 3.3.4) there are some ugly and useless x axes + # with a few ticks when have subplots in a vertical layout + if 'ax' not in kwds and isinstance(ax, np.ndarray): + for axes in ax.flat: + axes.get_xaxis().set_visible(False) + return ax + + @_use_pandas_plot_docstring + def scatter(self, x, y, s=None, c=None, **kwds): + # TODO: add support for 'c' and 's' even when x and y are not specified + return self(kind='scatter', x=x, y=y, c=c, s=s, **kwds) + + @_use_pandas_plot_docstring + def hexbin(self, x, y, C=None, reduce_C_function=None, gridsize=None, **kwds): + if reduce_C_function is not None: + kwds['reduce_C_function'] = reduce_C_function + if gridsize is not None: + kwds['gridsize'] = gridsize + return self(kind='hexbin', x=x, y=y, C=C, **kwds) + + # TODO: rename to ArrayIndexIndexer or something like that # TODO: the first slice in the example below should be documented class ArrayPositionalIndexer(object): @@ -3084,6 +3384,7 @@ def indicesofsorted(self, axis=None, ascending=True, kind='quicksort'): # TODO: implement keys_by # XXX: implement expand=True? Unsure it is necessary now that we have zip_array_* + # TODO: add support for groups in addition to entire axes def keys(self, axes=None, ascending=True): r"""Returns a view on the array labels along axes. @@ -3162,6 +3463,7 @@ def keys(self, axes=None, ascending=True): return self.axes.iter_labels(axes, ascending=ascending) # TODO: implement values_by + # TODO: add support for groups in addition to entire axes def values(self, axes=None, ascending=True): r"""Returns a view on the values of the array along axes. @@ -6885,8 +7187,10 @@ def plot(self): - 'scatter' : scatter plot (if array's dimensions >= 2) - 'hexbin' : hexbin plot (if array's dimensions >= 2) ax : matplotlib axes object, default None - subplots : boolean, default False - Make separate subplots for each column + subplots : boolean, Axis, int, str or tuple, default False + Make several subplots. If True, will make subplots for each combination of labels for all axes except the + last. If an Axis, int, str (or tuple of those), it will make subplots for combination of labels of those + axes. sharex : boolean, default True if ax is None else False In case subplots=True, share x axis and set some x axis labels to invisible; defaults to True if ax is None otherwise False if an ax is passed in; @@ -6903,7 +7207,7 @@ def plot(self): grid : boolean, default None (matlab style default) Axis grid lines legend : False/True/'reverse' - Place legend on axis subplots + Place legend on axis subplots. Defaults to True. style : list or dict matplotlib line style per column logx : boolean, default False @@ -6929,8 +7233,6 @@ def plot(self): position : float Specify relative alignments for bar plot layout. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center) - layout : tuple (optional) - (rows, columns) for the layout of the plot yerr : array-like Error bars on y axis xerr : array-like @@ -6950,45 +7252,46 @@ def plot(self): Examples -------- - >>> import matplotlib.pyplot as plt # doctest: +SKIP - >>> a = ndtest('gender=M,F;age=0..20') + >>> import matplotlib.pyplot as plt + >>> # let us define an array with some made up data + >>> arr = Array([[5, 20, 5, 10], [6, 16, 8, 11]], 'gender=M,F;year=2018..2021') Simple line plot - >>> a.plot() # doctest: +SKIP - >>> # shows figure (reset the current figure after showing it! Do not call it before savefig) - >>> plt.show() # doctest: +SKIP + >>> arr.plot() + >>> # show figure (it also resets it after showing it! Do not call it before savefig) + >>> plt.show() - Line plot with grid, title and both axes in logscale + Line plot with grid and a title - >>> a.plot(grid=True, loglog=True, title='line plot') # doctest: +SKIP - >>> # saves figure in a file (see matplotlib.pyplot.savefig documentation for more details) - >>> plt.savefig('my_file.png') # doctest: +SKIP + >>> arr.plot(grid=True, title='line plot') + >>> # save figure in a file (see matplotlib.pyplot.savefig documentation for more details) + >>> plt.savefig('my_file.png') - 2 bar plots sharing the same x axis (one for males and one for females) + 2 bar plots (one for each gender) sharing the same y axis, which makes sub plots + easier to compare. By default sub plots are independant of each other and the axes + ranges are computed to "fit" just the data for their individual plot. - >>> a.plot.bar(subplots=True, sharex=True) # doctest: +SKIP - >>> plt.show() # doctest: +SKIP + >>> arr.plot.bar(subplots='gender', sharey=True) + >>> plt.show() Create a figure containing 2 x 2 graphs >>> # see matplotlib.pyplot.subplots documentation for more details - >>> fig, ax = plt.subplots(2, 2, figsize=(15, 15)) # doctest: +SKIP - >>> # 2 curves : Males and Females - >>> a.plot(ax=ax[0, 0], title='line plot') # doctest: +SKIP - >>> # bar plot with stacked values - >>> a.plot.bar(ax=ax[0, 1], stacked=True, title='stacked bar plot') # doctest: +SKIP - >>> # same as previously but with colored areas instead of bars - >>> a.plot.area(ax=ax[1, 0], title='area plot') # doctest: +SKIP - >>> # scatter plot - >>> a.plot.scatter(ax=ax[1, 1], x='M', y='F', title='scatter plot') # doctest: +SKIP - >>> plt.show() # doctest: +SKIP + >>> fig, ax = plt.subplots(2, 2, figsize=(15, 15)) # doctest: +SKIP + >>> # line plot with 2 curves (Males and Females) in the top left corner (0, 0) + >>> arr.plot(ax=ax[0, 0], title='line plot') # doctest: +SKIP + >>> # bar plot with stacked values in the top right corner (0, 1) + >>> arr.plot.bar(ax=ax[0, 1], stacked=True, title='stacked bar plot') # doctest: +SKIP + >>> # area plot in the bottom left corner (1, 0) + >>> arr.plot.area(ax=ax[1, 0], title='area plot') # doctest: +SKIP + >>> # scatter plot in the bottom right corner (1, 1) + >>> arr.plot.scatter(ax=ax[1, 1], x='M', y='F', title='scatter plot') # doctest: +SKIP + >>> arr.plot.scatter(ax=ax[1, 1], x='M', y='F', c=arr.year, colormap='viridis', + ... title='scatter plot') # doctest: +SKIP + >>> plt.show() # doctest: +SKIP """ - combined = self.combine_axes(self.axes[:-1], sep=' ') if self.ndim > 2 else self - if combined.ndim == 1: - return combined.to_series().plot - else: - return combined.transpose().to_frame().plot + return PlotObject(self) @property def shape(self): @@ -7566,7 +7869,7 @@ def combine_axes(self, axes=None, sep='_', wildcard=False): transposed_axes = transposed_axes - axes_to_combine transposed_axes = transposed_axes[:min_axis_index] + axes_to_combine + transposed_axes[min_axis_index:] transposed = self.transpose(transposed_axes) - + # XXX: I think this might be problematic if axes to combine are given by position instead of by name/object new_axes = transposed.axes.combine_axes(axes, sep=sep, wildcard=wildcard) return transposed.reshape(new_axes) diff --git a/larray/core/misc.py b/larray/core/misc.py index 28bee6fbc..9b042b52c 100644 --- a/larray/core/misc.py +++ b/larray/core/misc.py @@ -2,8 +2,6 @@ import numpy as np -from larray.core.array import ndtest # noqa: F401 - def isscalar(element: Any) -> bool: r""" @@ -21,6 +19,7 @@ def isscalar(element: Any) -> bool: Examples -------- + >>> from larray import ndtest >>> isscalar(3.1) True >>> isscalar([3.1]) diff --git a/setup.cfg b/setup.cfg index c10ab4a46..e71069fae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,14 +3,30 @@ test=pytest [tool:pytest] testpaths = larray -# - exclude (doc)tests from ufuncs (because docstrings are copied from numpy -# and many of those doctests are failing -# - deselect Array.astype since doctests fails for Python 3.6 and numpy >= 1.17 addopts = -v --doctest-modules + # exclude (doc)tests from ufuncs (because docstrings are copied from numpy + # and many of those doctests are failing) --ignore=larray/core/npufuncs.py --ignore=larray/ipfp --ignore=larray/inout/xw_reporting.py + # doctest fails for Python 3.6 and numpy >= 1.17 --deselect larray/core/array.py::larray.core.array.Array.astype + # doctest fails (because the plot method returns a matplotlib axis object, + # which we do not mention in the doctest to make it nicer) + --deselect larray/core/array.py::larray.core.array.Array.plot + # skip Pandas-leeched doctests because they are not larray-specific and, + # without Pandas-specific documentation build infrastructure, they leave + # some plots open + --deselect larray/core/array.py::larray.core.array.PlotObject.area + --deselect larray/core/array.py::larray.core.array.PlotObject.bar + --deselect larray/core/array.py::larray.core.array.PlotObject.barh + --deselect larray/core/array.py::larray.core.array.PlotObject.box + --deselect larray/core/array.py::larray.core.array.PlotObject.hexbin + --deselect larray/core/array.py::larray.core.array.PlotObject.hist + --deselect larray/core/array.py::larray.core.array.PlotObject.kde + --deselect larray/core/array.py::larray.core.array.PlotObject.line + --deselect larray/core/array.py::larray.core.array.PlotObject.pie + --deselect larray/core/array.py::larray.core.array.PlotObject.scatter --flake8 #--cov