12
12
from matplotlib import transforms , collections
13
13
from matplotlib .backends .backend_agg import FigureCanvasAgg
14
14
15
+
15
16
class Exporter (object ):
16
17
"""Matplotlib Exporter
17
18
@@ -44,15 +45,17 @@ def run(self, fig):
44
45
# in the correct place.
45
46
if fig .canvas is None :
46
47
canvas = FigureCanvasAgg (fig )
47
- fig .savefig (io .BytesIO (), format = ' png' , dpi = fig .dpi )
48
+ fig .savefig (io .BytesIO (), format = " png" , dpi = fig .dpi )
48
49
if self .close_mpl :
49
50
import matplotlib .pyplot as plt
51
+
50
52
plt .close (fig )
51
53
self .crawl_fig (fig )
52
54
53
55
@staticmethod
54
- def process_transform (transform , ax = None , data = None , return_trans = False ,
55
- force_trans = None ):
56
+ def process_transform (
57
+ transform , ax = None , data = None , return_trans = False , force_trans = None
58
+ ):
56
59
"""Process the transform and convert data to figure or data coordinates
57
60
58
61
Parameters
@@ -81,8 +84,10 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
81
84
Returned only if data is specified
82
85
"""
83
86
if isinstance (transform , transforms .BlendedGenericTransform ):
84
- warnings .warn ("Blended transforms not yet supported. "
85
- "Zoom behavior may not work as expected." )
87
+ warnings .warn (
88
+ "Blended transforms not yet supported. "
89
+ "Zoom behavior may not work as expected."
90
+ )
86
91
87
92
if force_trans is not None :
88
93
if data is not None :
@@ -91,10 +96,12 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
91
96
92
97
code = "display"
93
98
if ax is not None :
94
- for (c , trans ) in [("data" , ax .transData ),
95
- ("axes" , ax .transAxes ),
96
- ("figure" , ax .figure .transFigure ),
97
- ("display" , transforms .IdentityTransform ())]:
99
+ for (c , trans ) in [
100
+ ("data" , ax .transData ),
101
+ ("axes" , ax .transAxes ),
102
+ ("figure" , ax .figure .transFigure ),
103
+ ("display" , transforms .IdentityTransform ()),
104
+ ]:
98
105
if transform .contains_branch (trans ):
99
106
code , transform = (c , transform - trans )
100
107
break
@@ -112,24 +119,23 @@ def process_transform(transform, ax=None, data=None, return_trans=False,
112
119
113
120
def crawl_fig (self , fig ):
114
121
"""Crawl the figure and process all axes"""
115
- with self .renderer .draw_figure (fig = fig ,
116
- props = utils .get_figure_properties (fig )):
122
+ with self .renderer .draw_figure (fig = fig , props = utils .get_figure_properties (fig )):
117
123
for ax in fig .axes :
118
124
self .crawl_ax (ax )
119
125
120
126
def crawl_ax (self , ax ):
121
127
"""Crawl the axes and process all elements within"""
122
- with self .renderer .draw_axes (ax = ax ,
123
- props = utils .get_axes_properties (ax )):
128
+ with self .renderer .draw_axes (ax = ax , props = utils .get_axes_properties (ax )):
124
129
for line in ax .lines :
125
130
self .draw_line (ax , line )
126
131
for text in ax .texts :
127
132
self .draw_text (ax , text )
128
- for (text , ttp ) in zip ([ax .xaxis .label , ax .yaxis .label , ax .title ],
129
- ["xlabel" , "ylabel" , "title" ]):
130
- if (hasattr (text , 'get_text' ) and text .get_text ()):
131
- self .draw_text (ax , text , force_trans = ax .transAxes ,
132
- text_type = ttp )
133
+ for (text , ttp ) in zip (
134
+ [ax .xaxis .label , ax .yaxis .label , ax .title ],
135
+ ["xlabel" , "ylabel" , "title" ],
136
+ ):
137
+ if hasattr (text , "get_text" ) and text .get_text ():
138
+ self .draw_text (ax , text , force_trans = ax .transAxes , text_type = ttp )
133
139
for artist in ax .artists :
134
140
# TODO: process other artists
135
141
if isinstance (artist , matplotlib .text .Text ):
@@ -145,107 +151,122 @@ def crawl_ax(self, ax):
145
151
if legend is not None :
146
152
props = utils .get_legend_properties (ax , legend )
147
153
with self .renderer .draw_legend (legend = legend , props = props ):
148
- if props [' visible' ]:
154
+ if props [" visible" ]:
149
155
self .crawl_legend (ax , legend )
150
156
151
157
def crawl_legend (self , ax , legend ):
152
158
"""
153
159
Recursively look through objects in legend children
154
160
"""
155
- legendElements = list (utils .iter_all_children (legend ._legend_box ,
156
- skipContainers = True ))
161
+ legendElements = list (
162
+ utils .iter_all_children (legend ._legend_box , skipContainers = True )
163
+ )
157
164
legendElements .append (legend .legendPatch )
158
165
for child in legendElements :
159
166
# force a large zorder so it appears on top
160
- child .set_zorder (1E6 + child .get_zorder ())
167
+ child .set_zorder (1e6 + child .get_zorder ())
161
168
162
169
# reorder border box to make sure marks are visible
163
170
if isinstance (child , matplotlib .patches .FancyBboxPatch ):
164
- child .set_zorder (child .get_zorder ()- 1 )
171
+ child .set_zorder (child .get_zorder () - 1 )
165
172
166
173
try :
167
174
# What kind of object...
168
175
if isinstance (child , matplotlib .patches .Patch ):
169
176
self .draw_patch (ax , child , force_trans = ax .transAxes )
170
177
elif isinstance (child , matplotlib .text .Text ):
171
- if child .get_text () != ' None' :
178
+ if child .get_text () != " None" :
172
179
self .draw_text (ax , child , force_trans = ax .transAxes )
173
180
elif isinstance (child , matplotlib .lines .Line2D ):
174
181
self .draw_line (ax , child , force_trans = ax .transAxes )
175
182
elif isinstance (child , matplotlib .collections .Collection ):
176
- self .draw_collection (ax , child ,
177
- force_pathtrans = ax .transAxes )
183
+ self .draw_collection (ax , child , force_pathtrans = ax .transAxes )
178
184
else :
179
185
warnings .warn ("Legend element %s not impemented" % child )
180
186
except NotImplementedError :
181
187
warnings .warn ("Legend element %s not impemented" % child )
182
188
183
189
def draw_line (self , ax , line , force_trans = None ):
184
190
"""Process a matplotlib line and call renderer.draw_line"""
185
- coordinates , data = self .process_transform (line . get_transform (),
186
- ax , line .get_xydata (),
187
- force_trans = force_trans )
191
+ coordinates , data = self .process_transform (
192
+ line . get_transform (), ax , line .get_xydata (), force_trans = force_trans
193
+ )
188
194
linestyle = utils .get_line_style (line )
189
- if (linestyle ['dasharray' ] is None
190
- and linestyle ['drawstyle' ] == 'default' ):
195
+ if linestyle ["dasharray" ] is None and linestyle ["drawstyle" ] == "default" :
191
196
linestyle = None
192
197
markerstyle = utils .get_marker_style (line )
193
- if (markerstyle ['marker' ] in ['None' , 'none' , None ]
194
- or markerstyle ['markerpath' ][0 ].size == 0 ):
198
+ if (
199
+ markerstyle ["marker" ] in ["None" , "none" , None ]
200
+ or markerstyle ["markerpath" ][0 ].size == 0
201
+ ):
195
202
markerstyle = None
196
203
label = line .get_label ()
197
204
if markerstyle or linestyle :
198
- self .renderer .draw_marked_line (data = data , coordinates = coordinates ,
199
- linestyle = linestyle ,
200
- markerstyle = markerstyle ,
201
- label = label ,
202
- mplobj = line )
205
+ self .renderer .draw_marked_line (
206
+ data = data ,
207
+ coordinates = coordinates ,
208
+ linestyle = linestyle ,
209
+ markerstyle = markerstyle ,
210
+ label = label ,
211
+ mplobj = line ,
212
+ )
203
213
204
214
def draw_text (self , ax , text , force_trans = None , text_type = None ):
205
215
"""Process a matplotlib text object and call renderer.draw_text"""
206
216
content = text .get_text ()
207
217
if content :
208
218
transform = text .get_transform ()
209
219
position = text .get_position ()
210
- coords , position = self .process_transform (transform , ax ,
211
- position ,
212
- force_trans = force_trans )
220
+ coords , position = self .process_transform (
221
+ transform , ax , position , force_trans = force_trans
222
+ )
213
223
style = utils .get_text_style (text )
214
- self .renderer .draw_text (text = content , position = position ,
215
- coordinates = coords ,
216
- text_type = text_type ,
217
- style = style , mplobj = text )
224
+ self .renderer .draw_text (
225
+ text = content ,
226
+ position = position ,
227
+ coordinates = coords ,
228
+ text_type = text_type ,
229
+ style = style ,
230
+ mplobj = text ,
231
+ )
218
232
219
233
def draw_patch (self , ax , patch , force_trans = None ):
220
234
"""Process a matplotlib patch object and call renderer.draw_path"""
221
235
vertices , pathcodes = utils .SVG_path (patch .get_path ())
222
236
transform = patch .get_transform ()
223
- coordinates , vertices = self .process_transform (transform ,
224
- ax , vertices ,
225
- force_trans = force_trans )
237
+ coordinates , vertices = self .process_transform (
238
+ transform , ax , vertices , force_trans = force_trans
239
+ )
226
240
linestyle = utils .get_path_style (patch , fill = patch .get_fill ())
227
- self .renderer .draw_path (data = vertices ,
228
- coordinates = coordinates ,
229
- pathcodes = pathcodes ,
230
- style = linestyle ,
231
- mplobj = patch )
241
+ self .renderer .draw_path (
242
+ data = vertices ,
243
+ coordinates = coordinates ,
244
+ pathcodes = pathcodes ,
245
+ style = linestyle ,
246
+ mplobj = patch ,
247
+ )
232
248
233
- def draw_collection (self , ax , collection ,
234
- force_pathtrans = None ,
235
- force_offsettrans = None ):
249
+ def draw_collection (
250
+ self , ax , collection , force_pathtrans = None , force_offsettrans = None
251
+ ):
236
252
"""Process a matplotlib collection and call renderer.draw_collection"""
237
- (transform , transOffset ,
238
- offsets , paths ) = collection ._prepare_points ()
253
+ (transform , transOffset , offsets , paths ) = collection ._prepare_points ()
239
254
240
255
offset_coords , offsets = self .process_transform (
241
- transOffset , ax , offsets , force_trans = force_offsettrans )
242
- path_coords = self . process_transform (
243
- transform , ax , force_trans = force_pathtrans )
256
+ transOffset , ax , offsets , force_trans = force_offsettrans
257
+ )
258
+ path_coords = self . process_transform ( transform , ax , force_trans = force_pathtrans )
244
259
245
260
processed_paths = [utils .SVG_path (path ) for path in paths ]
246
- processed_paths = [(self .process_transform (
247
- transform , ax , path [0 ], force_trans = force_pathtrans )[1 ], path [1 ])
248
- for path in processed_paths ]
261
+ processed_paths = [
262
+ (
263
+ self .process_transform (
264
+ transform , ax , path [0 ], force_trans = force_pathtrans
265
+ )[1 ],
266
+ path [1 ],
267
+ )
268
+ for path in processed_paths
269
+ ]
249
270
250
271
path_transforms = collection .get_transforms ()
251
272
try :
@@ -256,30 +277,34 @@ def draw_collection(self, ax, collection,
256
277
# matplotlib 1.4: path transforms are already numpy arrays.
257
278
pass
258
279
259
- styles = {'linewidth' : collection .get_linewidths (),
260
- 'facecolor' : collection .get_facecolors (),
261
- 'edgecolor' : collection .get_edgecolors (),
262
- 'alpha' : collection ._alpha ,
263
- 'zorder' : collection .get_zorder ()}
280
+ styles = {
281
+ "linewidth" : collection .get_linewidths (),
282
+ "facecolor" : collection .get_facecolors (),
283
+ "edgecolor" : collection .get_edgecolors (),
284
+ "alpha" : collection ._alpha ,
285
+ "zorder" : collection .get_zorder (),
286
+ }
264
287
265
- offset_dict = {"data" : "before" ,
266
- "screen" : "after" }
288
+ offset_dict = {"data" : "before" , "screen" : "after" }
267
289
offset_order = offset_dict [collection .get_offset_position ()]
268
290
269
- self .renderer .draw_path_collection (paths = processed_paths ,
270
- path_coordinates = path_coords ,
271
- path_transforms = path_transforms ,
272
- offsets = offsets ,
273
- offset_coordinates = offset_coords ,
274
- offset_order = offset_order ,
275
- styles = styles ,
276
- mplobj = collection )
291
+ self .renderer .draw_path_collection (
292
+ paths = processed_paths ,
293
+ path_coordinates = path_coords ,
294
+ path_transforms = path_transforms ,
295
+ offsets = offsets ,
296
+ offset_coordinates = offset_coords ,
297
+ offset_order = offset_order ,
298
+ styles = styles ,
299
+ mplobj = collection ,
300
+ )
277
301
278
302
def draw_image (self , ax , image ):
279
303
"""Process a matplotlib image object and call renderer.draw_image"""
280
- self .renderer .draw_image (imdata = utils .image_to_base64 (image ),
281
- extent = image .get_extent (),
282
- coordinates = "data" ,
283
- style = {"alpha" : image .get_alpha (),
284
- "zorder" : image .get_zorder ()},
285
- mplobj = image )
304
+ self .renderer .draw_image (
305
+ imdata = utils .image_to_base64 (image ),
306
+ extent = image .get_extent (),
307
+ coordinates = "data" ,
308
+ style = {"alpha" : image .get_alpha (), "zorder" : image .get_zorder ()},
309
+ mplobj = image ,
310
+ )
0 commit comments