Skip to content

Commit 5d0586b

Browse files
committed
Directly compare two returned figures from a test function
1 parent 05c71b9 commit 5d0586b

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

pytest_mpl/plugin.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,10 @@ def item_function_wrapper(*args, **kwargs):
272272
fig = original(*args, **kwargs)
273273

274274
if remove_text:
275-
remove_ticks_and_titles(fig)
275+
if not isinstance(fig, tuple):
276+
remove_ticks_and_titles(fig)
277+
else:
278+
[remove_ticks_and_titles(f) for f in fig]
276279

277280
# Find test name to use as plot name
278281
filename = compare.kwargs.get('filename', None)
@@ -286,33 +289,40 @@ def item_function_wrapper(*args, **kwargs):
286289
# reference images or simply running the test.
287290
if self.generate_dir is None:
288291

289-
# Save the figure
292+
# Save the figure(s)
290293
result_dir = tempfile.mkdtemp(dir=self.results_dir)
291294
test_image = os.path.abspath(os.path.join(result_dir, filename))
295+
baseline_image = os.path.abspath(os.path.join(result_dir,
296+
'baseline-' + filename))
292297

293-
fig.savefig(test_image, **savefig_kwargs)
294-
close_mpl_figure(fig)
298+
if not isinstance(fig, tuple):
299+
fig.savefig(test_image, **savefig_kwargs)
300+
close_mpl_figure(fig)
301+
302+
# Find path to baseline image
303+
if baseline_remote:
304+
baseline_image_ref = _download_file(baseline_dir, filename)
305+
else:
306+
baseline_image_ref = os.path.abspath(os.path.join(
307+
os.path.dirname(item.fspath.strpath), baseline_dir, filename))
308+
309+
if not os.path.exists(baseline_image_ref):
310+
pytest.fail("Image file not found for comparison test in: "
311+
"\n\t{baseline_dir}"
312+
"\n(This is expected for new tests.)\nGenerated Image: "
313+
"\n\t{test}".format(baseline_dir=baseline_dir,
314+
test=test_image),
315+
pytrace=False)
316+
317+
# distutils may put the baseline images in non-accessible places,
318+
# copy to our tmpdir to be sure to keep them in case of failure
319+
shutil.copyfile(baseline_image_ref, baseline_image)
295320

296-
# Find path to baseline image
297-
if baseline_remote:
298-
baseline_image_ref = _download_file(baseline_dir, filename)
299321
else:
300-
baseline_image_ref = os.path.abspath(os.path.join(
301-
os.path.dirname(item.fspath.strpath), baseline_dir, filename))
302-
303-
if not os.path.exists(baseline_image_ref):
304-
pytest.fail("Image file not found for comparison test in: "
305-
"\n\t{baseline_dir}"
306-
"\n(This is expected for new tests.)\nGenerated Image: "
307-
"\n\t{test}".format(baseline_dir=baseline_dir,
308-
test=test_image),
309-
pytrace=False)
310-
311-
# distutils may put the baseline images in non-accessible places,
312-
# copy to our tmpdir to be sure to keep them in case of failure
313-
baseline_image = os.path.abspath(os.path.join(result_dir,
314-
'baseline-' + filename))
315-
shutil.copyfile(baseline_image_ref, baseline_image)
322+
fig[0].savefig(test_image, **savefig_kwargs)
323+
close_mpl_figure(fig[0])
324+
fig[1].savefig(baseline_image, **savefig_kwargs)
325+
close_mpl_figure(fig[1])
316326

317327
_raise_on_image_difference(
318328
expected=baseline_image,

tests/test_pytest_mpl.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,14 @@ def test_succeeds(self):
227227
ax = fig.add_subplot(1, 1, 1)
228228
ax.plot(self.x)
229229
return fig
230+
231+
232+
@pytest.mark.mpl_image_compare
233+
def test_check_equal():
234+
fig_test = plt.figure()
235+
fig_test.subplots().plot([1, 3, 5])
236+
237+
fig_ref = plt.figure()
238+
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
239+
240+
return fig_test, fig_ref

0 commit comments

Comments
 (0)