1
- from typing import List , Dict , Any , Protocol , Tuple , get_type_hints
1
+ from typing import Any , Protocol , get_type_hints
2
2
import inspect
3
3
4
4
import numpy as np
5
5
6
6
from cachetools import LFUCache
7
+ from collections .abc import Sequence
7
8
from functools import partial , wraps
8
9
9
10
import matplotlib as mpl
19
20
20
21
21
22
class _BBox (Protocol ):
22
- size : Tuple [float , float ]
23
+ size : tuple [float , float ]
23
24
24
25
25
26
class _Axis (Protocol ):
@@ -34,10 +35,10 @@ class _Axes(Protocol):
34
35
transData : _MatplotlibTransform
35
36
transAxes : _MatplotlibTransform
36
37
37
- def get_xlim (self ) -> Tuple [float , float ]:
38
+ def get_xlim (self ) -> tuple [float , float ]:
38
39
...
39
40
40
- def get_ylim (self ) -> Tuple [float , float ]:
41
+ def get_ylim (self ) -> tuple [float , float ]:
41
42
...
42
43
43
44
def get_window_extent (self , renderer ) -> _BBox :
@@ -47,15 +48,16 @@ def get_window_extent(self, renderer) -> _BBox:
47
48
class _Aritst (Protocol ):
48
49
axes : _Axes
49
50
51
+ def _make_param_name (k , func ):
52
+ def wrapped (** kwargs ):
53
+ (arg ,) = kwargs .values ()
54
+ return func (arg )
50
55
51
- def _make_identity (k ):
52
- def identity (** kwargs ):
53
- (_ ,) = kwargs .values ()
54
- return _
55
-
56
- identity .__signature__ = inspect .Signature ([inspect .Parameter (k , inspect .Parameter .POSITIONAL_OR_KEYWORD )])
57
- return identity
56
+ wrapped .__signature__ = inspect .Signature ([inspect .Parameter (k , inspect .Parameter .POSITIONAL_OR_KEYWORD )])
57
+ return wrapped
58
58
59
+ def _make_identity (k ):
60
+ return _make_param_name (k , lambda x : x )
59
61
60
62
def _forwarder (forwards , cls = None ):
61
63
if cls is None :
@@ -109,7 +111,7 @@ def draw(self, renderer):
109
111
def _update_wrapped (self , data ):
110
112
raise NotImplementedError
111
113
112
- def _query_and_transform (self , renderer , * , xunits : List [ str ], yunits : List [ str ] ) -> Dict [str , Any ]:
114
+ def _query_and_transform (self , renderer ) -> dict [str , Any ]:
113
115
"""
114
116
Helper to centralize the data querying and python-side transforms
115
117
@@ -139,38 +141,43 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
139
141
return self ._cache [cache_key ]
140
142
except KeyError :
141
143
...
142
- # TODO decide if units go pre-nu or post-nu?
143
- for x_like in xunits :
144
- if x_like in data :
145
- data [x_like ] = ax .xaxis .convert_units (data [x_like ])
146
- for y_like in yunits :
147
- if y_like in data :
148
- data [y_like ] = ax .xaxis .convert_units (data [y_like ])
149
-
150
144
# doing the nu work here is nice because we can write it once, but we
151
145
# really want to push this computation down a layer
152
146
# TODO sort out how this interoperates with the transform stack
153
147
transformed_data = {}
154
- for k , (nu , sig ) in self ._sigs .items ():
155
- to_pass = set (sig .parameters )
156
- transformed_data [k ] = nu (** {k : data [k ] for k in to_pass })
148
+ for k , nu_list in self ._sigs .items ():
149
+ for nu , sig in nu_list :
150
+ to_pass = set (sig .parameters )
151
+ transformed_data [k ] = nu (** {k : transformed_data .get (k , data [k ]) for k in to_pass })
157
152
158
153
self ._cache [cache_key ] = transformed_data
159
154
return transformed_data
160
155
161
- def __init__ (self , data , nus , ** kwargs ):
156
+ def __init__ (self , data , nus , xunits : tuple [ str , ...] = (), yunits : tuple [ str , ...] = (), ** kwargs ):
162
157
super ().__init__ (** kwargs )
163
158
self .data = data
164
159
self ._cache = LFUCache (64 )
165
160
# TODO make sure mutating this will invalidate the cache!
166
161
self ._nus = nus or {}
167
162
for k in self .required_keys :
168
- self ._nus .setdefault (k , _make_identity (k ))
163
+ self ._nus .setdefault (k , [_make_identity (k )])
164
+
169
165
desc = data .describe ()
170
166
for k in self .expected_keys :
171
167
if k in desc :
172
- self ._nus .setdefault (k , _make_identity (k ))
173
- self ._sigs = {k : (nu , inspect .signature (nu )) for k , nu in self ._nus .items ()}
168
+ self ._nus .setdefault (k , [_make_identity (k )])
169
+
170
+ for field in self ._nus :
171
+ if inspect .isfunction (self ._nus [field ]):
172
+ self ._nus [field ] = [self ._nus [field ]]
173
+
174
+ for field in xunits :
175
+ self ._nus [field ].append (_make_param_name (field , lambda x : self .axes .xaxis .convert_units (x )))
176
+
177
+ for field in yunits :
178
+ self ._nus [field ].append (_make_param_name (field , lambda y : self .axes .yaxis .convert_units (y )))
179
+
180
+ self ._sigs = {k : [(nu , inspect .signature (nu )) for nu in nu_list ] for k , nu_list in self ._nus .items ()}
174
181
self .stale = True
175
182
176
183
# TODO add a setter
@@ -180,7 +187,7 @@ def nus(self):
180
187
181
188
182
189
class ProxyWrapper (ProxyWrapperBase ):
183
- _privtized_methods : Tuple [str , ...] = ()
190
+ _privtized_methods : tuple [str , ...] = ()
184
191
_wrapped_class = None
185
192
_wrapped_instance : _Aritst
186
193
@@ -206,13 +213,13 @@ class LineWrapper(ProxyWrapper):
206
213
required_keys = {"x" , "y" }
207
214
208
215
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
209
- super ().__init__ (data , nus )
216
+ super ().__init__ (data , nus , xunits = [ "x" ], yunits = [ "y" ] )
210
217
self ._wrapped_instance = self ._wrapped_class (np .array ([]), np .array ([]), ** kwargs )
211
218
212
219
@_stale_wrapper
213
220
def draw (self , renderer ):
214
221
self ._update_wrapped (
215
- self ._query_and_transform (renderer , xunits = [ "x" ], yunits = [ "y" ] ),
222
+ self ._query_and_transform (renderer ),
216
223
)
217
224
return self ._wrapped_instance .draw (renderer )
218
225
@@ -239,14 +246,14 @@ class PathCollectionWrapper(ProxyWrapper):
239
246
)
240
247
241
248
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
242
- super ().__init__ (data , nus )
249
+ super ().__init__ (data , nus , xunits = ( "x" ,), yunits = ( "y" ,) )
243
250
self ._wrapped_instance = self ._wrapped_class ([], ** kwargs )
244
251
self ._wrapped_instance .set_transform (mtransforms .IdentityTransform ())
245
252
246
253
@_stale_wrapper
247
254
def draw (self , renderer ):
248
255
self ._update_wrapped (
249
- self ._query_and_transform (renderer , xunits = [ "x" ], yunits = [ "y" ] ),
256
+ self ._query_and_transform (renderer ),
250
257
)
251
258
return self ._wrapped_instance .draw (renderer )
252
259
@@ -272,14 +279,14 @@ def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwa
272
279
if norm is None :
273
280
raise ValueError ("not sure how to do autoscaling yet" )
274
281
nus ["image" ] = lambda image : cmap (norm (image ))
275
- super ().__init__ (data , nus )
282
+ super ().__init__ (data , nus , xunits = [ "xextent" ], yunits = [ "yextent" ] )
276
283
kwargs .setdefault ("origin" , "lower" )
277
284
self ._wrapped_instance = self ._wrapped_class (None , ** kwargs )
278
285
279
286
@_stale_wrapper
280
287
def draw (self , renderer ):
281
288
self ._update_wrapped (
282
- self ._query_and_transform (renderer , xunits = [ "xextent" ], yunits = [ "yextent" ] ),
289
+ self ._query_and_transform (renderer ),
283
290
)
284
291
return self ._wrapped_instance .draw (renderer )
285
292
@@ -294,13 +301,13 @@ class StepWrapper(ProxyWrapper):
294
301
required_keys = {"edges" , "density" }
295
302
296
303
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
297
- super ().__init__ (data , nus )
304
+ super ().__init__ (data , nus , xunits = [ "edges" ], yunits = [ "density" ] )
298
305
self ._wrapped_instance = self ._wrapped_class ([], [1 ], ** kwargs )
299
306
300
307
@_stale_wrapper
301
308
def draw (self , renderer ):
302
309
self ._update_wrapped (
303
- self ._query_and_transform (renderer , xunits = [ "edges" ], yunits = [ "density" ] ),
310
+ self ._query_and_transform (renderer ),
304
311
)
305
312
return self ._wrapped_instance .draw (renderer )
306
313
@@ -319,7 +326,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
319
326
@_stale_wrapper
320
327
def draw (self , renderer ):
321
328
self ._update_wrapped (
322
- self ._query_and_transform (renderer , xunits = [], yunits = [] ),
329
+ self ._query_and_transform (renderer ),
323
330
)
324
331
return self ._wrapped_instance .draw (renderer )
325
332
@@ -342,8 +349,8 @@ def _update_wrapped(self, data):
342
349
)
343
350
# _Artist has to go last for now because it is not (yet) MI friendly.
344
351
class MultiProxyWrapper (ProxyWrapperBase , _Artist ):
345
- _privtized_methods : Tuple [str , ...] = ()
346
- _wrapped_instances : Dict [str , _Aritst ]
352
+ _privtized_methods : tuple [str , ...] = ()
353
+ _wrapped_instances : dict [str , _Aritst ]
347
354
348
355
def __setattr__ (self , key , value ):
349
356
attrs = set (get_type_hints (type (self )))
@@ -369,7 +376,7 @@ class ErrorbarWrapper(MultiProxyWrapper):
369
376
expected_keys = {f"{ axis } { dirc } " for axis in ["x" , "y" ] for dirc in ["upper" , "lower" ]}
370
377
371
378
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
372
- super ().__init__ (data , nus )
379
+ super ().__init__ (data , nus , xunits = [ "x" , "xupper" , "xlower" ], yunits = [ "y" , "yupper" , "ylower" ] )
373
380
# TODO all of the kwarg teasing apart that is needed
374
381
color = kwargs .pop ("color" , "k" )
375
382
lw = kwargs .pop ("lw" , 2 )
@@ -396,9 +403,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
396
403
@_stale_wrapper
397
404
def draw (self , renderer ):
398
405
self ._update_wrapped (
399
- self ._query_and_transform (
400
- renderer , xunits = ["x" , "xupper" , "xlower" ], yunits = ["y" , "yupper" , "ylower" ]
401
- ),
406
+ self ._query_and_transform (renderer ),
402
407
)
403
408
for k , v in self ._wrapped_instances .items ():
404
409
v .draw (renderer )
0 commit comments