Skip to content

Commit a1b3c0c

Browse files
committed
Inject sub-modules when importing with importlib
1 parent 3aa74ea commit a1b3c0c

File tree

2 files changed

+131
-91
lines changed

2 files changed

+131
-91
lines changed

src/_pytest/pathlib.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from posixpath import sep as posix_sep
2424
from types import ModuleType
2525
from typing import Callable
26+
from typing import Dict
2627
from typing import Iterable
2728
from typing import Iterator
2829
from typing import Optional
@@ -508,6 +509,7 @@ def import_path(
508509
mod = importlib.util.module_from_spec(spec)
509510
sys.modules[module_name] = mod
510511
spec.loader.exec_module(mod) # type: ignore[union-attr]
512+
insert_missing_modules(sys.modules, module_name)
511513
return mod
512514

513515
pkg_path = resolve_package_path(path)
@@ -593,6 +595,26 @@ def module_name_from_path(path: Path, root: Path) -> str:
593595
return ".".join(path_parts)
594596

595597

598+
def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) -> None:
599+
"""
600+
Used by ``import_path`` to create intermediate modules when using mode=importlib.
601+
602+
When we want to import a module as "src.tests.test_foo" for example, we need
603+
to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo",
604+
otherwise "src.tests.test_foo" is not importable by ``__import__``.
605+
"""
606+
module_parts = module_name.split(".")
607+
while module_name:
608+
if module_name not in modules:
609+
module = ModuleType(
610+
module_name,
611+
doc="Empty module created by pytest's importmode=importlib.",
612+
)
613+
modules[module_name] = module
614+
module_parts.pop(-1)
615+
module_name = ".".join(module_parts)
616+
617+
596618
def resolve_package_path(path: Path) -> Optional[Path]:
597619
"""Return the Python package path by looking for the last
598620
directory upwards which still contains an __init__.py.

testing/test_pathlib.py

Lines changed: 109 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest.mock
44
from pathlib import Path
55
from textwrap import dedent
6+
from types import ModuleType
67
from typing import Any
78

89
import py
@@ -17,6 +18,7 @@
1718
from _pytest.pathlib import get_lock_path
1819
from _pytest.pathlib import import_path
1920
from _pytest.pathlib import ImportPathMismatchError
21+
from _pytest.pathlib import insert_missing_modules
2022
from _pytest.pathlib import maybe_delete_a_numbered_dir
2123
from _pytest.pathlib import module_name_from_path
2224
from _pytest.pathlib import resolve_package_path
@@ -263,8 +265,8 @@ def test_invalid_path(self, tmpdir):
263265

264266
@pytest.fixture
265267
def simple_module(self, tmpdir):
266-
tmpdir.join("src/tests").ensure_dir()
267-
fn = tmpdir.join("src/tests/mymod.py")
268+
tmpdir.join("_src/tests").ensure_dir()
269+
fn = tmpdir.join("_src/tests/mymod.py")
268270
fn.write(
269271
dedent(
270272
"""
@@ -280,7 +282,9 @@ def test_importmode_importlib(self, simple_module, tmpdir):
280282
assert module.foo(2) == 42 # type: ignore[attr-defined]
281283
assert simple_module.dirname not in sys.path
282284
assert module.__name__ in sys.modules
283-
assert module.__name__ == "src.tests.mymod"
285+
assert module.__name__ == "_src.tests.mymod"
286+
assert "_src" in sys.modules
287+
assert "_src.tests" in sys.modules
284288

285289
def test_importmode_twice_is_different_module(self, simple_module, tmpdir):
286290
"""`importlib` mode always returns a new module."""
@@ -442,112 +446,126 @@ def test_samefile_false_negatives(tmp_path: Path, monkeypatch: MonkeyPatch) -> N
442446
assert getattr(module, "foo")() == 42
443447

444448

445-
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
446-
def test_importmode_importlib_with_dataclass(tmp_path: Path) -> None:
447-
"""Ensure that importlib mode works with a module containing dataclasses (#7856)."""
448-
fn = tmp_path.joinpath("src/tests/test_dataclass.py")
449-
fn.parent.mkdir(parents=True)
450-
fn.write_text(
451-
dedent(
452-
"""
453-
from dataclasses import dataclass
449+
class TestImportLibMode:
450+
@pytest.mark.skipif(sys.version_info < (3, 7), reason="Dataclasses in Python3.7+")
451+
def test_importmode_importlib_with_dataclass(self, tmp_path: Path) -> None:
452+
"""Ensure that importlib mode works with a module containing dataclasses (#7856)."""
453+
fn = tmp_path.joinpath("_src/tests/test_dataclass.py")
454+
fn.parent.mkdir(parents=True)
455+
fn.write_text(
456+
dedent(
457+
"""
458+
from dataclasses import dataclass
454459
455-
@dataclass
456-
class Data:
457-
value: str
458-
"""
460+
@dataclass
461+
class Data:
462+
value: str
463+
"""
464+
)
459465
)
460-
)
461-
462-
module = import_path(fn, mode="importlib", root=tmp_path)
463-
Data: Any = getattr(module, "Data")
464-
data = Data(value="foo")
465-
assert data.value == "foo"
466-
assert data.__module__ == "src.tests.test_dataclass"
467466

467+
module = import_path(fn, mode="importlib", root=tmp_path)
468+
Data: Any = getattr(module, "Data")
469+
data = Data(value="foo")
470+
assert data.value == "foo"
471+
assert data.__module__ == "_src.tests.test_dataclass"
472+
473+
def test_importmode_importlib_with_pickle(self, tmp_path: Path) -> None:
474+
"""Ensure that importlib mode works with pickle (#7859)."""
475+
fn = tmp_path.joinpath("_src/tests/test_pickle.py")
476+
fn.parent.mkdir(parents=True)
477+
fn.write_text(
478+
dedent(
479+
"""
480+
import pickle
468481
469-
def test_importmode_importlib_with_pickle(tmp_path: Path) -> None:
470-
"""Ensure that importlib mode works with pickle (#7859)."""
471-
fn = tmp_path.joinpath("src/tests/test_pickle.py")
472-
fn.parent.mkdir(parents=True)
473-
fn.write_text(
474-
dedent(
475-
"""
476-
import pickle
477-
478-
def _action():
479-
return 42
482+
def _action():
483+
return 42
480484
481-
def round_trip():
482-
s = pickle.dumps(_action)
483-
return pickle.loads(s)
484-
"""
485+
def round_trip():
486+
s = pickle.dumps(_action)
487+
return pickle.loads(s)
488+
"""
489+
)
485490
)
486-
)
487-
488-
module = import_path(fn, mode="importlib", root=tmp_path)
489-
round_trip = getattr(module, "round_trip")
490-
action = round_trip()
491-
assert action() == 42
492491

492+
module = import_path(fn, mode="importlib", root=tmp_path)
493+
round_trip = getattr(module, "round_trip")
494+
action = round_trip()
495+
assert action() == 42
493496

494-
def test_importmode_importlib_with_pickle_separate_modules(tmp_path: Path) -> None:
495-
"""
496-
Ensure that importlib mode works can load pickles that look similar but are
497-
defined in separate modules.
498-
"""
499-
fn1 = tmp_path.joinpath("src/m1/tests/test.py")
500-
fn1.parent.mkdir(parents=True)
501-
fn1.write_text(
502-
dedent(
503-
"""
504-
import attr
505-
import pickle
497+
def test_importmode_importlib_with_pickle_separate_modules(
498+
self, tmp_path: Path
499+
) -> None:
500+
"""
501+
Ensure that importlib mode works can load pickles that look similar but are
502+
defined in separate modules.
503+
"""
504+
fn1 = tmp_path.joinpath("_src/m1/tests/test.py")
505+
fn1.parent.mkdir(parents=True)
506+
fn1.write_text(
507+
dedent(
508+
"""
509+
import attr
510+
import pickle
506511
507-
@attr.s(auto_attribs=True)
508-
class Data:
509-
x: int = 42
510-
"""
512+
@attr.s(auto_attribs=True)
513+
class Data:
514+
x: int = 42
515+
"""
516+
)
511517
)
512-
)
513518

514-
fn2 = tmp_path.joinpath("src/m2/tests/test.py")
515-
fn2.parent.mkdir(parents=True)
516-
fn2.write_text(
517-
dedent(
518-
"""
519-
import attr
520-
import pickle
519+
fn2 = tmp_path.joinpath("_src/m2/tests/test.py")
520+
fn2.parent.mkdir(parents=True)
521+
fn2.write_text(
522+
dedent(
523+
"""
524+
import attr
525+
import pickle
521526
522-
@attr.s(auto_attribs=True)
523-
class Data:
524-
x: str = ""
525-
"""
527+
@attr.s(auto_attribs=True)
528+
class Data:
529+
x: str = ""
530+
"""
531+
)
526532
)
527-
)
528533

529-
import pickle
534+
import pickle
535+
536+
def round_trip(obj):
537+
s = pickle.dumps(obj)
538+
return pickle.loads(s)
539+
540+
module = import_path(fn1, mode="importlib", root=tmp_path)
541+
Data1 = getattr(module, "Data")
530542

531-
def round_trip(obj):
532-
s = pickle.dumps(obj)
533-
return pickle.loads(s)
543+
module = import_path(fn2, mode="importlib", root=tmp_path)
544+
Data2 = getattr(module, "Data")
534545

535-
module = import_path(fn1, mode="importlib", root=tmp_path)
536-
Data1 = getattr(module, "Data")
546+
assert round_trip(Data1(20)) == Data1(20)
547+
assert round_trip(Data2("hello")) == Data2("hello")
548+
assert Data1.__module__ == "_src.m1.tests.test"
549+
assert Data2.__module__ == "_src.m2.tests.test"
537550

538-
module = import_path(fn2, mode="importlib", root=tmp_path)
539-
Data2 = getattr(module, "Data")
551+
def test_module_name_from_path(self, tmp_path: Path) -> None:
552+
result = module_name_from_path(tmp_path / "src/tests/test_foo.py", tmp_path)
553+
assert result == "src.tests.test_foo"
540554

541-
assert round_trip(Data1(20)) == Data1(20)
542-
assert round_trip(Data2("hello")) == Data2("hello")
543-
assert Data1.__module__ == "src.m1.tests.test"
544-
assert Data2.__module__ == "src.m2.tests.test"
555+
# Path is not relative to root dir: use the full path to obtain the module name.
556+
result = module_name_from_path(Path("/home/foo/test_foo.py"), Path("/bar"))
557+
assert result == "home.foo.test_foo"
545558

559+
def test_insert_missing_modules(self) -> None:
560+
modules = {"src.tests.foo": ModuleType("src.tests.foo")}
561+
insert_missing_modules(modules, "src.tests.foo")
562+
assert sorted(modules) == ["src", "src.tests", "src.tests.foo"]
546563

547-
def test_module_name_from_path(tmp_path: Path) -> None:
548-
result = module_name_from_path(tmp_path / "src/tests/test_foo.py", tmp_path)
549-
assert result == "src.tests.test_foo"
564+
mod = ModuleType("mod", doc="My Module")
565+
modules = {"src": mod}
566+
insert_missing_modules(modules, "src")
567+
assert modules == {"src": mod}
550568

551-
# Path is not relative to root dir: use the full path to obtain the module name.
552-
result = module_name_from_path(Path("/home/foo/test_foo.py"), Path("/bar"))
553-
assert result == "home.foo.test_foo"
569+
modules = {}
570+
insert_missing_modules(modules, "")
571+
assert modules == {}

0 commit comments

Comments
 (0)