Skip to content

Commit af96b02

Browse files
committed
Use a wrapper to intercept and store the figure
Rather than overriding the low level `pytest_pyfunc_call`, we can wrap the test function in a wrapper that stores its return value to the plugin object.
1 parent 4e5981a commit af96b02

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

pytest_mpl/plugin.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,6 @@ def pathify(path):
8181
return Path(path + ext)
8282

8383

84-
def _pytest_pyfunc_call(obj, pyfuncitem):
85-
testfunction = pyfuncitem.obj
86-
funcargs = pyfuncitem.funcargs
87-
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
88-
obj.result = testfunction(**testargs)
89-
return True
90-
91-
9284
def generate_test_name(item):
9385
"""
9486
Generate a unique name for the hash for this test.
@@ -100,6 +92,24 @@ def generate_test_name(item):
10092
return name
10193

10294

95+
def wrap_figure_interceptor(plugin, item):
96+
"""
97+
Intercept and store figures returned by test functions.
98+
"""
99+
# Only intercept figures on marked figure tests
100+
if get_compare(item) is not None:
101+
102+
# Use the full test name as a key to ensure correct figure is being retrieved
103+
test_name = generate_test_name(item)
104+
105+
def figure_interceptor(store, obj):
106+
def wrapper(*args, **kwargs):
107+
store.return_value[test_name] = obj(*args, **kwargs)
108+
return wrapper
109+
110+
item.obj = figure_interceptor(plugin, item.obj)
111+
112+
103113
def pytest_report_header(config, startdir):
104114
import matplotlib
105115
import matplotlib.ft2font
@@ -286,6 +296,7 @@ def __init__(self,
286296
self._generated_hash_library = {}
287297
self._test_results = {}
288298
self._test_stats = None
299+
self.return_value = {}
289300

290301
# https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
291302
# turn debug prints on only if "-vv" or more passed
@@ -608,13 +619,14 @@ def pytest_runtest_call(self, item): # noqa
608619
with plt.style.context(style, after_reset=True), switch_backend(backend):
609620

610621
# Run test and get figure object
622+
wrap_figure_interceptor(self, item)
611623
yield
612-
fig = self.result
624+
test_name = generate_test_name(item)
625+
fig = self.return_value[test_name]
613626

614627
if remove_text:
615628
remove_ticks_and_titles(fig)
616629

617-
test_name = generate_test_name(item)
618630
result_dir = self.make_test_results_dir(item)
619631

620632
summary = {
@@ -678,10 +690,6 @@ def pytest_runtest_call(self, item): # noqa
678690
if summary['status'] == 'skipped':
679691
pytest.skip(summary['status_msg'])
680692

681-
@pytest.hookimpl(tryfirst=True)
682-
def pytest_pyfunc_call(self, pyfuncitem):
683-
return _pytest_pyfunc_call(self, pyfuncitem)
684-
685693
def generate_summary_json(self):
686694
json_file = self.results_dir / 'results.json'
687695
with open(json_file, 'w') as f:
@@ -733,13 +741,13 @@ class FigureCloser:
733741

734742
def __init__(self, config):
735743
self.config = config
744+
self.return_value = {}
736745

737746
@pytest.hookimpl(hookwrapper=True)
738747
def pytest_runtest_call(self, item):
748+
wrap_figure_interceptor(self, item)
739749
yield
740750
if get_compare(item) is not None:
741-
close_mpl_figure(self.result)
742-
743-
@pytest.hookimpl(tryfirst=True)
744-
def pytest_pyfunc_call(self, pyfuncitem):
745-
return _pytest_pyfunc_call(self, pyfuncitem)
751+
test_name = generate_test_name(item)
752+
fig = self.return_value[test_name]
753+
close_mpl_figure(fig)

0 commit comments

Comments
 (0)