From 04af063f9b00458814a698fd7b11a5f91c24f821 Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Mon, 21 Nov 2016 17:15:14 +0100 Subject: [PATCH] Expose fixtures to change Django's {Transaction,}TestCase This adds `django_db_testcase` and `django_transactional_db_testcase`, which allows to override them to e.g. enable the `multi_db` feature: ``` @pytest.fixture def django_db_testcase(django_db_testcase): django_db_testcase.multi_db = True return django_db_testcase ``` Ref: https://github.com/pytest-dev/pytest-django/pull/397 --- pytest_django/fixtures.py | 20 ++++++++++++++++---- pytest_django/plugin.py | 2 ++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 7745328ee..21a1891c2 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -14,6 +14,7 @@ from .lazy_django import get_django_version, skip_if_no_django __all__ = ['django_db_setup', 'db', 'transactional_db', 'admin_user', + 'django_db_testcase', 'django_transactional_db_testcase', 'django_user_model', 'django_username_field', 'client', 'admin_client', 'rf', 'settings', 'live_server', '_live_server_helper'] @@ -108,6 +109,18 @@ def teardown_database(): request.addfinalizer(teardown_database) +@pytest.fixture +def django_db_testcase(request): + from django.test import TestCase + return TestCase + + +@pytest.fixture +def django_transactional_db_testcase(request): + from django.test import TransactionTestCase + return TransactionTestCase + + def _django_db_fixture_helper(transactional, request, django_db_blocker): if is_django_unittest(request): return @@ -119,10 +132,9 @@ def _django_db_fixture_helper(transactional, request, django_db_blocker): django_db_blocker.unblock() request.addfinalizer(django_db_blocker.restore) - if transactional: - from django.test import TransactionTestCase as django_case - else: - from django.test import TestCase as django_case + testcase_class_fixture = ('django_transactional_db_testcase' + if transactional else 'django_db_testcase') + django_case = getfixturevalue(request, testcase_class_fixture) test_case = django_case(methodName='__init__') test_case._pre_setup() diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 1fb9120e5..b5344de35 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -31,6 +31,8 @@ from .fixtures import rf # noqa from .fixtures import settings # noqa from .fixtures import transactional_db # noqa +from .fixtures import django_db_testcase # noqa +from .fixtures import django_transactional_db_testcase # noqa from .pytest_compat import getfixturevalue from .lazy_django import (django_settings_is_configured,