Skip to content

Commit 1e2ad42

Browse files
authored
Merge pull request #31 from matplotlib/conversion_node
Conversion Node implementation of 'nu'
2 parents 446bc5c + 25f2ebc commit 1e2ad42

File tree

10 files changed

+178
-74
lines changed

10 files changed

+178
-74
lines changed

data_prototype/conversion_node.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable, Callable, Sequence
4+
from collections import Counter
5+
from dataclasses import dataclass
6+
import inspect
7+
from functools import cached_property
8+
9+
from typing import Any
10+
11+
12+
def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]):
13+
for node in nodes:
14+
input = node.evaluate(input)
15+
return input
16+
17+
18+
@dataclass
19+
class ConversionNode:
20+
required_keys: tuple[str, ...]
21+
output_keys: tuple[str, ...]
22+
trim_keys: bool
23+
24+
def preview_keys(self, input_keys: Iterable[str]) -> tuple[str, ...]:
25+
if missing_keys := set(self.required_keys) - set(input_keys):
26+
raise ValueError(f"Missing keys: {missing_keys}")
27+
if self.trim_keys:
28+
return tuple(sorted(set(self.output_keys)))
29+
return tuple(sorted(set(input_keys) | set(self.output_keys)))
30+
31+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
32+
if self.trim_keys:
33+
return {k: input[k] for k in self.output_keys}
34+
else:
35+
if missing_keys := set(self.output_keys) - set(input):
36+
raise ValueError(f"Missing keys: {missing_keys}")
37+
return input
38+
39+
40+
@dataclass
41+
class UnionConversionNode(ConversionNode):
42+
nodes: tuple[ConversionNode, ...]
43+
44+
@classmethod
45+
def from_nodes(cls, *nodes: ConversionNode, trim_keys=False):
46+
required = tuple(set(k for n in nodes for k in n.required_keys))
47+
output = Counter(k for n in nodes for k in n.output_keys)
48+
if duplicate := {k for k, v in output.items() if v > 1}:
49+
raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}")
50+
return cls(required, tuple(output), trim_keys, nodes)
51+
52+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
53+
return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()})
54+
55+
56+
@dataclass
57+
class RenameConversionNode(ConversionNode):
58+
mapping: dict[str, str]
59+
60+
@classmethod
61+
def from_mapping(cls, mapping: dict[str, str], trim_keys=False):
62+
required = tuple(mapping)
63+
output = Counter(mapping.values())
64+
if duplicate := {k for k, v in output.items() if v > 1}:
65+
raise ValueError(f"Duplicate output keys in mapping: {duplicate}")
66+
return cls(required, tuple(output), trim_keys, mapping)
67+
68+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
69+
return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}})
70+
71+
72+
@dataclass
73+
class FunctionConversionNode(ConversionNode):
74+
funcs: dict[str, Callable]
75+
76+
@cached_property
77+
def _sigs(self):
78+
return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()}
79+
80+
@classmethod
81+
def from_funcs(cls, funcs: dict[str, Callable], trim_keys=False):
82+
sigs = {k: inspect.signature(f) for k, f in funcs.items()}
83+
output = tuple(sigs)
84+
input = []
85+
for v in sigs.values():
86+
input.extend(v.parameters.keys())
87+
input = tuple(set(input))
88+
return cls(input, output, trim_keys, funcs)
89+
90+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
91+
return super().evaluate(
92+
{
93+
**input,
94+
**{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()},
95+
}
96+
)
97+
98+
99+
@dataclass
100+
class LimitKeysConversionNode(ConversionNode):
101+
keys: set[str]
102+
103+
@classmethod
104+
def from_keys(cls, keys: Sequence[str]):
105+
return cls((), tuple(keys), trim_keys=True, keys=set(keys))
106+
107+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
108+
return {k: v for k, v in input.items() if k in self.keys}

data_prototype/patches.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class PatchWrapper(ProxyWrapper):
4444
"joinstyle",
4545
}
4646

47-
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
48-
super().__init__(data, nus)
47+
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
48+
super().__init__(data, converters)
4949
self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs)
5050

5151
@_stale_wrapper

data_prototype/tests/test_containers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def ac():
1414

1515

1616
def _verify_describe(container):
17-
1817
data, cache_key = container.query(IdentityTransform(), [100, 100])
1918
desc = container.describe()
2019

data_prototype/wrappers.py

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
from matplotlib.artist import Artist as _Artist
1717

1818
from data_prototype.containers import DataContainer, _MatplotlibTransform
19+
from data_prototype.conversion_node import (
20+
ConversionNode,
21+
RenameConversionNode,
22+
evaluate_pipeline,
23+
FunctionConversionNode,
24+
LimitKeysConversionNode,
25+
)
1926

2027

2128
class _BBox(Protocol):
@@ -139,45 +146,26 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
139146
return self._cache[cache_key]
140147
except KeyError:
141148
...
142-
# TODO decide if units go pre-nu or post-nu?
143-
for x_like in xunits:
144-
if x_like in data:
145-
data[x_like] = ax.xaxis.convert_units(data[x_like])
146-
for y_like in yunits:
147-
if y_like in data:
148-
data[y_like] = ax.xaxis.convert_units(data[y_like])
149-
150-
# doing the nu work here is nice because we can write it once, but we
151-
# really want to push this computation down a layer
152-
# TODO sort out how this interoperates with the transform stack
153-
transformed_data = {}
154-
for k, (nu, sig) in self._sigs.items():
155-
to_pass = set(sig.parameters)
156-
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
149+
# TODO units
150+
transformed_data = evaluate_pipeline(self._converters, data)
157151

158152
self._cache[cache_key] = transformed_data
159153
return transformed_data
160154

161-
def __init__(self, data, nus, **kwargs):
155+
def __init__(self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs):
162156
super().__init__(**kwargs)
163157
self.data = data
164158
self._cache = LFUCache(64)
165159
# TODO make sure mutating this will invalidate the cache!
166-
self._nus = nus or {}
167-
for k in self.required_keys:
168-
self._nus.setdefault(k, _make_identity(k))
169-
desc = data.describe()
170-
for k in self.expected_keys:
171-
if k in desc:
172-
self._nus.setdefault(k, _make_identity(k))
173-
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
160+
if isinstance(converters, ConversionNode):
161+
converters = [converters]
162+
self._converters: list[ConversionNode] = converters or []
163+
setters = list(self.expected_keys | self.required_keys)
164+
if hasattr(self, "_wrapped_class"):
165+
setters += [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")]
166+
self._converters.append(LimitKeysConversionNode.from_keys(setters))
174167
self.stale = True
175168

176-
# TODO add a setter
177-
@property
178-
def nus(self):
179-
return dict(self._nus)
180-
181169

182170
class ProxyWrapper(ProxyWrapperBase):
183171
_privtized_methods: Tuple[str, ...] = ()
@@ -192,7 +180,7 @@ def __getattr__(self, key):
192180
return getattr(self._wrapped_instance, key)
193181

194182
def __setattr__(self, key, value):
195-
if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"):
183+
if key in ("_wrapped_instance", "data", "_cache", "_converters", "stale", "_sigs"):
196184
super().__setattr__(key, value)
197185
elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key):
198186
setattr(self._wrapped_instance, key, value)
@@ -205,9 +193,12 @@ class LineWrapper(ProxyWrapper):
205193
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
206194
required_keys = {"x", "y"}
207195

208-
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
209-
super().__init__(data, nus)
196+
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
197+
super().__init__(data, converters)
210198
self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs)
199+
self._converters.insert(-1, RenameConversionNode.from_mapping({"x": "xdata", "y": "ydata"}))
200+
setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")]
201+
self._converters[-1] = LimitKeysConversionNode.from_keys(setters)
211202

212203
@_stale_wrapper
213204
def draw(self, renderer):
@@ -218,7 +209,6 @@ def draw(self, renderer):
218209

219210
def _update_wrapped(self, data):
220211
for k, v in data.items():
221-
k = {"x": "xdata", "y": "ydata"}.get(k, k)
222212
getattr(self._wrapped_instance, f"set_{k}")(v)
223213

224214

@@ -238,8 +228,8 @@ class PathCollectionWrapper(ProxyWrapper):
238228
"get_paths",
239229
)
240230

241-
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
242-
super().__init__(data, nus)
231+
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
232+
super().__init__(data, converters)
243233
self._wrapped_instance = self._wrapped_class([], **kwargs)
244234
self._wrapped_instance.set_transform(mtransforms.IdentityTransform())
245235

@@ -262,17 +252,17 @@ class ImageWrapper(ProxyWrapper):
262252
_wrapped_class = _AxesImage
263253
required_keys = {"xextent", "yextent", "image"}
264254

265-
def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs):
266-
nus = dict(nus or {})
255+
def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs):
256+
converters = converters or []
267257
if cmap is not None or norm is not None:
268-
if nus is not None and "image" in nus:
258+
if converters is not None and "image" in converters:
269259
raise ValueError("Conflicting input")
270260
if cmap is None:
271261
cmap = mpl.colormaps["viridis"]
272262
if norm is None:
273263
raise ValueError("not sure how to do autoscaling yet")
274-
nus["image"] = lambda image: cmap(norm(image))
275-
super().__init__(data, nus)
264+
converters.append(FunctionConversionNode.from_funcs({"image": lambda image: cmap(norm(image))}))
265+
super().__init__(data, converters)
276266
kwargs.setdefault("origin", "lower")
277267
self._wrapped_instance = self._wrapped_class(None, **kwargs)
278268

@@ -293,8 +283,8 @@ class StepWrapper(ProxyWrapper):
293283
_privtized_methods = () # ("set_data", "get_data")
294284
required_keys = {"edges", "density"}
295285

296-
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
297-
super().__init__(data, nus)
286+
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
287+
super().__init__(data, converters)
298288
self._wrapped_instance = self._wrapped_class([], [1], **kwargs)
299289

300290
@_stale_wrapper
@@ -312,8 +302,8 @@ class FormatedText(ProxyWrapper):
312302
_wrapped_class = _Text
313303
_privtized_methods = ("set_text",)
314304

315-
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
316-
super().__init__(data, nus)
305+
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
306+
super().__init__(data, converters)
317307
self._wrapped_instance = self._wrapped_class(text="", **kwargs)
318308

319309
@_stale_wrapper
@@ -368,8 +358,8 @@ class ErrorbarWrapper(MultiProxyWrapper):
368358
required_keys = {"x", "y"}
369359
expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]}
370360

371-
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
372-
super().__init__(data, nus)
361+
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
362+
super().__init__(data, converters)
373363
# TODO all of the kwarg teasing apart that is needed
374364
color = kwargs.pop("color", "k")
375365
lw = kwargs.pop("lw", 2)

examples/2Dfunc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from data_prototype.wrappers import ImageWrapper
1414
from data_prototype.containers import FuncContainer
1515

16-
import matplotlib as mpl
1716
from matplotlib.colors import Normalize
1817

1918

@@ -25,9 +24,8 @@
2524
"image": (("N", "M"), lambda x, y: np.sin(x).reshape(1, -1) * np.cos(y).reshape(-1, 1)),
2625
},
2726
)
28-
cmap = mpl.colormaps["viridis"]
29-
norm = Normalize(-1, 1)
30-
im = ImageWrapper(fc, {"image": lambda image: cmap(norm(image))})
27+
norm = Normalize(vmin=-1, vmax=1)
28+
im = ImageWrapper(fc, norm=norm)
3129

3230
fig, ax = plt.subplots()
3331
ax.add_artist(im)

examples/animation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from matplotlib.animation import FuncAnimation
1818

1919
from data_prototype.containers import _MatplotlibTransform, Desc
20+
from data_prototype.conversion_node import FunctionConversionNode
2021

2122
from data_prototype.wrappers import LineWrapper, FormatedText
2223

@@ -63,9 +64,9 @@ def update(frame, art):
6364
lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)")
6465
fc = FormatedText(
6566
sot_c,
66-
{"text": lambda phase: f"ϕ={phase:.2f}"},
67-
x=2 * np.pi,
68-
y=1,
67+
FunctionConversionNode.from_funcs(
68+
{"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1}
69+
),
6970
ha="right",
7071
)
7172
fig, ax = plt.subplots()

0 commit comments

Comments
 (0)