Skip to content

Fixes class disambiguation logic when there is nested inheritance #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions bump_pydantic/codemods/class_def_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def __init__(self, context: CodemodContext) -> None:
self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set())
self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set))

def _recursively_disambiguate(self, classname: str, context_set: set[str]) -> None:
if classname in context_set and classname in self.context.scratch[self.CLS_CONTEXT_KEY]:
for child_classname in self.context.scratch[self.CLS_CONTEXT_KEY].pop(classname):
context_set.add(child_classname)
self._recursively_disambiguate(child_classname, context_set)

def visit_ClassDef(self, node: cst.ClassDef) -> None:
fqn_set = self.get_metadata(FullyQualifiedNameProvider, node)

Expand All @@ -60,30 +66,24 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name)

# In case we have the following scenario:
# class ChildA(A):
# class A(B): ...
# class B(BaseModel): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `A` as soon as we see `B` is a `BaseModel`.
if (
fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class)
# We want to disambiguate `A` and then `ChildA` as soon as we see `B` is a `BaseModel`.
# We recursively add child classes to self.BASE_MODEL_CONTEXT_KEY.
self._recursively_disambiguate(fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY])

# In case we have the following scenario:
# class A(B): ...
# class B(BaseModel): ...
# class E(D): ...
# class D(C): ...
# class C: ...
# We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`.
if (
fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]
and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY]
):
for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name):
self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class)
# We want to disambiguate `D` and then `E` as soon as we see `C` is NOT a `BaseModel`.
# We recursively add child classes to self.NO_BASE_MODEL_CONTEXT_KEY.
self._recursively_disambiguate(fqn.name, self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY])

# In case we have the following scenario:
# class A(B): ...
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/cases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .field import cases as generic_model_cases
from .folder_inside_folder import cases as folder_inside_folder_cases
from .is_base_model import cases as is_base_model_cases
from .nested_inheritance import cases as nested_inheritance_cases
from .replace_validator import cases as replace_validator_cases
from .root_model import cases as root_model_cases
from .unicode import cases as unicode_cases
Expand All @@ -22,6 +23,7 @@
*base_settings_cases,
*add_none_cases,
*is_base_model_cases,
*nested_inheritance_cases,
*replace_validator_cases,
*config_to_model_cases,
*root_model_cases,
Expand Down
77 changes: 77 additions & 0 deletions tests/integration/cases/nested_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from ..case import Case
from ..file import File
from ..folder import Folder

cases = [
Case(
name="Nested Inheritance",
source=Folder(
"nested_inheritance",
File("__init__.py", content=[]),
File(
"bar.py",
content=[
"from .foo import Foo",
"",
"",
"class Bar(Foo):",
" b: str | None",
],
),
File(
"baz.py",
content=[
"from .bar import Bar",
"",
"",
"class Baz(Bar):",
" c: str | None",
],
),
File(
"foo.py",
content=[
"from pydantic import BaseModel",
"",
"",
"class Foo(BaseModel):",
" a: str | None",
],
),
),
expected=Folder(
"nested_inheritance",
File("__init__.py", content=[]),
File(
"bar.py",
content=[
"from .foo import Foo",
"",
"",
"class Bar(Foo):",
" b: str | None = None",
],
),
File(
"baz.py",
content=[
"from .bar import Bar",
"",
"",
"class Baz(Bar):",
" c: str | None = None",
],
),
File(
"foo.py",
content=[
"from pydantic import BaseModel",
"",
"",
"class Foo(BaseModel):",
" a: str | None = None",
],
),
),
)
]
6 changes: 3 additions & 3 deletions tests/unit/test_add_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def add_annotations(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
parse_module(CodemodTest.make_fixture_data(code)),
cache={
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get(
file_path, ""
)
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(
Path(""), [file_path], timeout=None
).get(file_path, "")
},
)
mod.resolve_many(AddAnnotationsCommand.METADATA_DEPENDENCIES)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_add_default_none.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module:
mod = MetadataWrapper(
parse_module(CodemodTest.make_fixture_data(code)),
cache={
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get(
file_path, ""
)
FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(
Path(""), [file_path], timeout=None
).get(file_path, "")
},
)
mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES)
Expand Down