diff --git a/data_prototype/axes.py b/data_prototype/axes.py new file mode 100644 index 0000000..cc974a8 --- /dev/null +++ b/data_prototype/axes.py @@ -0,0 +1,141 @@ +import numpy as np + + +import matplotlib as mpl +from matplotlib.axes._axes import Axes as MPLAxes, _preprocess_data +import matplotlib.collections as mcoll +import matplotlib.cbook as cbook +import matplotlib.markers as mmarkers +import matplotlib.projections as mprojections + +from .containers import ArrayContainer, DataUnion +from .conversion_node import ( + MatplotlibUnitConversion, + FunctionConversionNode, + RenameConversionNode, +) +from .wrappers import PathCollectionWrapper + + +class Axes(MPLAxes): + # Name for registering as a projection so we can experiment with it + name = "data-prototype" + + @_preprocess_data( + replace_names=[ + "x", + "y", + "s", + "linewidths", + "edgecolors", + "c", + "facecolor", + "facecolors", + "color", + ], + label_namer="y", + ) + def scatter( + self, + x, + y, + s=None, + c=None, + marker=None, + cmap=None, + norm=None, + vmin=None, + vmax=None, + alpha=None, + linewidths=None, + *, + edgecolors=None, + plotnonfinite=False, + **kwargs + ): + # TODO implement normalize kwargs as a pipeline stage + # add edgecolors and linewidths to kwargs so they can be processed by + # normalize_kwargs + if edgecolors is not None: + kwargs.update({"edgecolors": edgecolors}) + if linewidths is not None: + kwargs.update({"linewidths": linewidths}) + + kwargs = cbook.normalize_kwargs(kwargs, mcoll.Collection) + c, colors, edgecolors = self._parse_scatter_color_args( + c, + edgecolors, + kwargs, + np.ma.ravel(x).size, + get_next_color_func=self._get_patches_for_fill.get_next_color, + ) + + inputs = ArrayContainer( + x=x, + y=y, + s=s, + c=c, + marker=marker, + cmap=cmap, + norm=norm, + vmin=vmin, + vmax=vmax, + alpha=alpha, + plotnonfinite=plotnonfinite, + facecolors=colors, + edgecolors=edgecolors, + **kwargs + ) + # TODO should more go in here? + # marker/s are always in Container, but require overriding if None + # Color handling is odd too + defaults = ArrayContainer( + linewidths=mpl.rcParams["lines.linewidth"], + ) + + cont = DataUnion(defaults, inputs) + + pipeline = [] + xconvert = MatplotlibUnitConversion.from_keys(("x",), axis=self.xaxis) + yconvert = MatplotlibUnitConversion.from_keys(("y",), axis=self.yaxis) + pipeline.extend([xconvert, yconvert]) + pipeline.append(lambda x: np.ma.ravel(x)) + pipeline.append(lambda y: np.ma.ravel(y)) + pipeline.append( + lambda s: np.ma.ravel(s) + if s is not None + else [20] + if mpl.rcParams["_internal.classic_mode"] + else [mpl.rcParams["lines.markersize"] ** 2.0] + ) + # TODO plotnonfinite/mask combining + pipeline.append( + lambda marker: marker + if marker is not None + else mpl.rcParams["scatter.marker"] + ) + pipeline.append( + lambda marker: marker + if isinstance(marker, mmarkers.MarkerStyle) + else mmarkers.MarkerStyle(marker) + ) + pipeline.append( + FunctionConversionNode.from_funcs( + { + "paths": lambda marker: [ + marker.get_path().transformed(marker.get_transform()) + ] + } + ) + ) + pipeline.append(RenameConversionNode.from_mapping({"s": "sizes"})) + + # TODO classic mode margin override? + pcw = PathCollectionWrapper(cont, pipeline, offset_transform=self.transData) + self.add_artist(pcw) + self._request_autoscale_view() + return pcw + + +# This is a handy trick to allow e.g. plt.subplots(subplot_kw={'projection': 'data-prototype'}) +mprojections.register_projection(Axes) diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 4d87446..b311e4d 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -85,7 +85,12 @@ class ArrayContainer: def __init__(self, **data): self._data = data self._cache_key = str(uuid.uuid4()) - self._desc = {k: Desc(v.shape, v.dtype) for k, v in data.items()} + self._desc = { + k: Desc(v.shape, v.dtype) + if isinstance(v, np.ndarray) + else Desc((), type(v)) + for k, v in data.items() + } def query( self, diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py index 834896c..0902fa5 100644 --- a/data_prototype/conversion_node.py +++ b/data_prototype/conversion_node.py @@ -6,11 +6,17 @@ import inspect from functools import cached_property +from matplotlib.axis import Axis + from typing import Any def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]): for node in nodes: + if isinstance(node, Callable): + k = list(inspect.signature(node).parameters.keys())[0] + node = FunctionConversionNode.from_funcs({k: node}) + input = node.evaluate(input) return input @@ -113,3 +119,20 @@ def from_keys(cls, keys: Sequence[str]): def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in input.items() if k in self.keys} + + +@dataclass +class MatplotlibUnitConversion(ConversionNode): + axis: Axis + + @classmethod + def from_keys(cls, keys: Sequence[str], axis: Axis): + return cls(tuple(keys), tuple(keys), trim_keys=False, axis=axis) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate( + { + **input, + **{k: self.axis.convert_units(input[k]) for k in self.required_keys}, + } + ) diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 14ac51c..a05f3cf 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -265,7 +265,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["x"], yunits=["y"]), + self._query_and_transform(renderer, xunits=[], yunits=[]), ) return self._wrapped_instance.draw(renderer) diff --git a/examples/scatter_with_custom_axes.py b/examples/scatter_with_custom_axes.py new file mode 100644 index 0000000..cd9ed58 --- /dev/null +++ b/examples/scatter_with_custom_axes.py @@ -0,0 +1,19 @@ +import data_prototype.axes # side-effect registers projection # noqa + +import matplotlib.pyplot as plt + +fig = plt.figure() +newstyle = fig.add_subplot(2, 1, 1, projection="data-prototype") +oldstyle = fig.add_subplot(2, 1, 2) + +newstyle.scatter([0, 1, 2], [2, 5, 1]) +oldstyle.scatter([0, 1, 2], [2, 5, 1]) +newstyle.scatter([0, 1, 2], [3, 1, 2]) +oldstyle.scatter([0, 1, 2], [3, 1, 2]) + + +# Autoscaling not working +newstyle.set_xlim(oldstyle.get_xlim()) +newstyle.set_ylim(oldstyle.get_ylim()) + +plt.show() diff --git a/examples/units.py b/examples/units.py new file mode 100644 index 0000000..0e042dd --- /dev/null +++ b/examples/units.py @@ -0,0 +1,42 @@ +""" +================================================== +An simple scatter plot using PathCollectionWrapper +================================================== + +A quick scatter plot using :class:`.containers.ArrayContainer` and +:class:`.wrappers.PathCollectionWrapper`. +""" +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib.markers as mmarkers + +from data_prototype.containers import ArrayContainer +from data_prototype.conversion_node import MatplotlibUnitConversion + +from data_prototype.wrappers import PathCollectionWrapper + +import pint + +ureg = pint.UnitRegistry() +ureg.setup_matplotlib() + +marker_obj = mmarkers.MarkerStyle("o") + +cont = ArrayContainer( + x=np.array([0, 1, 2]) * ureg.m, + y=np.array([1, 4, 2]) * ureg.m, + paths=np.array([marker_obj.get_path()]), + sizes=np.array([12]), + edgecolors=np.array(["k"]), + facecolors=np.array(["C3"]), +) + +fig, ax = plt.subplots() +ax.set_xlim(-0.5, 7) +ax.set_ylim(0, 5) +conv = MatplotlibUnitConversion.from_keys(("x",), axis=ax.xaxis) +lw = PathCollectionWrapper(cont, [conv], offset_transform=ax.transData) +ax.add_artist(lw) +ax.xaxis.set_units(ureg.feet) +plt.show()