diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py new file mode 100644 index 0000000..fefc8bd --- /dev/null +++ b/data_prototype/conversion_node.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Iterable, Callable, Sequence +from collections import Counter +from dataclasses import dataclass +import inspect +from functools import cached_property + +from typing import Any + + +def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]): + for node in nodes: + input = node.evaluate(input) + return input + + +@dataclass +class ConversionNode: + required_keys: tuple[str, ...] + output_keys: tuple[str, ...] + trim_keys: bool + + def preview_keys(self, input_keys: Iterable[str]) -> tuple[str, ...]: + if missing_keys := set(self.required_keys) - set(input_keys): + raise ValueError(f"Missing keys: {missing_keys}") + if self.trim_keys: + return tuple(sorted(set(self.output_keys))) + return tuple(sorted(set(input_keys) | set(self.output_keys))) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + if self.trim_keys: + return {k: input[k] for k in self.output_keys} + else: + if missing_keys := set(self.output_keys) - set(input): + raise ValueError(f"Missing keys: {missing_keys}") + return input + + +@dataclass +class UnionConversionNode(ConversionNode): + nodes: tuple[ConversionNode, ...] + + @classmethod + def from_nodes(cls, *nodes: ConversionNode, trim_keys=False): + required = tuple(set(k for n in nodes for k in n.required_keys)) + output = Counter(k for n in nodes for k in n.output_keys) + if duplicate := {k for k, v in output.items() if v > 1}: + raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}") + return cls(required, tuple(output), trim_keys, nodes) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()}) + + +@dataclass +class RenameConversionNode(ConversionNode): + mapping: dict[str, str] + + @classmethod + def from_mapping(cls, mapping: dict[str, str], trim_keys=False): + required = tuple(mapping) + output = Counter(mapping.values()) + if duplicate := {k for k, v in output.items() if v > 1}: + raise ValueError(f"Duplicate output keys in mapping: {duplicate}") + return cls(required, tuple(output), trim_keys, mapping) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}}) + + +@dataclass +class FunctionConversionNode(ConversionNode): + funcs: dict[str, Callable] + + @cached_property + def _sigs(self): + return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()} + + @classmethod + def from_funcs(cls, funcs: dict[str, Callable], trim_keys=False): + sigs = {k: inspect.signature(f) for k, f in funcs.items()} + output = tuple(sigs) + input = [] + for v in sigs.values(): + input.extend(v.parameters.keys()) + input = tuple(set(input)) + return cls(input, output, trim_keys, funcs) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate( + { + **input, + **{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()}, + } + ) + + +@dataclass +class LimitKeysConversionNode(ConversionNode): + keys: set[str] + + @classmethod + def from_keys(cls, keys: Sequence[str]): + return cls((), tuple(keys), trim_keys=True, keys=set(keys)) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in input.items() if k in self.keys} diff --git a/data_prototype/patches.py b/data_prototype/patches.py index 94a50af..d993f2f 100644 --- a/data_prototype/patches.py +++ b/data_prototype/patches.py @@ -44,8 +44,8 @@ class PatchWrapper(ProxyWrapper): "joinstyle", } - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs) @_stale_wrapper diff --git a/data_prototype/tests/test_containers.py b/data_prototype/tests/test_containers.py index ddea3c6..4fadefb 100644 --- a/data_prototype/tests/test_containers.py +++ b/data_prototype/tests/test_containers.py @@ -14,7 +14,6 @@ def ac(): def _verify_describe(container): - data, cache_key = container.query(IdentityTransform(), [100, 100]) desc = container.describe() diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 83ab843..e447e63 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -16,6 +16,13 @@ from matplotlib.artist import Artist as _Artist from data_prototype.containers import DataContainer, _MatplotlibTransform +from data_prototype.conversion_node import ( + ConversionNode, + RenameConversionNode, + evaluate_pipeline, + FunctionConversionNode, + LimitKeysConversionNode, +) class _BBox(Protocol): @@ -139,45 +146,26 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str] return self._cache[cache_key] except KeyError: ... - # TODO decide if units go pre-nu or post-nu? - for x_like in xunits: - if x_like in data: - data[x_like] = ax.xaxis.convert_units(data[x_like]) - for y_like in yunits: - if y_like in data: - data[y_like] = ax.xaxis.convert_units(data[y_like]) - - # doing the nu work here is nice because we can write it once, but we - # really want to push this computation down a layer - # TODO sort out how this interoperates with the transform stack - transformed_data = {} - for k, (nu, sig) in self._sigs.items(): - to_pass = set(sig.parameters) - transformed_data[k] = nu(**{k: data[k] for k in to_pass}) + # TODO units + transformed_data = evaluate_pipeline(self._converters, data) self._cache[cache_key] = transformed_data return transformed_data - def __init__(self, data, nus, **kwargs): + def __init__(self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs): super().__init__(**kwargs) self.data = data self._cache = LFUCache(64) # TODO make sure mutating this will invalidate the cache! - self._nus = nus or {} - for k in self.required_keys: - self._nus.setdefault(k, _make_identity(k)) - desc = data.describe() - for k in self.expected_keys: - if k in desc: - self._nus.setdefault(k, _make_identity(k)) - self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()} + if isinstance(converters, ConversionNode): + converters = [converters] + self._converters: list[ConversionNode] = converters or [] + setters = list(self.expected_keys | self.required_keys) + if hasattr(self, "_wrapped_class"): + setters += [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] + self._converters.append(LimitKeysConversionNode.from_keys(setters)) self.stale = True - # TODO add a setter - @property - def nus(self): - return dict(self._nus) - class ProxyWrapper(ProxyWrapperBase): _privtized_methods: Tuple[str, ...] = () @@ -192,7 +180,7 @@ def __getattr__(self, key): return getattr(self._wrapped_instance, key) def __setattr__(self, key, value): - if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"): + if key in ("_wrapped_instance", "data", "_cache", "_converters", "stale", "_sigs"): super().__setattr__(key, value) elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key): setattr(self._wrapped_instance, key, value) @@ -205,9 +193,12 @@ class LineWrapper(ProxyWrapper): _privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data") required_keys = {"x", "y"} - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs) + self._converters.insert(-1, RenameConversionNode.from_mapping({"x": "xdata", "y": "ydata"})) + setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] + self._converters[-1] = LimitKeysConversionNode.from_keys(setters) @_stale_wrapper def draw(self, renderer): @@ -218,7 +209,6 @@ def draw(self, renderer): def _update_wrapped(self, data): for k, v in data.items(): - k = {"x": "xdata", "y": "ydata"}.get(k, k) getattr(self._wrapped_instance, f"set_{k}")(v) @@ -238,8 +228,8 @@ class PathCollectionWrapper(ProxyWrapper): "get_paths", ) - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class([], **kwargs) self._wrapped_instance.set_transform(mtransforms.IdentityTransform()) @@ -262,17 +252,17 @@ class ImageWrapper(ProxyWrapper): _wrapped_class = _AxesImage required_keys = {"xextent", "yextent", "image"} - def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs): - nus = dict(nus or {}) + def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs): + converters = converters or [] if cmap is not None or norm is not None: - if nus is not None and "image" in nus: + if converters is not None and "image" in converters: raise ValueError("Conflicting input") if cmap is None: cmap = mpl.colormaps["viridis"] if norm is None: raise ValueError("not sure how to do autoscaling yet") - nus["image"] = lambda image: cmap(norm(image)) - super().__init__(data, nus) + converters.append(FunctionConversionNode.from_funcs({"image": lambda image: cmap(norm(image))})) + super().__init__(data, converters) kwargs.setdefault("origin", "lower") self._wrapped_instance = self._wrapped_class(None, **kwargs) @@ -293,8 +283,8 @@ class StepWrapper(ProxyWrapper): _privtized_methods = () # ("set_data", "get_data") required_keys = {"edges", "density"} - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class([], [1], **kwargs) @_stale_wrapper @@ -312,8 +302,8 @@ class FormatedText(ProxyWrapper): _wrapped_class = _Text _privtized_methods = ("set_text",) - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class(text="", **kwargs) @_stale_wrapper @@ -368,8 +358,8 @@ class ErrorbarWrapper(MultiProxyWrapper): required_keys = {"x", "y"} expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]} - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) # TODO all of the kwarg teasing apart that is needed color = kwargs.pop("color", "k") lw = kwargs.pop("lw", 2) diff --git a/examples/2Dfunc.py b/examples/2Dfunc.py index 6d200de..623a7d1 100644 --- a/examples/2Dfunc.py +++ b/examples/2Dfunc.py @@ -13,7 +13,6 @@ from data_prototype.wrappers import ImageWrapper from data_prototype.containers import FuncContainer -import matplotlib as mpl from matplotlib.colors import Normalize @@ -25,9 +24,8 @@ "image": (("N", "M"), lambda x, y: np.sin(x).reshape(1, -1) * np.cos(y).reshape(-1, 1)), }, ) -cmap = mpl.colormaps["viridis"] -norm = Normalize(-1, 1) -im = ImageWrapper(fc, {"image": lambda image: cmap(norm(image))}) +norm = Normalize(vmin=-1, vmax=1) +im = ImageWrapper(fc, norm=norm) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/animation.py b/examples/animation.py index 316dbaf..628f313 100644 --- a/examples/animation.py +++ b/examples/animation.py @@ -17,6 +17,7 @@ from matplotlib.animation import FuncAnimation from data_prototype.containers import _MatplotlibTransform, Desc +from data_prototype.conversion_node import FunctionConversionNode from data_prototype.wrappers import LineWrapper, FormatedText @@ -63,9 +64,9 @@ def update(frame, art): lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)") fc = FormatedText( sot_c, - {"text": lambda phase: f"ϕ={phase:.2f}"}, - x=2 * np.pi, - y=1, + FunctionConversionNode.from_funcs( + {"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1} + ), ha="right", ) fig, ax = plt.subplots() diff --git a/examples/mapped.py b/examples/mapped.py index 1da081c..0fe7898 100644 --- a/examples/mapped.py +++ b/examples/mapped.py @@ -3,7 +3,7 @@ Mapping Line Properties ======================= -Leveraging the nu functions to transform users space data to visualization data. +Leveraging the converter functions to transform users space data to visualization data. """ @@ -14,25 +14,31 @@ from data_prototype.wrappers import LineWrapper, FormatedText from data_prototype.containers import ArrayContainer +from data_prototype.conversion_node import FunctionConversionNode cmap = plt.colormaps["viridis"] cmap.set_over("k") cmap.set_under("r") norm = Normalize(1, 8) -line_nus = { - # arbitrary functions - "lw": lambda lw: min(1 + lw, 5), - # standard color mapping - "color": lambda j: cmap(norm(j)), - # categorical - "ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]], -} +line_converter = FunctionConversionNode.from_funcs( + { + # arbitrary functions + "lw": lambda lw: min(1 + lw, 5), + # standard color mapping + "color": lambda j: cmap(norm(j)), + # categorical + "ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]], + }, +) -text_nus = { - "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", - "y": lambda j: j, -} +text_converter = FunctionConversionNode.from_funcs( + { + "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", + "y": lambda j: j, + "x": lambda x: 2 * np.pi, + }, +) th = np.linspace(0, 2 * np.pi, 128) @@ -53,13 +59,13 @@ ax.add_artist( LineWrapper( ac, - line_nus, + line_converter, ) ) ax.add_artist( FormatedText( ac, - text_nus, + text_converter, x=2 * np.pi, ha="right", bbox={"facecolor": "gray", "alpha": 0.5}, diff --git a/examples/mulivariate_cmap.py b/examples/mulivariate_cmap.py index 8b6377e..c00b709 100644 --- a/examples/mulivariate_cmap.py +++ b/examples/mulivariate_cmap.py @@ -13,6 +13,7 @@ from data_prototype.wrappers import ImageWrapper from data_prototype.containers import FuncContainer +from data_prototype.conversion_node import FunctionConversionNode from matplotlib.colors import hsv_to_rgb @@ -40,7 +41,7 @@ def image_nu(image): }, ) -im = ImageWrapper(fc, {"image": image_nu}) +im = ImageWrapper(fc, FunctionConversionNode.from_funcs({"image": image_nu})) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/subsample.py b/examples/subsample.py index 2e0e4ab..fc84c5a 100644 --- a/examples/subsample.py +++ b/examples/subsample.py @@ -68,7 +68,7 @@ def query( sub = Subsample() cmap = mpl.colormaps["coolwarm"] norm = Normalize(-2.2, 2.2) -im = ImageWrapper(sub, {"image": lambda image: cmap(norm(image))}) +im = ImageWrapper(sub, cmap=cmap, norm=norm) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/widgets.py b/examples/widgets.py index bd6c40b..32a824f 100644 --- a/examples/widgets.py +++ b/examples/widgets.py @@ -15,6 +15,7 @@ from data_prototype.wrappers import LineWrapper from data_prototype.containers import FuncContainer +from data_prototype.conversion_node import FunctionConversionNode class SliderContainer(FuncContainer): @@ -117,7 +118,7 @@ def _query_hash(self, coord_transform, size): lw = LineWrapper( fc, # color map phase (scaled to 2pi and wrapped to [0, 1]) - {"color": lambda color: cmap((color / (2 * np.pi)) % 1)}, + FunctionConversionNode.from_funcs({"color": lambda color: cmap((color / (2 * np.pi)) % 1)}), lw=5, ) ax.add_artist(lw)