Skip to content

Commit 0e1c99f

Browse files
committed
Get examples working with conversion node infrastructure
1 parent 01b5746 commit 0e1c99f

File tree

8 files changed

+66
-51
lines changed

8 files changed

+66
-51
lines changed

data_prototype/conversion_node.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,15 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
9595
**{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()},
9696
}
9797
)
98+
99+
100+
@dataclass
101+
class LimitKeysConversionNode(ConversionNode):
102+
keys: set[str]
103+
104+
@classmethod
105+
def from_keys(cls, name: str, keys: Sequence[str]):
106+
return cls(name, (), tuple(keys), trim_keys=True, keys=set(keys))
107+
108+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
109+
return {k: v for k, v in input.items() if k in self.keys}

data_prototype/wrappers.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
from matplotlib.artist import Artist as _Artist
1717

1818
from data_prototype.containers import DataContainer, _MatplotlibTransform
19+
from data_prototype.conversion_node import (
20+
ConversionNode,
21+
RenameConversionNode,
22+
evaluate_pipeline,
23+
FunctionConversionNode,
24+
LimitKeysConversionNode,
25+
)
1926

2027

2128
class _BBox(Protocol):
@@ -139,45 +146,26 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
139146
return self._cache[cache_key]
140147
except KeyError:
141148
...
142-
# TODO decide if units go pre-nu or post-nu?
143-
for x_like in xunits:
144-
if x_like in data:
145-
data[x_like] = ax.xaxis.convert_units(data[x_like])
146-
for y_like in yunits:
147-
if y_like in data:
148-
data[y_like] = ax.xaxis.convert_units(data[y_like])
149-
150-
# doing the nu work here is nice because we can write it once, but we
151-
# really want to push this computation down a layer
152-
# TODO sort out how this interoperates with the transform stack
153-
transformed_data = {}
154-
for k, (nu, sig) in self._sigs.items():
155-
to_pass = set(sig.parameters)
156-
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
149+
# TODO units
150+
transformed_data = evaluate_pipeline(self._converters, data)
157151

158152
self._cache[cache_key] = transformed_data
159153
return transformed_data
160154

161-
def __init__(self, data, converters, **kwargs):
155+
def __init__(self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs):
162156
super().__init__(**kwargs)
163157
self.data = data
164158
self._cache = LFUCache(64)
165159
# TODO make sure mutating this will invalidate the cache!
166-
self._converters = converters or {}
167-
for k in self.required_keys:
168-
self._converters.setdefault(k, _make_identity(k))
169-
desc = data.describe()
170-
for k in self.expected_keys:
171-
if k in desc:
172-
self._converters.setdefault(k, _make_identity(k))
173-
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._converters.items()}
160+
if isinstance(converters, ConversionNode):
161+
converters = [converters]
162+
self._converters: list[ConversionNode] = converters or []
163+
setters = list(self.expected_keys | self.required_keys)
164+
if hasattr(self, "_wrapped_class"):
165+
setters += [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")]
166+
self._converters.append(LimitKeysConversionNode.from_keys("setters", setters))
174167
self.stale = True
175168

176-
# TODO add a setter
177-
@property
178-
def converters(self):
179-
return dict(self._converters)
180-
181169

182170
class ProxyWrapper(ProxyWrapperBase):
183171
_privtized_methods: Tuple[str, ...] = ()
@@ -208,6 +196,9 @@ class LineWrapper(ProxyWrapper):
208196
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
209197
super().__init__(data, converters)
210198
self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs)
199+
self._converters.insert(-1, RenameConversionNode.from_mapping("xydata", {"x": "xdata", "y": "ydata"}))
200+
setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")]
201+
self._converters[-1] = LimitKeysConversionNode.from_keys("setters", setters)
211202

212203
@_stale_wrapper
213204
def draw(self, renderer):
@@ -218,7 +209,6 @@ def draw(self, renderer):
218209

219210
def _update_wrapped(self, data):
220211
for k, v in data.items():
221-
k = {"x": "xdata", "y": "ydata"}.get(k, k)
222212
getattr(self._wrapped_instance, f"set_{k}")(v)
223213

224214

@@ -263,15 +253,17 @@ class ImageWrapper(ProxyWrapper):
263253
required_keys = {"xextent", "yextent", "image"}
264254

265255
def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs):
266-
converters = dict(converters or {})
256+
converters = converters or []
267257
if cmap is not None or norm is not None:
268258
if converters is not None and "image" in converters:
269259
raise ValueError("Conflicting input")
270260
if cmap is None:
271261
cmap = mpl.colormaps["viridis"]
272262
if norm is None:
273263
raise ValueError("not sure how to do autoscaling yet")
274-
converters["image"] = lambda image: cmap(norm(image))
264+
converters.append(
265+
FunctionConversionNode.from_funcs("map colors", {"image": lambda image: cmap(norm(image))})
266+
)
275267
super().__init__(data, converters)
276268
kwargs.setdefault("origin", "lower")
277269
self._wrapped_instance = self._wrapped_class(None, **kwargs)

examples/2Dfunc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
cmap = mpl.colormaps["viridis"]
2929
norm = Normalize(-1, 1)
30-
im = ImageWrapper(fc, {"image": lambda image: cmap(norm(image))})
30+
im = ImageWrapper(fc)
3131

3232
fig, ax = plt.subplots()
3333
ax.add_artist(im)

examples/animation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from matplotlib.animation import FuncAnimation
1818

1919
from data_prototype.containers import _MatplotlibTransform, Desc
20+
from data_prototype.conversion_node import FunctionConversionNode
2021

2122
from data_prototype.wrappers import LineWrapper, FormatedText
2223

@@ -63,9 +64,9 @@ def update(frame, art):
6364
lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)")
6465
fc = FormatedText(
6566
sot_c,
66-
{"text": lambda phase: f"ϕ={phase:.2f}"},
67-
x=2 * np.pi,
68-
y=1,
67+
FunctionConversionNode.from_funcs(
68+
"fmt", {"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1}
69+
),
6970
ha="right",
7071
)
7172
fig, ax = plt.subplots()

examples/mapped.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,33 @@
1414

1515
from data_prototype.wrappers import LineWrapper, FormatedText
1616
from data_prototype.containers import ArrayContainer
17+
from data_prototype.conversion_node import FunctionConversionNode
1718

1819
cmap = plt.colormaps["viridis"]
1920
cmap.set_over("k")
2021
cmap.set_under("r")
2122
norm = Normalize(1, 8)
2223

23-
line_converter = {
24-
# arbitrary functions
25-
"lw": lambda lw: min(1 + lw, 5),
26-
# standard color mapping
27-
"color": lambda j: cmap(norm(j)),
28-
# categorical
29-
"ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]],
30-
}
24+
line_converter = FunctionConversionNode.from_funcs(
25+
"line converter",
26+
{
27+
# arbitrary functions
28+
"lw": lambda lw: min(1 + lw, 5),
29+
# standard color mapping
30+
"color": lambda j: cmap(norm(j)),
31+
# categorical
32+
"ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]],
33+
},
34+
)
3135

32-
text_converter = {
33-
"text": lambda j, cat: f"index={j[()]} class={cat[()]!r}",
34-
"y": lambda j: j,
35-
}
36+
text_converter = FunctionConversionNode.from_funcs(
37+
"text converter",
38+
{
39+
"text": lambda j, cat: f"index={j[()]} class={cat[()]!r}",
40+
"y": lambda j: j,
41+
"x": lambda x: 2 * np.pi,
42+
},
43+
)
3644

3745

3846
th = np.linspace(0, 2 * np.pi, 128)

examples/mulivariate_cmap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from data_prototype.wrappers import ImageWrapper
1515
from data_prototype.containers import FuncContainer
16+
from data_prototype.conversion_node import FunctionConversionNode
1617

1718
from matplotlib.colors import hsv_to_rgb
1819

@@ -40,7 +41,7 @@ def image_nu(image):
4041
},
4142
)
4243

43-
im = ImageWrapper(fc, {"image": image_nu})
44+
im = ImageWrapper(fc, FunctionConversionNode.from_funcs("bivariate", {"image": image_nu}))
4445

4546
fig, ax = plt.subplots()
4647
ax.add_artist(im)

examples/subsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def query(
6868
sub = Subsample()
6969
cmap = mpl.colormaps["coolwarm"]
7070
norm = Normalize(-2.2, 2.2)
71-
im = ImageWrapper(sub, {"image": lambda image: cmap(norm(image))})
71+
im = ImageWrapper(sub)
7272

7373
fig, ax = plt.subplots()
7474
ax.add_artist(im)

examples/widgets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from data_prototype.wrappers import LineWrapper
1717
from data_prototype.containers import FuncContainer
18+
from data_prototype.conversion_node import FunctionConversionNode
1819

1920

2021
class SliderContainer(FuncContainer):
@@ -117,7 +118,7 @@ def _query_hash(self, coord_transform, size):
117118
lw = LineWrapper(
118119
fc,
119120
# color map phase (scaled to 2pi and wrapped to [0, 1])
120-
{"color": lambda color: cmap((color / (2 * np.pi)) % 1)},
121+
FunctionConversionNode.from_funcs("cmap", {"color": lambda color: cmap((color / (2 * np.pi)) % 1)}),
121122
lw=5,
122123
)
123124
ax.add_artist(lw)

0 commit comments

Comments
 (0)