Skip to content

Commit 1cfaf74

Browse files
authored
Fix compatibility with Django (#3)
Fix compatibility with Django
2 parents b33b51e + 8160edf commit 1cfaf74

File tree

2 files changed

+73
-68
lines changed

2 files changed

+73
-68
lines changed

codewars_unittest/django.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55

66
class CodewarsDjangoRunner(DiscoverRunner):
77
def run_suite(self, suite, **kwargs):
8-
return CodewarsTestRunner().run(suite)
8+
return CodewarsTestRunner(group_by_module=True).run(suite)

codewars_unittest/test_runner.py

Lines changed: 72 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
2-
import inspect
32
import unittest
3+
from itertools import groupby
44

55
# Use timeit.default_timer for Python 2 compatibility.
66
# default_timer is time.perf_counter on 3.3+
@@ -14,58 +14,72 @@ def __init__(self, stream=None, group_by_module=False):
1414
if stream is None:
1515
stream = sys.stdout
1616
self.stream = _WritelnDecorator(stream)
17-
self.result = CodewarsTestResult(self.stream)
1817
self.group_by_module = group_by_module
18+
self.results = []
1919

2020
def run(self, test):
2121
if isinstance(test, unittest.TestSuite):
22-
self._run_each_test_cases(test)
23-
return self.result
22+
self._run_modules(_to_tree(_flatten(test)))
2423
else:
25-
return self._run_case(test)
26-
27-
def _run_each_test_cases(self, suite):
28-
if not isinstance(suite, unittest.TestSuite):
29-
return
30-
31-
for test in suite:
32-
if _is_test_module(test):
33-
name = ""
34-
if self.group_by_module:
35-
case = _get_test_case(test)
36-
name = _get_module_name(case)
37-
if name:
38-
self.stream.writeln(_group(name))
24+
self._run_case(test)
25+
return self._make_result()
26+
27+
def _make_result(self):
28+
accum = unittest.TestResult()
29+
for result in self.results:
30+
accum.failures.extend(result.failures)
31+
accum.errors.extend(result.errors)
32+
accum.testsRun += result.testsRun
33+
return accum
34+
35+
def _run_modules(self, modules):
36+
for mod in modules:
37+
name = ""
38+
if self.group_by_module:
39+
name = mod.group_name
40+
# Don't group on ImportError
41+
if name == "unittest.loader":
42+
name = ""
43+
if name:
44+
self.stream.writeln(_group(name))
3945

40-
startTime = perf_counter()
41-
for cases in test:
42-
self._run_cases(cases)
46+
startTime = perf_counter()
47+
for cases in mod:
48+
self._run_cases(cases)
4349

44-
if name:
45-
self.stream.writeln(_completedin(startTime, perf_counter()))
46-
else:
47-
self._run_each_test_cases(test)
50+
if name:
51+
self.stream.writeln(_completedin(startTime, perf_counter()))
4852

4953
def _run_cases(self, test):
50-
case = next(iter(test), None)
51-
if not case:
52-
return self.result
53-
54-
self.stream.writeln(_group(_get_class_name(case)))
54+
name = test.group_name
55+
# Don't group when errored before running tests, e.g., ImportError
56+
if name == "_FailedTest":
57+
name = ""
58+
if name:
59+
self.stream.writeln(_group(name))
5560
startTime = perf_counter()
61+
result = CodewarsTestResult(self.stream)
5662
try:
57-
test(self.result)
63+
test(result)
5864
finally:
5965
pass
60-
self.stream.writeln(_completedin(startTime, perf_counter()))
61-
return self.result
66+
if name:
67+
self.stream.writeln(_completedin(startTime, perf_counter()))
68+
self.results.append(result)
6269

6370
def _run_case(self, test):
71+
result = CodewarsTestResult(self.stream)
6472
try:
65-
test(self.result)
73+
test(result)
6674
finally:
6775
pass
68-
return self.result
76+
self.results.append(result)
77+
78+
79+
class _NamedTestSuite(unittest.TestSuite):
80+
def __init__(self, tests=(), group_name=None):
81+
super(_NamedTestSuite, self).__init__(tests)
82+
self.group_name = group_name
6983

7084

7185
def _group(name):
@@ -76,46 +90,37 @@ def _completedin(start, end):
7690
return "\n<COMPLETEDIN::>{:.4f}".format(1000 * (end - start))
7791

7892

79-
# True if test suite directly contains a test case
80-
def _is_test_cases(suite):
81-
return isinstance(suite, unittest.TestSuite) and any(
82-
isinstance(t, unittest.TestCase) for t in suite
83-
)
84-
85-
86-
# True if test suite directly contains test cases
87-
def _is_test_module(suite):
88-
return isinstance(suite, unittest.TestSuite) and any(
89-
_is_test_cases(t) for t in suite
90-
)
93+
# Flatten nested TestSuite by collecting all test cases.
94+
def _flatten(suites):
95+
tests = []
96+
for test in suites:
97+
if isinstance(test, unittest.TestSuite):
98+
tests.extend(_flatten(test))
99+
else:
100+
tests.append(test)
101+
return tests
91102

92103

93-
# Get first test case from a TestSuite created from a test module to find module name
94-
def _get_test_case(suite):
95-
if not isinstance(suite, unittest.TestSuite):
96-
return None
97-
for test in suite:
98-
if not isinstance(test, unittest.TestSuite):
99-
continue
100-
for t in test:
101-
if isinstance(t, unittest.TestCase):
102-
return t
103-
return None
104+
# Group by module name and then by class name
105+
def _to_tree(suite):
106+
tree = unittest.TestSuite()
107+
for k, ms in groupby(suite, _module_name):
108+
sub_trees = _NamedTestSuite(group_name=k)
109+
for c, cs in groupby(ms, _class_name):
110+
sub_trees.addTest(_NamedTestSuite(tests=cs, group_name=c))
111+
tree.addTest(sub_trees)
112+
return tree
104113

105114

106-
def _get_class_name(x):
107-
cls = x if inspect.isclass(x) else x.__class__
108-
return cls.__name__
115+
def _module_name(x):
116+
return x.__class__.__module__
109117

110118

111-
def _get_module_name(x):
112-
cls = x if inspect.isclass(x) else x.__class__
113-
mod = cls.__module__
114-
if mod is None or mod == str.__class__.__module__:
115-
return ""
116-
return mod
119+
def _class_name(x):
120+
return x.__class__.__name__
117121

118122

123+
# https://github.com/python/cpython/blob/289f1f80ee87a4baf4567a86b3425fb3bf73291d/Lib/unittest/runner.py#L13
119124
class _WritelnDecorator(object):
120125
"""Used to decorate file-like objects with a handy 'writeln' method"""
121126

0 commit comments

Comments
 (0)