From 14346dc300d30099f7e574fe65f0614621a94bbc Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Thu, 18 May 2023 01:13:16 -0500 Subject: [PATCH 1/6] Initial implementation of conversion node Reference #26 --- data_prototype/conversion_node.py | 97 +++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 data_prototype/conversion_node.py diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py new file mode 100644 index 0000000..e1eef7f --- /dev/null +++ b/data_prototype/conversion_node.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from collections.abc import Iterable, Callable, Sequence +from collections import Counter +from dataclasses import dataclass +import inspect +from functools import cached_property + +from typing import Any + + +def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]): + for node in nodes: + input = node.evaluate(input) + return input + + +@dataclass +class ConversionNode: + name: str + required_keys: tuple[str, ...] + output_keys: tuple[str, ...] + trim_keys: bool + + def preview_keys(self, input_keys: Iterable[str]) -> tuple[str, ...]: + if missing_keys := set(self.required_keys) - set(input_keys): + raise ValueError(f"Missing keys: {missing_keys}") + if self.trim_keys: + return tuple(sorted(set(self.output_keys))) + return tuple(sorted(set(input_keys) | set(self.output_keys))) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + if self.trim_keys: + return {k: input[k] for k in self.output_keys} + else: + if missing_keys := set(self.output_keys) - set(input): + raise ValueError(f"Missing keys: {missing_keys}") + return input + + +@dataclass +class UnionConversionNode(ConversionNode): + nodes: tuple[ConversionNode, ...] + + @classmethod + def from_nodes(cls, name: str, *nodes: ConversionNode, trim_keys=False): + required = tuple(set(k for n in nodes for k in n.required_keys)) + output = Counter(k for n in nodes for k in n.output_keys) + if duplicate := {k for k, v in output.items() if v > 1}: + raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}") + return cls(name, required, tuple(output), trim_keys, nodes) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()}) + + +@dataclass +class RenameConversionNode(ConversionNode): + mapping: dict[str, str] + + @classmethod + def from_mapping(cls, name: str, mapping: dict[str, str], trim_keys=False): + required = tuple(mapping) + output = Counter(mapping.values()) + if duplicate := {k for k, v in output.items() if v > 1}: + raise ValueError(f"Duplicate output keys in mapping: {duplicate}") + return cls(name, required, tuple(output), trim_keys, mapping) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}}) + + +@dataclass +class FunctionConversionNode(ConversionNode): + funcs: dict[str, Callable] + + @cached_property + def _sigs(self): + return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()} + + @classmethod + def from_funcs(cls, name: str, funcs: dict[str, Callable], trim_keys=False): + sigs = {k: inspect.signature(f) for k, f in funcs.items()} + output = tuple(sigs) + input = [] + for v in sigs.values(): + input.extend(v.parameters.keys()) + input = tuple(set(input)) + return cls(name, input, output, trim_keys, funcs) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return super().evaluate( + { + **input, + **{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()}, + } + ) From 01b5746dec2b3d3145ae9b4413e9731440d9ee88 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Tue, 23 May 2023 17:10:50 -0700 Subject: [PATCH 2/6] blindly convert 'nus' to 'converters' --- data_prototype/patches.py | 4 ++-- data_prototype/wrappers.py | 46 +++++++++++++++++++------------------- examples/mapped.py | 10 ++++----- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/data_prototype/patches.py b/data_prototype/patches.py index 94a50af..d993f2f 100644 --- a/data_prototype/patches.py +++ b/data_prototype/patches.py @@ -44,8 +44,8 @@ class PatchWrapper(ProxyWrapper): "joinstyle", } - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs) @_stale_wrapper diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 83ab843..44ee85b 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -158,25 +158,25 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str] self._cache[cache_key] = transformed_data return transformed_data - def __init__(self, data, nus, **kwargs): + def __init__(self, data, converters, **kwargs): super().__init__(**kwargs) self.data = data self._cache = LFUCache(64) # TODO make sure mutating this will invalidate the cache! - self._nus = nus or {} + self._converters = converters or {} for k in self.required_keys: - self._nus.setdefault(k, _make_identity(k)) + self._converters.setdefault(k, _make_identity(k)) desc = data.describe() for k in self.expected_keys: if k in desc: - self._nus.setdefault(k, _make_identity(k)) - self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()} + self._converters.setdefault(k, _make_identity(k)) + self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._converters.items()} self.stale = True # TODO add a setter @property - def nus(self): - return dict(self._nus) + def converters(self): + return dict(self._converters) class ProxyWrapper(ProxyWrapperBase): @@ -192,7 +192,7 @@ def __getattr__(self, key): return getattr(self._wrapped_instance, key) def __setattr__(self, key, value): - if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"): + if key in ("_wrapped_instance", "data", "_cache", "_converters", "stale", "_sigs"): super().__setattr__(key, value) elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key): setattr(self._wrapped_instance, key, value) @@ -205,8 +205,8 @@ class LineWrapper(ProxyWrapper): _privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data") required_keys = {"x", "y"} - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs) @_stale_wrapper @@ -238,8 +238,8 @@ class PathCollectionWrapper(ProxyWrapper): "get_paths", ) - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class([], **kwargs) self._wrapped_instance.set_transform(mtransforms.IdentityTransform()) @@ -262,17 +262,17 @@ class ImageWrapper(ProxyWrapper): _wrapped_class = _AxesImage required_keys = {"xextent", "yextent", "image"} - def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs): - nus = dict(nus or {}) + def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs): + converters = dict(converters or {}) if cmap is not None or norm is not None: - if nus is not None and "image" in nus: + if converters is not None and "image" in converters: raise ValueError("Conflicting input") if cmap is None: cmap = mpl.colormaps["viridis"] if norm is None: raise ValueError("not sure how to do autoscaling yet") - nus["image"] = lambda image: cmap(norm(image)) - super().__init__(data, nus) + converters["image"] = lambda image: cmap(norm(image)) + super().__init__(data, converters) kwargs.setdefault("origin", "lower") self._wrapped_instance = self._wrapped_class(None, **kwargs) @@ -293,8 +293,8 @@ class StepWrapper(ProxyWrapper): _privtized_methods = () # ("set_data", "get_data") required_keys = {"edges", "density"} - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class([], [1], **kwargs) @_stale_wrapper @@ -312,8 +312,8 @@ class FormatedText(ProxyWrapper): _wrapped_class = _Text _privtized_methods = ("set_text",) - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) self._wrapped_instance = self._wrapped_class(text="", **kwargs) @_stale_wrapper @@ -368,8 +368,8 @@ class ErrorbarWrapper(MultiProxyWrapper): required_keys = {"x", "y"} expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]} - def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + def __init__(self, data: DataContainer, converters=None, /, **kwargs): + super().__init__(data, converters) # TODO all of the kwarg teasing apart that is needed color = kwargs.pop("color", "k") lw = kwargs.pop("lw", 2) diff --git a/examples/mapped.py b/examples/mapped.py index 1da081c..c5d4ff6 100644 --- a/examples/mapped.py +++ b/examples/mapped.py @@ -3,7 +3,7 @@ Mapping Line Properties ======================= -Leveraging the nu functions to transform users space data to visualization data. +Leveraging the converter functions to transform users space data to visualization data. """ @@ -20,7 +20,7 @@ cmap.set_under("r") norm = Normalize(1, 8) -line_nus = { +line_converter = { # arbitrary functions "lw": lambda lw: min(1 + lw, 5), # standard color mapping @@ -29,7 +29,7 @@ "ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]], } -text_nus = { +text_converter = { "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", "y": lambda j: j, } @@ -53,13 +53,13 @@ ax.add_artist( LineWrapper( ac, - line_nus, + line_converter, ) ) ax.add_artist( FormatedText( ac, - text_nus, + text_converter, x=2 * np.pi, ha="right", bbox={"facecolor": "gray", "alpha": 0.5}, From 0e1c99fbc481f05583b0e0640f5a7743218dadff Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Thu, 25 May 2023 13:20:00 -0700 Subject: [PATCH 3/6] Get examples working with conversion node infrastructure --- data_prototype/conversion_node.py | 12 +++++++ data_prototype/wrappers.py | 56 +++++++++++++------------------ examples/2Dfunc.py | 2 +- examples/animation.py | 7 ++-- examples/mapped.py | 32 +++++++++++------- examples/mulivariate_cmap.py | 3 +- examples/subsample.py | 2 +- examples/widgets.py | 3 +- 8 files changed, 66 insertions(+), 51 deletions(-) diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py index e1eef7f..946869d 100644 --- a/data_prototype/conversion_node.py +++ b/data_prototype/conversion_node.py @@ -95,3 +95,15 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: **{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()}, } ) + + +@dataclass +class LimitKeysConversionNode(ConversionNode): + keys: set[str] + + @classmethod + def from_keys(cls, name: str, keys: Sequence[str]): + return cls(name, (), tuple(keys), trim_keys=True, keys=set(keys)) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in input.items() if k in self.keys} diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 44ee85b..b6d221b 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -16,6 +16,13 @@ from matplotlib.artist import Artist as _Artist from data_prototype.containers import DataContainer, _MatplotlibTransform +from data_prototype.conversion_node import ( + ConversionNode, + RenameConversionNode, + evaluate_pipeline, + FunctionConversionNode, + LimitKeysConversionNode, +) class _BBox(Protocol): @@ -139,45 +146,26 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str] return self._cache[cache_key] except KeyError: ... - # TODO decide if units go pre-nu or post-nu? - for x_like in xunits: - if x_like in data: - data[x_like] = ax.xaxis.convert_units(data[x_like]) - for y_like in yunits: - if y_like in data: - data[y_like] = ax.xaxis.convert_units(data[y_like]) - - # doing the nu work here is nice because we can write it once, but we - # really want to push this computation down a layer - # TODO sort out how this interoperates with the transform stack - transformed_data = {} - for k, (nu, sig) in self._sigs.items(): - to_pass = set(sig.parameters) - transformed_data[k] = nu(**{k: data[k] for k in to_pass}) + # TODO units + transformed_data = evaluate_pipeline(self._converters, data) self._cache[cache_key] = transformed_data return transformed_data - def __init__(self, data, converters, **kwargs): + def __init__(self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs): super().__init__(**kwargs) self.data = data self._cache = LFUCache(64) # TODO make sure mutating this will invalidate the cache! - self._converters = converters or {} - for k in self.required_keys: - self._converters.setdefault(k, _make_identity(k)) - desc = data.describe() - for k in self.expected_keys: - if k in desc: - self._converters.setdefault(k, _make_identity(k)) - self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._converters.items()} + if isinstance(converters, ConversionNode): + converters = [converters] + self._converters: list[ConversionNode] = converters or [] + setters = list(self.expected_keys | self.required_keys) + if hasattr(self, "_wrapped_class"): + setters += [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] + self._converters.append(LimitKeysConversionNode.from_keys("setters", setters)) self.stale = True - # TODO add a setter - @property - def converters(self): - return dict(self._converters) - class ProxyWrapper(ProxyWrapperBase): _privtized_methods: Tuple[str, ...] = () @@ -208,6 +196,9 @@ class LineWrapper(ProxyWrapper): def __init__(self, data: DataContainer, converters=None, /, **kwargs): super().__init__(data, converters) self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs) + self._converters.insert(-1, RenameConversionNode.from_mapping("xydata", {"x": "xdata", "y": "ydata"})) + setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] + self._converters[-1] = LimitKeysConversionNode.from_keys("setters", setters) @_stale_wrapper def draw(self, renderer): @@ -218,7 +209,6 @@ def draw(self, renderer): def _update_wrapped(self, data): for k, v in data.items(): - k = {"x": "xdata", "y": "ydata"}.get(k, k) getattr(self._wrapped_instance, f"set_{k}")(v) @@ -263,7 +253,7 @@ class ImageWrapper(ProxyWrapper): required_keys = {"xextent", "yextent", "image"} def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs): - converters = dict(converters or {}) + converters = converters or [] if cmap is not None or norm is not None: if converters is not None and "image" in converters: raise ValueError("Conflicting input") @@ -271,7 +261,9 @@ def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None cmap = mpl.colormaps["viridis"] if norm is None: raise ValueError("not sure how to do autoscaling yet") - converters["image"] = lambda image: cmap(norm(image)) + converters.append( + FunctionConversionNode.from_funcs("map colors", {"image": lambda image: cmap(norm(image))}) + ) super().__init__(data, converters) kwargs.setdefault("origin", "lower") self._wrapped_instance = self._wrapped_class(None, **kwargs) diff --git a/examples/2Dfunc.py b/examples/2Dfunc.py index 6d200de..e94131a 100644 --- a/examples/2Dfunc.py +++ b/examples/2Dfunc.py @@ -27,7 +27,7 @@ ) cmap = mpl.colormaps["viridis"] norm = Normalize(-1, 1) -im = ImageWrapper(fc, {"image": lambda image: cmap(norm(image))}) +im = ImageWrapper(fc) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/animation.py b/examples/animation.py index 316dbaf..6fd712b 100644 --- a/examples/animation.py +++ b/examples/animation.py @@ -17,6 +17,7 @@ from matplotlib.animation import FuncAnimation from data_prototype.containers import _MatplotlibTransform, Desc +from data_prototype.conversion_node import FunctionConversionNode from data_prototype.wrappers import LineWrapper, FormatedText @@ -63,9 +64,9 @@ def update(frame, art): lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)") fc = FormatedText( sot_c, - {"text": lambda phase: f"ϕ={phase:.2f}"}, - x=2 * np.pi, - y=1, + FunctionConversionNode.from_funcs( + "fmt", {"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1} + ), ha="right", ) fig, ax = plt.subplots() diff --git a/examples/mapped.py b/examples/mapped.py index c5d4ff6..9df022e 100644 --- a/examples/mapped.py +++ b/examples/mapped.py @@ -14,25 +14,33 @@ from data_prototype.wrappers import LineWrapper, FormatedText from data_prototype.containers import ArrayContainer +from data_prototype.conversion_node import FunctionConversionNode cmap = plt.colormaps["viridis"] cmap.set_over("k") cmap.set_under("r") norm = Normalize(1, 8) -line_converter = { - # arbitrary functions - "lw": lambda lw: min(1 + lw, 5), - # standard color mapping - "color": lambda j: cmap(norm(j)), - # categorical - "ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]], -} +line_converter = FunctionConversionNode.from_funcs( + "line converter", + { + # arbitrary functions + "lw": lambda lw: min(1 + lw, 5), + # standard color mapping + "color": lambda j: cmap(norm(j)), + # categorical + "ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]], + }, +) -text_converter = { - "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", - "y": lambda j: j, -} +text_converter = FunctionConversionNode.from_funcs( + "text converter", + { + "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", + "y": lambda j: j, + "x": lambda x: 2 * np.pi, + }, +) th = np.linspace(0, 2 * np.pi, 128) diff --git a/examples/mulivariate_cmap.py b/examples/mulivariate_cmap.py index 8b6377e..c1e10d5 100644 --- a/examples/mulivariate_cmap.py +++ b/examples/mulivariate_cmap.py @@ -13,6 +13,7 @@ from data_prototype.wrappers import ImageWrapper from data_prototype.containers import FuncContainer +from data_prototype.conversion_node import FunctionConversionNode from matplotlib.colors import hsv_to_rgb @@ -40,7 +41,7 @@ def image_nu(image): }, ) -im = ImageWrapper(fc, {"image": image_nu}) +im = ImageWrapper(fc, FunctionConversionNode.from_funcs("bivariate", {"image": image_nu})) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/subsample.py b/examples/subsample.py index 2e0e4ab..978104a 100644 --- a/examples/subsample.py +++ b/examples/subsample.py @@ -68,7 +68,7 @@ def query( sub = Subsample() cmap = mpl.colormaps["coolwarm"] norm = Normalize(-2.2, 2.2) -im = ImageWrapper(sub, {"image": lambda image: cmap(norm(image))}) +im = ImageWrapper(sub) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/widgets.py b/examples/widgets.py index bd6c40b..c5a70e5 100644 --- a/examples/widgets.py +++ b/examples/widgets.py @@ -15,6 +15,7 @@ from data_prototype.wrappers import LineWrapper from data_prototype.containers import FuncContainer +from data_prototype.conversion_node import FunctionConversionNode class SliderContainer(FuncContainer): @@ -117,7 +118,7 @@ def _query_hash(self, coord_transform, size): lw = LineWrapper( fc, # color map phase (scaled to 2pi and wrapped to [0, 1]) - {"color": lambda color: cmap((color / (2 * np.pi)) % 1)}, + FunctionConversionNode.from_funcs("cmap", {"color": lambda color: cmap((color / (2 * np.pi)) % 1)}), lw=5, ) ax.add_artist(lw) From 34aaa72a643c10428e44ddce8b0e8807b83a8f12 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Tue, 6 Jun 2023 13:46:13 -0500 Subject: [PATCH 4/6] fix subsample example to use cmap/norm by passing to image wrapper instead of as converter --- examples/subsample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/subsample.py b/examples/subsample.py index 978104a..fc84c5a 100644 --- a/examples/subsample.py +++ b/examples/subsample.py @@ -68,7 +68,7 @@ def query( sub = Subsample() cmap = mpl.colormaps["coolwarm"] norm = Normalize(-2.2, 2.2) -im = ImageWrapper(sub) +im = ImageWrapper(sub, cmap=cmap, norm=norm) fig, ax = plt.subplots() ax.add_artist(im) From 0bae71914d4f919646b271fcee84394c0282f9ca Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 7 Jun 2023 17:19:04 -0500 Subject: [PATCH 5/6] Remove unused name field from Converters --- data_prototype/conversion_node.py | 17 ++++++++--------- data_prototype/wrappers.py | 10 ++++------ examples/2Dfunc.py | 6 ++---- examples/animation.py | 2 +- examples/mapped.py | 2 -- examples/mulivariate_cmap.py | 2 +- examples/widgets.py | 2 +- 7 files changed, 17 insertions(+), 24 deletions(-) diff --git a/data_prototype/conversion_node.py b/data_prototype/conversion_node.py index 946869d..fefc8bd 100644 --- a/data_prototype/conversion_node.py +++ b/data_prototype/conversion_node.py @@ -17,7 +17,6 @@ def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]): @dataclass class ConversionNode: - name: str required_keys: tuple[str, ...] output_keys: tuple[str, ...] trim_keys: bool @@ -43,12 +42,12 @@ class UnionConversionNode(ConversionNode): nodes: tuple[ConversionNode, ...] @classmethod - def from_nodes(cls, name: str, *nodes: ConversionNode, trim_keys=False): + def from_nodes(cls, *nodes: ConversionNode, trim_keys=False): required = tuple(set(k for n in nodes for k in n.required_keys)) output = Counter(k for n in nodes for k in n.output_keys) if duplicate := {k for k, v in output.items() if v > 1}: raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}") - return cls(name, required, tuple(output), trim_keys, nodes) + return cls(required, tuple(output), trim_keys, nodes) def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()}) @@ -59,12 +58,12 @@ class RenameConversionNode(ConversionNode): mapping: dict[str, str] @classmethod - def from_mapping(cls, name: str, mapping: dict[str, str], trim_keys=False): + def from_mapping(cls, mapping: dict[str, str], trim_keys=False): required = tuple(mapping) output = Counter(mapping.values()) if duplicate := {k for k, v in output.items() if v > 1}: raise ValueError(f"Duplicate output keys in mapping: {duplicate}") - return cls(name, required, tuple(output), trim_keys, mapping) + return cls(required, tuple(output), trim_keys, mapping) def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}}) @@ -79,14 +78,14 @@ def _sigs(self): return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()} @classmethod - def from_funcs(cls, name: str, funcs: dict[str, Callable], trim_keys=False): + def from_funcs(cls, funcs: dict[str, Callable], trim_keys=False): sigs = {k: inspect.signature(f) for k, f in funcs.items()} output = tuple(sigs) input = [] for v in sigs.values(): input.extend(v.parameters.keys()) input = tuple(set(input)) - return cls(name, input, output, trim_keys, funcs) + return cls(input, output, trim_keys, funcs) def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return super().evaluate( @@ -102,8 +101,8 @@ class LimitKeysConversionNode(ConversionNode): keys: set[str] @classmethod - def from_keys(cls, name: str, keys: Sequence[str]): - return cls(name, (), tuple(keys), trim_keys=True, keys=set(keys)) + def from_keys(cls, keys: Sequence[str]): + return cls((), tuple(keys), trim_keys=True, keys=set(keys)) def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in input.items() if k in self.keys} diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index b6d221b..e447e63 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -163,7 +163,7 @@ def __init__(self, data, converters: ConversionNode | list[ConversionNode] | Non setters = list(self.expected_keys | self.required_keys) if hasattr(self, "_wrapped_class"): setters += [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] - self._converters.append(LimitKeysConversionNode.from_keys("setters", setters)) + self._converters.append(LimitKeysConversionNode.from_keys(setters)) self.stale = True @@ -196,9 +196,9 @@ class LineWrapper(ProxyWrapper): def __init__(self, data: DataContainer, converters=None, /, **kwargs): super().__init__(data, converters) self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs) - self._converters.insert(-1, RenameConversionNode.from_mapping("xydata", {"x": "xdata", "y": "ydata"})) + self._converters.insert(-1, RenameConversionNode.from_mapping({"x": "xdata", "y": "ydata"})) setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")] - self._converters[-1] = LimitKeysConversionNode.from_keys("setters", setters) + self._converters[-1] = LimitKeysConversionNode.from_keys(setters) @_stale_wrapper def draw(self, renderer): @@ -261,9 +261,7 @@ def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None cmap = mpl.colormaps["viridis"] if norm is None: raise ValueError("not sure how to do autoscaling yet") - converters.append( - FunctionConversionNode.from_funcs("map colors", {"image": lambda image: cmap(norm(image))}) - ) + converters.append(FunctionConversionNode.from_funcs({"image": lambda image: cmap(norm(image))})) super().__init__(data, converters) kwargs.setdefault("origin", "lower") self._wrapped_instance = self._wrapped_class(None, **kwargs) diff --git a/examples/2Dfunc.py b/examples/2Dfunc.py index e94131a..623a7d1 100644 --- a/examples/2Dfunc.py +++ b/examples/2Dfunc.py @@ -13,7 +13,6 @@ from data_prototype.wrappers import ImageWrapper from data_prototype.containers import FuncContainer -import matplotlib as mpl from matplotlib.colors import Normalize @@ -25,9 +24,8 @@ "image": (("N", "M"), lambda x, y: np.sin(x).reshape(1, -1) * np.cos(y).reshape(-1, 1)), }, ) -cmap = mpl.colormaps["viridis"] -norm = Normalize(-1, 1) -im = ImageWrapper(fc) +norm = Normalize(vmin=-1, vmax=1) +im = ImageWrapper(fc, norm=norm) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/animation.py b/examples/animation.py index 6fd712b..628f313 100644 --- a/examples/animation.py +++ b/examples/animation.py @@ -65,7 +65,7 @@ def update(frame, art): fc = FormatedText( sot_c, FunctionConversionNode.from_funcs( - "fmt", {"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1} + {"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1} ), ha="right", ) diff --git a/examples/mapped.py b/examples/mapped.py index 9df022e..0fe7898 100644 --- a/examples/mapped.py +++ b/examples/mapped.py @@ -22,7 +22,6 @@ norm = Normalize(1, 8) line_converter = FunctionConversionNode.from_funcs( - "line converter", { # arbitrary functions "lw": lambda lw: min(1 + lw, 5), @@ -34,7 +33,6 @@ ) text_converter = FunctionConversionNode.from_funcs( - "text converter", { "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", "y": lambda j: j, diff --git a/examples/mulivariate_cmap.py b/examples/mulivariate_cmap.py index c1e10d5..c00b709 100644 --- a/examples/mulivariate_cmap.py +++ b/examples/mulivariate_cmap.py @@ -41,7 +41,7 @@ def image_nu(image): }, ) -im = ImageWrapper(fc, FunctionConversionNode.from_funcs("bivariate", {"image": image_nu})) +im = ImageWrapper(fc, FunctionConversionNode.from_funcs({"image": image_nu})) fig, ax = plt.subplots() ax.add_artist(im) diff --git a/examples/widgets.py b/examples/widgets.py index c5a70e5..32a824f 100644 --- a/examples/widgets.py +++ b/examples/widgets.py @@ -118,7 +118,7 @@ def _query_hash(self, coord_transform, size): lw = LineWrapper( fc, # color map phase (scaled to 2pi and wrapped to [0, 1]) - FunctionConversionNode.from_funcs("cmap", {"color": lambda color: cmap((color / (2 * np.pi)) % 1)}), + FunctionConversionNode.from_funcs({"color": lambda color: cmap((color / (2 * np.pi)) % 1)}), lw=5, ) ax.add_artist(lw) From 25f2ebc3f1db39b725cb66b0a51e6c7b8935856c Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Thu, 8 Jun 2023 17:24:09 -0400 Subject: [PATCH 6/6] STY: run black --- data_prototype/tests/test_containers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/data_prototype/tests/test_containers.py b/data_prototype/tests/test_containers.py index ddea3c6..4fadefb 100644 --- a/data_prototype/tests/test_containers.py +++ b/data_prototype/tests/test_containers.py @@ -14,7 +14,6 @@ def ac(): def _verify_describe(container): - data, cache_key = container.query(IdentityTransform(), [100, 100]) desc = container.describe()