Skip to content

Fix tests which exit before returning a figure or use unittest.TestCase #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 45 additions & 30 deletions pytest_mpl/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,33 @@ def pathify(path):
return Path(path + ext)


def _pytest_pyfunc_call(obj, pyfuncitem):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
obj.result = testfunction(**testargs)
return True
def generate_test_name(item):
"""
Generate a unique name for the hash for this test.
"""
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name


def wrap_figure_interceptor(plugin, item):
"""
Intercept and store figures returned by test functions.
"""
# Only intercept figures on marked figure tests
if get_compare(item) is not None:

# Use the full test name as a key to ensure correct figure is being retrieved
test_name = generate_test_name(item)

def figure_interceptor(store, obj):
def wrapper(*args, **kwargs):
store.return_value[test_name] = obj(*args, **kwargs)
return wrapper

item.obj = figure_interceptor(plugin, item.obj)


def pytest_report_header(config, startdir):
Expand Down Expand Up @@ -275,6 +296,7 @@ def __init__(self,
self._generated_hash_library = {}
self._test_results = {}
self._test_stats = None
self.return_value = {}

# https://stackoverflow.com/questions/51737378/how-should-i-log-in-my-pytest-plugin
# turn debug prints on only if "-vv" or more passed
Expand All @@ -287,7 +309,7 @@ def generate_filename(self, item):
Given a pytest item, generate the figure filename.
"""
if self.config.getini('mpl-use-full-test-name'):
filename = self.generate_test_name(item) + '.png'
filename = generate_test_name(item) + '.png'
else:
compare = get_compare(item)
# Find test name to use as plot name
Expand All @@ -298,21 +320,11 @@ def generate_filename(self, item):
filename = str(pathify(filename))
return filename

def generate_test_name(self, item):
"""
Generate a unique name for the hash for this test.
"""
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name

def make_test_results_dir(self, item):
"""
Generate the directory to put the results in.
"""
test_name = pathify(self.generate_test_name(item))
test_name = pathify(generate_test_name(item))
results_dir = self.results_dir / test_name
results_dir.mkdir(exist_ok=True, parents=True)
return results_dir
Expand Down Expand Up @@ -526,7 +538,7 @@ def compare_image_to_hash_library(self, item, fig, result_dir, summary=None):
pytest.fail(f"Can't find hash library at path {hash_library_filename}")

hash_library = self.load_hash_library(hash_library_filename)
hash_name = self.generate_test_name(item)
hash_name = generate_test_name(item)
baseline_hash = hash_library.get(hash_name, None)
summary['baseline_hash'] = baseline_hash

Expand Down Expand Up @@ -607,13 +619,17 @@ def pytest_runtest_call(self, item): # noqa
with plt.style.context(style, after_reset=True), switch_backend(backend):

# Run test and get figure object
wrap_figure_interceptor(self, item)
yield
fig = self.result
test_name = generate_test_name(item)
if test_name not in self.return_value:
# Test function did not complete successfully
return
fig = self.return_value[test_name]

if remove_text:
remove_ticks_and_titles(fig)

test_name = self.generate_test_name(item)
result_dir = self.make_test_results_dir(item)

summary = {
Expand Down Expand Up @@ -677,10 +693,6 @@ def pytest_runtest_call(self, item): # noqa
if summary['status'] == 'skipped':
pytest.skip(summary['status_msg'])

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
return _pytest_pyfunc_call(self, pyfuncitem)

def generate_summary_json(self):
json_file = self.results_dir / 'results.json'
with open(json_file, 'w') as f:
Expand Down Expand Up @@ -732,13 +744,16 @@ class FigureCloser:

def __init__(self, config):
self.config = config
self.return_value = {}

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):
wrap_figure_interceptor(self, item)
yield
if get_compare(item) is not None:
close_mpl_figure(self.result)

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
return _pytest_pyfunc_call(self, pyfuncitem)
test_name = generate_test_name(item)
if test_name not in self.return_value:
# Test function did not complete successfully
return
fig = self.return_value[test_name]
close_mpl_figure(fig)
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ test =

[tool:pytest]
testpaths = "tests"
markers =
image: run test during image comparison only mode.
hash: run test during hash comparison only mode.
filterwarnings =
error
ignore:distutils Version classes are deprecated
ignore:the imp module is deprecated in favour of importlib

[flake8]
max-line-length = 100
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from packaging.version import Version

pytest_plugins = ["pytester"]

if Version(pytest.__version__) < Version("6.2.0"):
@pytest.fixture
def pytester(testdir):
return testdir
143 changes: 142 additions & 1 deletion tests/test_pytest_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import subprocess
from pathlib import Path
from unittest import TestCase

import matplotlib
import matplotlib.ft2font
Expand Down Expand Up @@ -259,6 +260,23 @@ def test_succeeds(self):
return fig


class TestClassWithTestCase(TestCase):

# Regression test for a bug that occurred when using unittest.TestCase

def setUp(self):
self.x = [1, 2, 3]

@pytest.mark.mpl_image_compare(baseline_dir=baseline_dir_local,
filename='test_succeeds.png',
tolerance=DEFAULT_TOLERANCE)
def test_succeeds(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(self.x)
return fig


# hashlib

@pytest.mark.skipif(not hash_library.exists(), reason="No hash library for this mpl version")
Expand Down Expand Up @@ -514,8 +532,27 @@ def test_fails(self):
return fig
"""

TEST_FAILING_UNITTEST_TESTCASE = """
from unittest import TestCase
import pytest
import matplotlib.pyplot as plt
class TestClassWithTestCase(TestCase):
def setUp(self):
self.x = [1, 2, 3]
@pytest.mark.mpl_image_compare
def test_fails(self):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(self.x)
return fig
"""

@pytest.mark.parametrize("code", [TEST_FAILING_CLASS, TEST_FAILING_CLASS_SETUP_METHOD])

@pytest.mark.parametrize("code", [
TEST_FAILING_CLASS,
TEST_FAILING_CLASS_SETUP_METHOD,
TEST_FAILING_UNITTEST_TESTCASE,
])
def test_class_fail(code, tmpdir):

test_file = tmpdir.join('test.py').strpath
Expand All @@ -529,3 +566,107 @@ def test_class_fail(code, tmpdir):
# If we don't use --mpl option, the test should succeed
code = call_pytest([test_file])
assert code == 0


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_fail(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_fail():
pytest.fail("Manually failed by user.")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(failed=1)
result.stdout.fnmatch_lines("FAILED*Manually failed by user.*")


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_skip(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_skip():
pytest.skip("Manually skipped by user.")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(skipped=1)


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_importorskip(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_importorskip():
pytest.importorskip("nonexistantmodule")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(skipped=1)


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_xfail(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_xfail():
pytest.xfail()
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(xfailed=1)


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_exit_success(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_exit_success():
pytest.exit("Manually exited by user.", returncode=0)
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes()
assert result.ret == 0
result.stdout.fnmatch_lines("*Exit*Manually exited by user.*")


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_exit_failure(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_exit_fail():
pytest.exit("Manually exited by user.", returncode=1)
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes()
assert result.ret == 1
result.stdout.fnmatch_lines("*Exit*Manually exited by user.*")


@pytest.mark.parametrize("runpytest_args", [(), ("--mpl",)])
def test_user_function_raises(pytester, runpytest_args):
pytester.makepyfile(
"""
import pytest
@pytest.mark.mpl_image_compare
def test_raises():
raise ValueError("User code raised an exception.")
"""
)
result = pytester.runpytest(*runpytest_args)
result.assert_outcomes(failed=1)
result.stdout.fnmatch_lines("FAILED*ValueError*User code*")
5 changes: 0 additions & 5 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,3 @@ description = check code style, e.g. with flake8
deps = pre-commit
commands =
pre-commit run --all-files

[pytest]
markers =
image: run test during image comparison only mode.
hash: run test during hash comparison only mode.