Skip to content

Commit 23cd067

Browse files
committed
ENH: make the nus functions more filtering and powerful
Make it possible to rename as part of the transforms
1 parent 4ad3c49 commit 23cd067

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

data_prototype/wrappers.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
2+
import inspect
23

34
import numpy as np
45

@@ -88,6 +89,7 @@ class ProxyWrapperBase:
8889
data: DataContainer
8990
axes: _Axes
9091
stale: bool
92+
required_keys: set = set()
9193

9294
@_stale_wrapper
9395
def draw(self, renderer):
@@ -137,18 +139,38 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
137139
# doing the nu work here is nice because we can write it once, but we
138140
# really want to push this computation down a layer
139141
# TODO sort out how this interoperates with the transform stack
140-
data = {k: self.nus.get(k, lambda x: x)(v) for k, v in data.items()}
141-
self._cache[cache_key] = data
142-
return data
142+
transformed_data = {}
143+
for k, (nu, sig) in self._sigs.items():
144+
to_pass = set(sig.parameters)
145+
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
146+
self._cache[cache_key] = transformed_data
147+
return transformed_data
143148

144149
def __init__(self, data, nus, **kwargs):
145150
super().__init__(**kwargs)
146151
self.data = data
147152
self._cache = LFUCache(64)
148153
# TODO make sure mutating this will invalidate the cache!
149-
self.nus = nus or {}
154+
self._nus = nus or {}
155+
for k in self.required_keys:
156+
157+
def identity(**kwargs):
158+
(_,) = kwargs.values()
159+
return _
160+
161+
identity.__signature__ = inspect.Signature(
162+
[inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)]
163+
)
164+
165+
self._nus.setdefault(k, identity)
166+
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
150167
self.stale = True
151168

169+
# TODO add a setter
170+
@property
171+
def nus(self):
172+
return dict(self._nus)
173+
152174

153175
class ProxyWrapper(ProxyWrapperBase):
154176
_privtized_methods: Tuple[str, ...] = ()
@@ -163,7 +185,7 @@ def __getattr__(self, key):
163185
return getattr(self._wrapped_instance, key)
164186

165187
def __setattr__(self, key, value):
166-
if key in ("_wrapped_instance", "data", "_cache", "nus", "stale"):
188+
if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"):
167189
super().__setattr__(key, value)
168190
elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key):
169191
setattr(self._wrapped_instance, key, value)
@@ -174,6 +196,7 @@ def __setattr__(self, key, value):
174196
class LineWrapper(ProxyWrapper):
175197
_wrapped_class = _Line2D
176198
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
199+
required_keys = {"x", "y"}
177200

178201
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
179202
super().__init__(data, nus)
@@ -188,6 +211,7 @@ def draw(self, renderer):
188211

189212
def _update_wrapped(self, data):
190213
for k, v in data.items():
214+
k = {"x": "xdata", "y": "ydata"}.get(k, k)
191215
getattr(self._wrapped_instance, f"set_{k}")(v)
192216

193217

@@ -244,10 +268,9 @@ class FormatedText(ProxyWrapper):
244268
_wrapped_class = _Text
245269
_privtized_methods = ("set_text",)
246270

247-
def __init__(self, data: DataContainer, format_func, nus=None, /, **kwargs):
271+
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
248272
super().__init__(data, nus)
249273
self._wrapped_instance = self._wrapped_class(text="", **kwargs)
250-
self._format_func = format_func
251274

252275
@_stale_wrapper
253276
def draw(self, renderer):
@@ -257,7 +280,9 @@ def draw(self, renderer):
257280
return self._wrapped_instance.draw(renderer)
258281

259282
def _update_wrapped(self, data):
260-
self._wrapped_instance.set_text(self._format_func(**data))
283+
for k, v in data.items():
284+
k = {"x": "xdata", "y": "ydata"}.get(k, k)
285+
getattr(self._wrapped_instance, f"set_{k}")(v)
261286

262287

263288
@_forwarder(

examples/animation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def update(frame, art):
6161
lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)")
6262
fc = FormatedText(
6363
sot_c,
64-
"ϕ={phase:.2f} ".format,
64+
{'text': lambda phase: f"ϕ={phase:.2f}"},
6565
x=2 * np.pi,
6666
y=1,
6767
ha="right",

0 commit comments

Comments
 (0)