1
1
from typing import List , Dict , Any , Protocol , Tuple , get_type_hints
2
+ import inspect
2
3
3
4
import numpy as np
4
5
@@ -96,7 +97,7 @@ def draw(self, renderer):
96
97
def _update_wrapped (self , data ):
97
98
raise NotImplementedError
98
99
99
- def _query_and_transform (self , renderer , * , xunits : List [ str ], yunits : List [ str ] ) -> Dict [str , Any ]:
100
+ def _query_and_transform (self , renderer ) -> Dict [str , Any ]:
100
101
"""
101
102
Helper to centralize the data querying and python-side transforms
102
103
@@ -126,27 +127,39 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
126
127
return self ._cache [cache_key ]
127
128
except KeyError :
128
129
...
129
- # TODO decide if units go pre-nu or post-nu?
130
- for x_like in xunits :
131
- if x_like in data :
132
- data [x_like ] = ax .xaxis .convert_units (data [x_like ])
133
- for y_like in yunits :
134
- if y_like in data :
135
- data [y_like ] = ax .xaxis .convert_units (data [y_like ])
136
-
137
130
# doing the nu work here is nice because we can write it once, but we
138
131
# really want to push this computation down a layer
139
132
# TODO sort out how this interoperates with the transform stack
140
- data = {k : self .nus .get (k , lambda x : x )(v ) for k , v in data .items ()}
133
+ def linearize (funclist ):
134
+ if inspect .isfunction (funclist ):
135
+ return funclist
136
+ def ret (x ):
137
+ for func in funclist :
138
+ x = func (x )
139
+ return x
140
+ return ret
141
+ data = {k : linearize (self .nus .get (k , lambda x : x ))(v ) for k , v in data .items ()}
141
142
self ._cache [cache_key ] = data
142
143
return data
143
144
144
- def __init__ (self , data , nus , ** kwargs ):
145
+ def __init__ (self , data , nus , xunits : List [ str ] = [], yunits : List [ str ] = [], ** kwargs ):
145
146
super ().__init__ (** kwargs )
146
147
self .data = data
147
148
self ._cache = LFUCache (64 )
148
149
# TODO make sure mutating this will invalidate the cache!
149
150
self .nus = nus or {}
151
+ for field in xunits :
152
+ if field not in self .nus :
153
+ self .nus [field ] = []
154
+ if inspect .isfunction (self .nus [field ]):
155
+ self .nus [field ] = [self .nus [field ]]
156
+ self .nus [field ].append (lambda x : self .axes .xaxis .convert_units (x ))
157
+ for field in yunits :
158
+ if field not in self .nus :
159
+ self .nus [field ] = []
160
+ if inspect .isfunction (self .nus [field ]):
161
+ self .nus [field ] = [self .nus [field ]]
162
+ self .nus [field ].append (lambda y : self .axes .yaxis .convert_units (y ))
150
163
self .stale = True
151
164
152
165
@@ -176,13 +189,13 @@ class LineWrapper(ProxyWrapper):
176
189
_privtized_methods = ("set_xdata" , "set_ydata" , "set_data" , "get_xdata" , "get_ydata" , "get_data" )
177
190
178
191
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
179
- super ().__init__ (data , nus )
192
+ super ().__init__ (data , nus , xunits = [ "x" ], yunits = [ "y" ] )
180
193
self ._wrapped_instance = self ._wrapped_class ([], [], ** kwargs )
181
194
182
195
@_stale_wrapper
183
196
def draw (self , renderer ):
184
197
self ._update_wrapped (
185
- self ._query_and_transform (renderer , xunits = [ "x" ], yunits = [ "y" ] ),
198
+ self ._query_and_transform (renderer ),
186
199
)
187
200
return self ._wrapped_instance .draw (renderer )
188
201
@@ -204,14 +217,14 @@ def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwa
204
217
if norm is None :
205
218
raise ValueError ("not sure how to do autoscaling yet" )
206
219
nus ["image" ] = lambda image : cmap (norm (image ))
207
- super ().__init__ (data , nus )
220
+ super ().__init__ (data , nus , xunits = [ "xextent" ], yunits = [ "yextent" ] )
208
221
kwargs .setdefault ("origin" , "lower" )
209
222
self ._wrapped_instance = self ._wrapped_class (None , ** kwargs )
210
223
211
224
@_stale_wrapper
212
225
def draw (self , renderer ):
213
226
self ._update_wrapped (
214
- self ._query_and_transform (renderer , xunits = [ "xextent" ], yunits = [ "yextent" ] ),
227
+ self ._query_and_transform (renderer ),
215
228
)
216
229
return self ._wrapped_instance .draw (renderer )
217
230
@@ -225,13 +238,13 @@ class StepWrapper(ProxyWrapper):
225
238
_privtized_methods = () # ("set_data", "get_data")
226
239
227
240
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
228
- super ().__init__ (data , nus )
241
+ super ().__init__ (data , nus , xunits = [ "edges" ], yunits = [ "density" ] )
229
242
self ._wrapped_instance = self ._wrapped_class ([], [1 ], ** kwargs )
230
243
231
244
@_stale_wrapper
232
245
def draw (self , renderer ):
233
246
self ._update_wrapped (
234
- self ._query_and_transform (renderer , xunits = [ "edges" ], yunits = [ "density" ] ),
247
+ self ._query_and_transform (renderer ),
235
248
)
236
249
return self ._wrapped_instance .draw (renderer )
237
250
@@ -251,7 +264,7 @@ def __init__(self, data: DataContainer, format_func, nus=None, /, **kwargs):
251
264
@_stale_wrapper
252
265
def draw (self , renderer ):
253
266
self ._update_wrapped (
254
- self ._query_and_transform (renderer , xunits = [], yunits = [] ),
267
+ self ._query_and_transform (renderer ),
255
268
)
256
269
return self ._wrapped_instance .draw (renderer )
257
270
@@ -297,7 +310,7 @@ def get_children(self):
297
310
298
311
class ErrorbarWrapper (MultiProxyWrapper ):
299
312
def __init__ (self , data : DataContainer , nus = None , / , ** kwargs ):
300
- super ().__init__ (data , nus )
313
+ super ().__init__ (data , nus , xunits = [ "x" , "xupper" , "xlower" ], yunits = [ "y" , "yupper" , "ylower" ] )
301
314
# TODO all of the kwarg teasing apart that is needed
302
315
color = kwargs .pop ("color" , "k" )
303
316
lw = kwargs .pop ("lw" , 2 )
@@ -325,7 +338,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
325
338
def draw (self , renderer ):
326
339
self ._update_wrapped (
327
340
self ._query_and_transform (
328
- renderer , xunits = [ "x" , "xupper" , "xlower" ], yunits = [ "y" , "yupper" , "ylower" ]
341
+ renderer
329
342
),
330
343
)
331
344
for k , v in self ._wrapped_instances .items ():
0 commit comments