Skip to content

Commit 3b59b06

Browse files
committed
Create MypyResults
1 parent 0b7ca81 commit 3b59b06

File tree

2 files changed

+93
-70
lines changed

2 files changed

+93
-70
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def read(fname):
3535
'setuptools-scm>=3.5',
3636
],
3737
install_requires=[
38+
'attrs>=19.0',
3839
'filelock>=3.0',
3940
'pytest>=3.5',
4041
'mypy>=0.500; python_version<"3.8"',

src/pytest_mypy.py

Lines changed: 92 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Mypy static type checker plugin for Pytest"""
22

3-
import functools
43
import json
54
import os
65
from tempfile import NamedTemporaryFile
6+
from typing import Dict, List, Optional, TextIO
77

8+
import attr
89
from filelock import FileLock # type: ignore
910
import mypy.api
1011
import pytest # type: ignore
@@ -178,9 +179,9 @@ class MypyFileItem(MypyItem):
178179

179180
def runtest(self):
180181
"""Raise an exception if mypy found errors for this item."""
181-
results = _mypy_results(self.session)
182+
results = MypyResults.from_session(self.session)
182183
abspath = os.path.abspath(str(self.fspath))
183-
errors = results['abspath_errors'].get(abspath)
184+
errors = results.abspath_errors.get(abspath)
184185
if errors:
185186
raise MypyError(file_error_formatter(self, results, errors))
186187

@@ -199,76 +200,96 @@ class MypyStatusItem(MypyItem):
199200

200201
def runtest(self):
201202
"""Raise a MypyError if mypy exited with a non-zero status."""
202-
results = _mypy_results(self.session)
203-
if results['status']:
203+
results = MypyResults.from_session(self.session)
204+
if results.status:
204205
raise MypyError(
205206
'mypy exited with status {status}.'.format(
206-
status=results['status'],
207+
status=results.status,
207208
),
208209
)
209210

210211

211-
def _mypy_results(session):
212-
"""Get the cached mypy results for the session, or generate them."""
213-
return _cached_json_results(
214-
results_path=(
212+
@attr.s(frozen=True, kw_only=True)
213+
class MypyResults:
214+
215+
"""Parsed results from Mypy."""
216+
217+
_abspath_errors_type = Dict[str, List[str]]
218+
219+
opts = attr.ib(type=List[str])
220+
stdout = attr.ib(type=str)
221+
stderr = attr.ib(type=str)
222+
status = attr.ib(type=int)
223+
abspath_errors = attr.ib(type=_abspath_errors_type)
224+
unmatched_stdout = attr.ib(type=str)
225+
226+
def dump(self, results_f: TextIO) -> None:
227+
"""Cache results in a format that can be parsed by load()."""
228+
return json.dump(vars(self), results_f)
229+
230+
@classmethod
231+
def load(cls, results_f: TextIO) -> 'MypyResults':
232+
"""Get results cached by dump()."""
233+
return cls(**json.load(results_f))
234+
235+
@classmethod
236+
def from_mypy(
237+
cls,
238+
items: List[MypyFileItem],
239+
*,
240+
opts: Optional[List[str]] = None
241+
) -> 'MypyResults':
242+
"""Generate results from mypy."""
243+
244+
if opts is None:
245+
opts = mypy_argv[:]
246+
abspath_errors = {
247+
os.path.abspath(str(item.fspath)): []
248+
for item in items
249+
} # type: MypyResults._abspath_errors_type
250+
251+
stdout, stderr, status = mypy.api.run(opts + list(abspath_errors))
252+
253+
unmatched_lines = []
254+
for line in stdout.split('\n'):
255+
if not line:
256+
continue
257+
path, _, error = line.partition(':')
258+
abspath = os.path.abspath(path)
259+
try:
260+
abspath_errors[abspath].append(error)
261+
except KeyError:
262+
unmatched_lines.append(line)
263+
264+
return cls(
265+
opts=opts,
266+
stdout=stdout,
267+
stderr=stderr,
268+
status=status,
269+
abspath_errors=abspath_errors,
270+
unmatched_stdout='\n'.join(unmatched_lines),
271+
)
272+
273+
@classmethod
274+
def from_session(cls, session) -> 'MypyResults':
275+
"""Load (or generate) cached mypy results for a pytest session."""
276+
results_path = (
215277
session.config._mypy_results_path
216278
if _is_master(session.config) else
217279
_get_xdist_workerinput(session.config)['_mypy_results_path']
218-
),
219-
results_factory=functools.partial(
220-
_mypy_results_factory,
221-
abspaths=[
222-
os.path.abspath(str(item.fspath))
223-
for item in session.items
224-
if isinstance(item, MypyFileItem)
225-
],
226280
)
227-
)
228-
229-
230-
def _cached_json_results(results_path, results_factory=None):
231-
"""
232-
Read results from results_path if it exists;
233-
otherwise, produce them with results_factory,
234-
and write them to results_path.
235-
"""
236-
with FileLock(results_path + '.lock'):
237-
try:
238-
with open(results_path, mode='r') as results_f:
239-
results = json.load(results_f)
240-
except FileNotFoundError:
241-
if not results_factory:
242-
raise
243-
results = results_factory()
244-
with open(results_path, mode='w') as results_f:
245-
json.dump(results, results_f)
246-
return results
247-
248-
249-
def _mypy_results_factory(abspaths):
250-
"""Run mypy on abspaths and return the results as a JSON-able dict."""
251-
252-
stdout, stderr, status = mypy.api.run(mypy_argv + abspaths)
253-
254-
abspath_errors, unmatched_lines = {}, []
255-
for line in stdout.split('\n'):
256-
if not line:
257-
continue
258-
path, _, error = line.partition(':')
259-
abspath = os.path.abspath(path)
260-
if abspath in abspaths:
261-
abspath_errors[abspath] = abspath_errors.get(abspath, []) + [error]
262-
else:
263-
unmatched_lines.append(line)
264-
265-
return {
266-
'stdout': stdout,
267-
'stderr': stderr,
268-
'status': status,
269-
'abspath_errors': abspath_errors,
270-
'unmatched_stdout': '\n'.join(unmatched_lines),
271-
}
281+
with FileLock(results_path + '.lock'):
282+
try:
283+
with open(results_path, mode='r') as results_f:
284+
results = cls.load(results_f)
285+
except FileNotFoundError:
286+
results = cls.from_mypy([
287+
item for item in session.items
288+
if isinstance(item, MypyFileItem)
289+
])
290+
with open(results_path, mode='w') as results_f:
291+
results.dump(results_f)
292+
return results
272293

273294

274295
class MypyError(Exception):
@@ -282,15 +303,16 @@ def pytest_terminal_summary(terminalreporter):
282303
"""Report stderr and unrecognized lines from stdout."""
283304
config = _pytest_terminal_summary_config
284305
try:
285-
results = _cached_json_results(config._mypy_results_path)
306+
with open(config._mypy_results_path, mode='r') as results_f:
307+
results = MypyResults.load(results_f)
286308
except FileNotFoundError:
287309
# No MypyItems executed.
288310
return
289-
if results['unmatched_stdout'] or results['stderr']:
311+
if results.unmatched_stdout or results.stderr:
290312
terminalreporter.section('mypy')
291-
if results['unmatched_stdout']:
292-
color = {'red': True} if results['status'] else {'green': True}
293-
terminalreporter.write_line(results['unmatched_stdout'], **color)
294-
if results['stderr']:
295-
terminalreporter.write_line(results['stderr'], yellow=True)
313+
if results.unmatched_stdout:
314+
color = {'red': True} if results.status else {'green': True}
315+
terminalreporter.write_line(results.unmatched_stdout, **color)
316+
if results.stderr:
317+
terminalreporter.write_line(results.stderr, yellow=True)
296318
os.remove(config._mypy_results_path)

0 commit comments

Comments
 (0)