Skip to content

Commit c0e503d

Browse files
committed
Treat units info as nu, make nu a list rather than single function
1 parent de25dd1 commit c0e503d

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

data_prototype/wrappers.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
2+
import inspect
23

34
import numpy as np
45

@@ -96,7 +97,7 @@ def draw(self, renderer):
9697
def _update_wrapped(self, data):
9798
raise NotImplementedError
9899

99-
def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]) -> Dict[str, Any]:
100+
def _query_and_transform(self, renderer) -> Dict[str, Any]:
100101
"""
101102
Helper to centralize the data querying and python-side transforms
102103
@@ -126,27 +127,39 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
126127
return self._cache[cache_key]
127128
except KeyError:
128129
...
129-
# TODO decide if units go pre-nu or post-nu?
130-
for x_like in xunits:
131-
if x_like in data:
132-
data[x_like] = ax.xaxis.convert_units(data[x_like])
133-
for y_like in yunits:
134-
if y_like in data:
135-
data[y_like] = ax.xaxis.convert_units(data[y_like])
136-
137130
# doing the nu work here is nice because we can write it once, but we
138131
# really want to push this computation down a layer
139132
# TODO sort out how this interoperates with the transform stack
140-
data = {k: self.nus.get(k, lambda x: x)(v) for k, v in data.items()}
133+
def linearize(funclist):
134+
if inspect.isfunction(funclist):
135+
return funclist
136+
def ret(x):
137+
for func in funclist:
138+
x = func(x)
139+
return x
140+
return ret
141+
data = {k: linearize(self.nus.get(k, lambda x: x))(v) for k, v in data.items()}
141142
self._cache[cache_key] = data
142143
return data
143144

144-
def __init__(self, data, nus, **kwargs):
145+
def __init__(self, data, nus, xunits: List[str] = [], yunits: List[str] = [], **kwargs):
145146
super().__init__(**kwargs)
146147
self.data = data
147148
self._cache = LFUCache(64)
148149
# TODO make sure mutating this will invalidate the cache!
149150
self.nus = nus or {}
151+
for field in xunits:
152+
if field not in self.nus:
153+
self.nus[field] = []
154+
if inspect.isfunction(self.nus[field]):
155+
self.nus[field] = [self.nus[field]]
156+
self.nus[field].append(lambda x: self.axes.xaxis.convert_units(x))
157+
for field in yunits:
158+
if field not in self.nus:
159+
self.nus[field] = []
160+
if inspect.isfunction(self.nus[field]):
161+
self.nus[field] = [self.nus[field]]
162+
self.nus[field].append(lambda y: self.axes.yaxis.convert_units(y))
150163
self.stale = True
151164

152165

@@ -176,13 +189,13 @@ class LineWrapper(ProxyWrapper):
176189
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
177190

178191
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
179-
super().__init__(data, nus)
192+
super().__init__(data, nus, xunits=["x"], yunits=["y"])
180193
self._wrapped_instance = self._wrapped_class([], [], **kwargs)
181194

182195
@_stale_wrapper
183196
def draw(self, renderer):
184197
self._update_wrapped(
185-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
198+
self._query_and_transform(renderer),
186199
)
187200
return self._wrapped_instance.draw(renderer)
188201

@@ -204,14 +217,14 @@ def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwa
204217
if norm is None:
205218
raise ValueError("not sure how to do autoscaling yet")
206219
nus["image"] = lambda image: cmap(norm(image))
207-
super().__init__(data, nus)
220+
super().__init__(data, nus, xunits=["xextent"], yunits=["yextent"])
208221
kwargs.setdefault("origin", "lower")
209222
self._wrapped_instance = self._wrapped_class(None, **kwargs)
210223

211224
@_stale_wrapper
212225
def draw(self, renderer):
213226
self._update_wrapped(
214-
self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]),
227+
self._query_and_transform(renderer),
215228
)
216229
return self._wrapped_instance.draw(renderer)
217230

@@ -225,13 +238,13 @@ class StepWrapper(ProxyWrapper):
225238
_privtized_methods = () # ("set_data", "get_data")
226239

227240
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
228-
super().__init__(data, nus)
241+
super().__init__(data, nus, xunits=["edges"], yunits=["density"])
229242
self._wrapped_instance = self._wrapped_class([], [1], **kwargs)
230243

231244
@_stale_wrapper
232245
def draw(self, renderer):
233246
self._update_wrapped(
234-
self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]),
247+
self._query_and_transform(renderer),
235248
)
236249
return self._wrapped_instance.draw(renderer)
237250

@@ -251,7 +264,7 @@ def __init__(self, data: DataContainer, format_func, nus=None, /, **kwargs):
251264
@_stale_wrapper
252265
def draw(self, renderer):
253266
self._update_wrapped(
254-
self._query_and_transform(renderer, xunits=[], yunits=[]),
267+
self._query_and_transform(renderer),
255268
)
256269
return self._wrapped_instance.draw(renderer)
257270

@@ -297,7 +310,7 @@ def get_children(self):
297310

298311
class ErrorbarWrapper(MultiProxyWrapper):
299312
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
300-
super().__init__(data, nus)
313+
super().__init__(data, nus, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"])
301314
# TODO all of the kwarg teasing apart that is needed
302315
color = kwargs.pop("color", "k")
303316
lw = kwargs.pop("lw", 2)
@@ -325,7 +338,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
325338
def draw(self, renderer):
326339
self._update_wrapped(
327340
self._query_and_transform(
328-
renderer, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"]
341+
renderer
329342
),
330343
)
331344
for k, v in self._wrapped_instances.items():

0 commit comments

Comments
 (0)