Skip to content

Commit 65b295b

Browse files
committed
Treat units info as nu, make nu a list rather than single function
1 parent 045ba27 commit 65b295b

File tree

2 files changed

+50
-45
lines changed

2 files changed

+50
-45
lines changed

data_prototype/patches.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ class PatchWrapper(ProxyWrapper):
4545
}
4646

4747
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
48-
super().__init__(data, nus)
48+
super().__init__(data, nus, xunits=self._xunits, yunits=self._yunits)
4949
self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs)
5050

5151
@_stale_wrapper
5252
def draw(self, renderer):
53-
self._update_wrapped(self._query_and_transform(renderer, xunits=self._xunits, yunits=self._yunits))
53+
self._update_wrapped(self._query_and_transform(renderer))
5454
return self._wrapped_instance.draw(renderer)
5555

5656
def _update_wrapped(self, data):

data_prototype/wrappers.py

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
1+
from typing import Any, Protocol, get_type_hints
22
import inspect
33

44
import numpy as np
55

66
from cachetools import LFUCache
7+
from collections.abc import Sequence
78
from functools import partial, wraps
89

910
import matplotlib as mpl
@@ -19,7 +20,7 @@
1920

2021

2122
class _BBox(Protocol):
22-
size: Tuple[float, float]
23+
size: tuple[float, float]
2324

2425

2526
class _Axis(Protocol):
@@ -34,10 +35,10 @@ class _Axes(Protocol):
3435
transData: _MatplotlibTransform
3536
transAxes: _MatplotlibTransform
3637

37-
def get_xlim(self) -> Tuple[float, float]:
38+
def get_xlim(self) -> tuple[float, float]:
3839
...
3940

40-
def get_ylim(self) -> Tuple[float, float]:
41+
def get_ylim(self) -> tuple[float, float]:
4142
...
4243

4344
def get_window_extent(self, renderer) -> _BBox:
@@ -47,15 +48,16 @@ def get_window_extent(self, renderer) -> _BBox:
4748
class _Aritst(Protocol):
4849
axes: _Axes
4950

51+
def _make_param_name(k, func):
52+
def wrapped(**kwargs):
53+
(arg,) = kwargs.values()
54+
return func(arg)
5055

51-
def _make_identity(k):
52-
def identity(**kwargs):
53-
(_,) = kwargs.values()
54-
return _
55-
56-
identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)])
57-
return identity
56+
wrapped.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)])
57+
return wrapped
5858

59+
def _make_identity(k):
60+
return _make_param_name(k, lambda x: x)
5961

6062
def _forwarder(forwards, cls=None):
6163
if cls is None:
@@ -109,7 +111,7 @@ def draw(self, renderer):
109111
def _update_wrapped(self, data):
110112
raise NotImplementedError
111113

112-
def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]) -> Dict[str, Any]:
114+
def _query_and_transform(self, renderer) -> dict[str, Any]:
113115
"""
114116
Helper to centralize the data querying and python-side transforms
115117
@@ -139,38 +141,43 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
139141
return self._cache[cache_key]
140142
except KeyError:
141143
...
142-
# TODO decide if units go pre-nu or post-nu?
143-
for x_like in xunits:
144-
if x_like in data:
145-
data[x_like] = ax.xaxis.convert_units(data[x_like])
146-
for y_like in yunits:
147-
if y_like in data:
148-
data[y_like] = ax.xaxis.convert_units(data[y_like])
149-
150144
# doing the nu work here is nice because we can write it once, but we
151145
# really want to push this computation down a layer
152146
# TODO sort out how this interoperates with the transform stack
153147
transformed_data = {}
154-
for k, (nu, sig) in self._sigs.items():
155-
to_pass = set(sig.parameters)
156-
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
148+
for k, nu_list in self._sigs.items():
149+
for nu, sig in nu_list:
150+
to_pass = set(sig.parameters)
151+
transformed_data[k] = nu(**{k: transformed_data.get(k, data[k]) for k in to_pass})
157152

158153
self._cache[cache_key] = transformed_data
159154
return transformed_data
160155

161-
def __init__(self, data, nus, **kwargs):
156+
def __init__(self, data, nus, xunits: tuple[str, ...] = (), yunits: tuple[str, ...] = (), **kwargs):
162157
super().__init__(**kwargs)
163158
self.data = data
164159
self._cache = LFUCache(64)
165160
# TODO make sure mutating this will invalidate the cache!
166161
self._nus = nus or {}
167162
for k in self.required_keys:
168-
self._nus.setdefault(k, _make_identity(k))
163+
self._nus.setdefault(k, [_make_identity(k)])
164+
169165
desc = data.describe()
170166
for k in self.expected_keys:
171167
if k in desc:
172-
self._nus.setdefault(k, _make_identity(k))
173-
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
168+
self._nus.setdefault(k, [_make_identity(k)])
169+
170+
for field in self._nus:
171+
if inspect.isfunction(self._nus[field]):
172+
self._nus[field] = [self._nus[field]]
173+
174+
for field in xunits:
175+
self._nus[field].append(_make_param_name(field, lambda x: self.axes.xaxis.convert_units(x)))
176+
177+
for field in yunits:
178+
self._nus[field].append(_make_param_name(field, lambda y: self.axes.yaxis.convert_units(y)))
179+
180+
self._sigs = {k: [(nu, inspect.signature(nu)) for nu in nu_list] for k, nu_list in self._nus.items()}
174181
self.stale = True
175182

176183
# TODO add a setter
@@ -180,7 +187,7 @@ def nus(self):
180187

181188

182189
class ProxyWrapper(ProxyWrapperBase):
183-
_privtized_methods: Tuple[str, ...] = ()
190+
_privtized_methods: tuple[str, ...] = ()
184191
_wrapped_class = None
185192
_wrapped_instance: _Aritst
186193

@@ -206,13 +213,13 @@ class LineWrapper(ProxyWrapper):
206213
required_keys = {"x", "y"}
207214

208215
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
209-
super().__init__(data, nus)
216+
super().__init__(data, nus, xunits=["x"], yunits=["y"])
210217
self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs)
211218

212219
@_stale_wrapper
213220
def draw(self, renderer):
214221
self._update_wrapped(
215-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
222+
self._query_and_transform(renderer),
216223
)
217224
return self._wrapped_instance.draw(renderer)
218225

@@ -239,14 +246,14 @@ class PathCollectionWrapper(ProxyWrapper):
239246
)
240247

241248
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
242-
super().__init__(data, nus)
249+
super().__init__(data, nus, xunits = ("x",), yunits = ("y",))
243250
self._wrapped_instance = self._wrapped_class([], **kwargs)
244251
self._wrapped_instance.set_transform(mtransforms.IdentityTransform())
245252

246253
@_stale_wrapper
247254
def draw(self, renderer):
248255
self._update_wrapped(
249-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
256+
self._query_and_transform(renderer),
250257
)
251258
return self._wrapped_instance.draw(renderer)
252259

@@ -272,14 +279,14 @@ def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwa
272279
if norm is None:
273280
raise ValueError("not sure how to do autoscaling yet")
274281
nus["image"] = lambda image: cmap(norm(image))
275-
super().__init__(data, nus)
282+
super().__init__(data, nus, xunits=["xextent"], yunits=["yextent"])
276283
kwargs.setdefault("origin", "lower")
277284
self._wrapped_instance = self._wrapped_class(None, **kwargs)
278285

279286
@_stale_wrapper
280287
def draw(self, renderer):
281288
self._update_wrapped(
282-
self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]),
289+
self._query_and_transform(renderer),
283290
)
284291
return self._wrapped_instance.draw(renderer)
285292

@@ -294,13 +301,13 @@ class StepWrapper(ProxyWrapper):
294301
required_keys = {"edges", "density"}
295302

296303
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
297-
super().__init__(data, nus)
304+
super().__init__(data, nus, xunits=["edges"], yunits=["density"])
298305
self._wrapped_instance = self._wrapped_class([], [1], **kwargs)
299306

300307
@_stale_wrapper
301308
def draw(self, renderer):
302309
self._update_wrapped(
303-
self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]),
310+
self._query_and_transform(renderer),
304311
)
305312
return self._wrapped_instance.draw(renderer)
306313

@@ -319,7 +326,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
319326
@_stale_wrapper
320327
def draw(self, renderer):
321328
self._update_wrapped(
322-
self._query_and_transform(renderer, xunits=[], yunits=[]),
329+
self._query_and_transform(renderer),
323330
)
324331
return self._wrapped_instance.draw(renderer)
325332

@@ -342,8 +349,8 @@ def _update_wrapped(self, data):
342349
)
343350
# _Artist has to go last for now because it is not (yet) MI friendly.
344351
class MultiProxyWrapper(ProxyWrapperBase, _Artist):
345-
_privtized_methods: Tuple[str, ...] = ()
346-
_wrapped_instances: Dict[str, _Aritst]
352+
_privtized_methods: tuple[str, ...] = ()
353+
_wrapped_instances: dict[str, _Aritst]
347354

348355
def __setattr__(self, key, value):
349356
attrs = set(get_type_hints(type(self)))
@@ -369,7 +376,7 @@ class ErrorbarWrapper(MultiProxyWrapper):
369376
expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]}
370377

371378
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
372-
super().__init__(data, nus)
379+
super().__init__(data, nus, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"])
373380
# TODO all of the kwarg teasing apart that is needed
374381
color = kwargs.pop("color", "k")
375382
lw = kwargs.pop("lw", 2)
@@ -396,9 +403,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
396403
@_stale_wrapper
397404
def draw(self, renderer):
398405
self._update_wrapped(
399-
self._query_and_transform(
400-
renderer, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"]
401-
),
406+
self._query_and_transform(renderer),
402407
)
403408
for k, v in self._wrapped_instances.items():
404409
v.draw(renderer)

0 commit comments

Comments
 (0)