diff --git a/codewars_unittest/django.py b/codewars_unittest/django.py index 39df824..cc46942 100644 --- a/codewars_unittest/django.py +++ b/codewars_unittest/django.py @@ -5,4 +5,4 @@ class CodewarsDjangoRunner(DiscoverRunner): def run_suite(self, suite, **kwargs): - return CodewarsTestRunner().run(suite) + return CodewarsTestRunner(group_by_module=True).run(suite) diff --git a/codewars_unittest/test_runner.py b/codewars_unittest/test_runner.py index 416ec7c..e47c71b 100644 --- a/codewars_unittest/test_runner.py +++ b/codewars_unittest/test_runner.py @@ -1,6 +1,6 @@ import sys -import inspect import unittest +from itertools import groupby # Use timeit.default_timer for Python 2 compatibility. # default_timer is time.perf_counter on 3.3+ @@ -14,58 +14,72 @@ def __init__(self, stream=None, group_by_module=False): if stream is None: stream = sys.stdout self.stream = _WritelnDecorator(stream) - self.result = CodewarsTestResult(self.stream) self.group_by_module = group_by_module + self.results = [] def run(self, test): if isinstance(test, unittest.TestSuite): - self._run_each_test_cases(test) - return self.result + self._run_modules(_to_tree(_flatten(test))) else: - return self._run_case(test) - - def _run_each_test_cases(self, suite): - if not isinstance(suite, unittest.TestSuite): - return - - for test in suite: - if _is_test_module(test): - name = "" - if self.group_by_module: - case = _get_test_case(test) - name = _get_module_name(case) - if name: - self.stream.writeln(_group(name)) + self._run_case(test) + return self._make_result() + + def _make_result(self): + accum = unittest.TestResult() + for result in self.results: + accum.failures.extend(result.failures) + accum.errors.extend(result.errors) + accum.testsRun += result.testsRun + return accum + + def _run_modules(self, modules): + for mod in modules: + name = "" + if self.group_by_module: + name = mod.group_name + # Don't group on ImportError + if name == "unittest.loader": + name = "" + if name: + self.stream.writeln(_group(name)) - startTime = perf_counter() - for cases in test: - self._run_cases(cases) + startTime = perf_counter() + for cases in mod: + self._run_cases(cases) - if name: - self.stream.writeln(_completedin(startTime, perf_counter())) - else: - self._run_each_test_cases(test) + if name: + self.stream.writeln(_completedin(startTime, perf_counter())) def _run_cases(self, test): - case = next(iter(test), None) - if not case: - return self.result - - self.stream.writeln(_group(_get_class_name(case))) + name = test.group_name + # Don't group when errored before running tests, e.g., ImportError + if name == "_FailedTest": + name = "" + if name: + self.stream.writeln(_group(name)) startTime = perf_counter() + result = CodewarsTestResult(self.stream) try: - test(self.result) + test(result) finally: pass - self.stream.writeln(_completedin(startTime, perf_counter())) - return self.result + if name: + self.stream.writeln(_completedin(startTime, perf_counter())) + self.results.append(result) def _run_case(self, test): + result = CodewarsTestResult(self.stream) try: - test(self.result) + test(result) finally: pass - return self.result + self.results.append(result) + + +class _NamedTestSuite(unittest.TestSuite): + def __init__(self, tests=(), group_name=None): + super(_NamedTestSuite, self).__init__(tests) + self.group_name = group_name def _group(name): @@ -76,46 +90,37 @@ def _completedin(start, end): return "\n{:.4f}".format(1000 * (end - start)) -# True if test suite directly contains a test case -def _is_test_cases(suite): - return isinstance(suite, unittest.TestSuite) and any( - isinstance(t, unittest.TestCase) for t in suite - ) - - -# True if test suite directly contains test cases -def _is_test_module(suite): - return isinstance(suite, unittest.TestSuite) and any( - _is_test_cases(t) for t in suite - ) +# Flatten nested TestSuite by collecting all test cases. +def _flatten(suites): + tests = [] + for test in suites: + if isinstance(test, unittest.TestSuite): + tests.extend(_flatten(test)) + else: + tests.append(test) + return tests -# Get first test case from a TestSuite created from a test module to find module name -def _get_test_case(suite): - if not isinstance(suite, unittest.TestSuite): - return None - for test in suite: - if not isinstance(test, unittest.TestSuite): - continue - for t in test: - if isinstance(t, unittest.TestCase): - return t - return None +# Group by module name and then by class name +def _to_tree(suite): + tree = unittest.TestSuite() + for k, ms in groupby(suite, _module_name): + sub_trees = _NamedTestSuite(group_name=k) + for c, cs in groupby(ms, _class_name): + sub_trees.addTest(_NamedTestSuite(tests=cs, group_name=c)) + tree.addTest(sub_trees) + return tree -def _get_class_name(x): - cls = x if inspect.isclass(x) else x.__class__ - return cls.__name__ +def _module_name(x): + return x.__class__.__module__ -def _get_module_name(x): - cls = x if inspect.isclass(x) else x.__class__ - mod = cls.__module__ - if mod is None or mod == str.__class__.__module__: - return "" - return mod +def _class_name(x): + return x.__class__.__name__ +# https://github.com/python/cpython/blob/289f1f80ee87a4baf4567a86b3425fb3bf73291d/Lib/unittest/runner.py#L13 class _WritelnDecorator(object): """Used to decorate file-like objects with a handy 'writeln' method"""