diff --git a/data_prototype/axes.py b/data_prototype/axes.py new file mode 100644 index 0000000..3b9dd48 --- /dev/null +++ b/data_prototype/axes.py @@ -0,0 +1,141 @@ +import numpy as np + + +import matplotlib as mpl +from matplotlib.axes._axes import Axes as MPLAxes, _preprocess_data +import matplotlib.collections as mcoll +import matplotlib.cbook as cbook +import matplotlib.markers as mmarkers +import matplotlib.projections as mprojections + +from .containers import ArrayContainer, DataUnion +from .conversion_node import ( + DelayedConversionNode, + FunctionConversionNode, + RenameConversionNode, +) +from .wrappers import PathCollectionWrapper + + +class Axes(MPLAxes): + # Name for registering as a projection so we can experiment with it + name = "data-prototype" + + @_preprocess_data( + replace_names=[ + "x", + "y", + "s", + "linewidths", + "edgecolors", + "c", + "facecolor", + "facecolors", + "color", + ], + label_namer="y", + ) + def scatter( + self, + x, + y, + s=None, + c=None, + marker=None, + cmap=None, + norm=None, + vmin=None, + vmax=None, + alpha=None, + linewidths=None, + *, + edgecolors=None, + plotnonfinite=False, + **kwargs + ): + # TODO implement normalize kwargs as a pipeline stage + # add edgecolors and linewidths to kwargs so they can be processed by + # normalize_kwargs + if edgecolors is not None: + kwargs.update({"edgecolors": edgecolors}) + if linewidths is not None: + kwargs.update({"linewidths": linewidths}) + + kwargs = cbook.normalize_kwargs(kwargs, mcoll.Collection) + c, colors, edgecolors = self._parse_scatter_color_args( + c, + edgecolors, + kwargs, + np.ma.ravel(x).size, + get_next_color_func=self._get_patches_for_fill.get_next_color, + ) + + inputs = ArrayContainer( + x=x, + y=y, + s=s, + c=c, + marker=marker, + cmap=cmap, + norm=norm, + vmin=vmin, + vmax=vmax, + alpha=alpha, + plotnonfinite=plotnonfinite, + facecolors=colors, + edgecolors=edgecolors, + **kwargs + ) + # TODO should more go in here? + # marker/s are always in Container, but require overriding if None + # Color handling is odd too + defaults = ArrayContainer( + linewidths=mpl.rcParams["lines.linewidth"], + ) + + cont = DataUnion(defaults, inputs) + + pipeline = [] + xconvert = DelayedConversionNode.from_keys(("x",), converter_key="xunits") + yconvert = DelayedConversionNode.from_keys(("y",), converter_key="yunits") + pipeline.extend([xconvert, yconvert]) + pipeline.append(lambda x: np.ma.ravel(x)) + pipeline.append(lambda y: np.ma.ravel(y)) + pipeline.append( + lambda s: np.ma.ravel(s) + if s is not None + else [20] + if mpl.rcParams["_internal.classic_mode"] + else [mpl.rcParams["lines.markersize"] ** 2.0] + ) + # TODO plotnonfinite/mask combining + pipeline.append( + lambda marker: marker + if marker is not None + else mpl.rcParams["scatter.marker"] + ) + pipeline.append( + lambda marker: marker + if isinstance(marker, mmarkers.MarkerStyle) + else mmarkers.MarkerStyle(marker) + ) + pipeline.append( + FunctionConversionNode.from_funcs( + { + "paths": lambda marker: [ + marker.get_path().transformed(marker.get_transform()) + ] + } + ) + ) + pipeline.append(RenameConversionNode.from_mapping({"s": "sizes"})) + + # TODO classic mode margin override? + pcw = PathCollectionWrapper(cont, pipeline, offset_transform=self.transData) + self.add_artist(pcw) + self._request_autoscale_view() + return pcw + + +# This is a handy trick to allow e.g. plt.subplots(subplot_kw={'projection': 'data-prototype'}) +mprojections.register_projection(Axes) diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 4d87446..b311e4d 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -85,7 +85,12 @@ class ArrayContainer: def __init__(self, **data): self._data = data self._cache_key = str(uuid.uuid4()) - self._desc = {k: Desc(v.shape, v.dtype) for k, v in data.items()} + self._desc = { + k: Desc(v.shape, v.dtype) + if isinstance(v, np.ndarray) + else Desc((), type(v)) + for k, v in data.items() + } def query( self, diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py index 834896c..48f97d0 100644 --- a/data_prototype/conversion_node.py +++ b/data_prototype/conversion_node.py @@ -9,9 +9,19 @@ from typing import Any -def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]): +def evaluate_pipeline( + nodes: Sequence[ConversionNode], + input: dict[str, Any], + delayed_converters: dict[str, Callable] | None = None, +): for node in nodes: - input = node.evaluate(input) + if isinstance(node, Callable): + k = list(inspect.signature(node).parameters.keys())[0] + node = FunctionConversionNode.from_funcs({k: node}) + if isinstance(node, DelayedConversionNode): + input = node.evaluate(input, delayed_converters) + else: + input = node.evaluate(input) return input @@ -113,3 +123,27 @@ def from_keys(cls, keys: Sequence[str]): def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in input.items() if k in self.keys} + + +@dataclass +class DelayedConversionNode(ConversionNode): + converter_key: str + + @classmethod + def from_keys(cls, keys: Sequence[str], converter_key: str): + return cls( + tuple(keys), tuple(keys), trim_keys=False, converter_key=converter_key + ) + + def evaluate( + self, input: dict[str, Any], converters: dict[str, Callable] | None = None + ) -> dict[str, Any]: + return super().evaluate( + { + **input, + **{ + k: converters[self.converter_key](input[k]) + for k in self.required_keys + }, + } + ) diff --git a/data_prototype/patches.py b/data_prototype/patches.py index b062024..3c05c01 100644 --- a/data_prototype/patches.py +++ b/data_prototype/patches.py @@ -30,8 +30,6 @@ class PatchWrapper(ProxyWrapper): "set_joinstyle", "set_path", ) - _xunits = () - _yunits = () required_keys = { "edgecolor", "facecolor", @@ -50,11 +48,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): - self._update_wrapped( - self._query_and_transform( - renderer, xunits=self._xunits, yunits=self._yunits - ) - ) + self._update_wrapped(self._query_and_transform(renderer)) return self._wrapped_instance.draw(renderer) def _update_wrapped(self, data): @@ -77,8 +71,6 @@ class RectangleWrapper(PatchWrapper): "set_angle", "set_rotation_point", ) - _xunits = ("x", "width") - _yunits = ("y", "height") required_keys = PatchWrapper.required_keys | { "x", "y", diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 14ac51c..8c80779 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Protocol, Tuple, get_type_hints +from typing import Dict, Any, Protocol, Tuple, get_type_hints import inspect import numpy as np @@ -121,17 +121,13 @@ def draw(self, renderer): def _update_wrapped(self, data): raise NotImplementedError - def _query_and_transform( - self, renderer, *, xunits: List[str], yunits: List[str] - ) -> Dict[str, Any]: + def _query_and_transform(self, renderer) -> Dict[str, Any]: """ Helper to centralize the data querying and python-side transforms Parameters ---------- renderer : RendererBase - xunits, yunits : List[str] - The list of keys that need to be run through the x and y unit machinery. """ # extract what we need to about the axes to query the data ax = self.axes @@ -153,8 +149,11 @@ def _query_and_transform( return self._cache[cache_key] except KeyError: ... - # TODO units - transformed_data = evaluate_pipeline(self._converters, data) + delayed_conversion = { + "xunits": ax.xaxis.convert_units, + "yunits": ax.yaxis.convert_units, + } + transformed_data = evaluate_pipeline(self._converters, data, delayed_conversion) self._cache[cache_key] = transformed_data return transformed_data @@ -232,7 +231,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["x"], yunits=["y"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -265,7 +264,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["x"], yunits=["y"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -304,7 +303,7 @@ def __init__( @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -325,7 +324,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -344,7 +343,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=[], yunits=[]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -425,11 +424,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform( - renderer, - xunits=["x", "xupper", "xlower"], - yunits=["y", "yupper", "ylower"], - ), + self._query_and_transform(renderer), ) for k, v in self._wrapped_instances.items(): v.draw(renderer) diff --git a/examples/scatter_with_custom_axes.py b/examples/scatter_with_custom_axes.py new file mode 100644 index 0000000..aa37c76 --- /dev/null +++ b/examples/scatter_with_custom_axes.py @@ -0,0 +1,33 @@ +""" +========================================= +An simple scatter plot using `ax.scatter` +========================================= + +This is a quick comparison between the current Matplotlib `scatter` and +the version in :file:`data_prototype/axes.py`, which uses data containers +and a conversion pipeline. + +This is here to show what does work and what does not work with the current +implementation of container-based artist drawing. +""" + + +import data_prototype.axes # side-effect registers projection # noqa + +import matplotlib.pyplot as plt + +fig = plt.figure() +newstyle = fig.add_subplot(2, 1, 1, projection="data-prototype") +oldstyle = fig.add_subplot(2, 1, 2) + +newstyle.scatter([0, 1, 2], [2, 5, 1]) +oldstyle.scatter([0, 1, 2], [2, 5, 1]) +newstyle.scatter([0, 1, 2], [3, 1, 2]) +oldstyle.scatter([0, 1, 2], [3, 1, 2]) + + +# Autoscaling not working +newstyle.set_xlim(oldstyle.get_xlim()) +newstyle.set_ylim(oldstyle.get_ylim()) + +plt.show() diff --git a/examples/units.py b/examples/units.py new file mode 100644 index 0000000..ae6702b --- /dev/null +++ b/examples/units.py @@ -0,0 +1,47 @@ +""" +=========================================== +Using pint units with PathCollectionWrapper +=========================================== + +Using third party units functionality in conjunction with Matplotlib Axes +""" +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib.markers as mmarkers + +from data_prototype.containers import ArrayContainer +from data_prototype.conversion_node import DelayedConversionNode + +from data_prototype.wrappers import PathCollectionWrapper + +import pint + +ureg = pint.UnitRegistry() +ureg.setup_matplotlib() + +marker_obj = mmarkers.MarkerStyle("o") + +cont = ArrayContainer( + x=np.array([0, 1, 2]) * ureg.m, + y=np.array([1, 4, 2]) * ureg.m, + paths=np.array([marker_obj.get_path()]), + sizes=np.array([12]), + edgecolors=np.array(["k"]), + facecolors=np.array(["C3"]), +) + +fig, ax = plt.subplots() +ax.set_xlim(-0.5, 7) +ax.set_ylim(0, 5) + +# DelayedConversionNode is used to identify the keys which undergo unit transformations +# The actual method which does conversions in this example is added by the +# `Axis`/`Axes`, but `PathCollectionWrapper` does not natively interact with the units. +xconv = DelayedConversionNode.from_keys(("x",), converter_key="xunits") +yconv = DelayedConversionNode.from_keys(("y",), converter_key="yunits") +lw = PathCollectionWrapper(cont, [xconv, yconv], offset_transform=ax.transData) +ax.add_artist(lw) +ax.xaxis.set_units(ureg.feet) +ax.yaxis.set_units(ureg.m) +plt.show() diff --git a/requirements-doc.txt b/requirements-doc.txt index 73d145b..bdaeb96 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -7,3 +7,4 @@ sphinx-copybutton sphinx-gallery ipython scikit-image +pint