diff --git a/bump_pydantic/codemods/class_def_visitor.py b/bump_pydantic/codemods/class_def_visitor.py index b97aa0e..1ad73e2 100644 --- a/bump_pydantic/codemods/class_def_visitor.py +++ b/bump_pydantic/codemods/class_def_visitor.py @@ -38,6 +38,12 @@ def __init__(self, context: CodemodContext) -> None: self.context.scratch.setdefault(self.NO_BASE_MODEL_CONTEXT_KEY, set()) self.context.scratch.setdefault(self.CLS_CONTEXT_KEY, defaultdict(set)) + def _recursively_disambiguate(self, classname: str, context_set: set[str]) -> None: + if classname in context_set and classname in self.context.scratch[self.CLS_CONTEXT_KEY]: + for child_classname in self.context.scratch[self.CLS_CONTEXT_KEY].pop(classname): + context_set.add(child_classname) + self._recursively_disambiguate(child_classname, context_set) + def visit_ClassDef(self, node: cst.ClassDef) -> None: fqn_set = self.get_metadata(FullyQualifiedNameProvider, node) @@ -60,30 +66,24 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(fqn.name) # In case we have the following scenario: + # class ChildA(A): # class A(B): ... # class B(BaseModel): ... # class D(C): ... # class C: ... - # We want to disambiguate `A` as soon as we see `B` is a `BaseModel`. - if ( - fqn.name in self.context.scratch[self.BASE_MODEL_CONTEXT_KEY] - and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] - ): - for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): - self.context.scratch[self.BASE_MODEL_CONTEXT_KEY].add(parent_class) + # We want to disambiguate `A` and then `ChildA` as soon as we see `B` is a `BaseModel`. + # We recursively add child classes to self.BASE_MODEL_CONTEXT_KEY. + self._recursively_disambiguate(fqn.name, self.context.scratch[self.BASE_MODEL_CONTEXT_KEY]) # In case we have the following scenario: # class A(B): ... # class B(BaseModel): ... + # class E(D): ... # class D(C): ... # class C: ... - # We want to disambiguate `D` as soon as we see `C` is NOT a `BaseModel`. - if ( - fqn.name in self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY] - and fqn.name in self.context.scratch[self.CLS_CONTEXT_KEY] - ): - for parent_class in self.context.scratch[self.CLS_CONTEXT_KEY].pop(fqn.name): - self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY].add(parent_class) + # We want to disambiguate `D` and then `E` as soon as we see `C` is NOT a `BaseModel`. + # We recursively add child classes to self.NO_BASE_MODEL_CONTEXT_KEY. + self._recursively_disambiguate(fqn.name, self.context.scratch[self.NO_BASE_MODEL_CONTEXT_KEY]) # In case we have the following scenario: # class A(B): ... diff --git a/tests/integration/cases/__init__.py b/tests/integration/cases/__init__.py index ecb2ffc..c2ae14b 100644 --- a/tests/integration/cases/__init__.py +++ b/tests/integration/cases/__init__.py @@ -9,6 +9,7 @@ from .field import cases as generic_model_cases from .folder_inside_folder import cases as folder_inside_folder_cases from .is_base_model import cases as is_base_model_cases +from .nested_inheritance import cases as nested_inheritance_cases from .replace_validator import cases as replace_validator_cases from .root_model import cases as root_model_cases from .unicode import cases as unicode_cases @@ -22,6 +23,7 @@ *base_settings_cases, *add_none_cases, *is_base_model_cases, + *nested_inheritance_cases, *replace_validator_cases, *config_to_model_cases, *root_model_cases, diff --git a/tests/integration/cases/nested_inheritance.py b/tests/integration/cases/nested_inheritance.py new file mode 100644 index 0000000..a1cdec4 --- /dev/null +++ b/tests/integration/cases/nested_inheritance.py @@ -0,0 +1,77 @@ +from ..case import Case +from ..file import File +from ..folder import Folder + +cases = [ + Case( + name="Nested Inheritance", + source=Folder( + "nested_inheritance", + File("__init__.py", content=[]), + File( + "bar.py", + content=[ + "from .foo import Foo", + "", + "", + "class Bar(Foo):", + " b: str | None", + ], + ), + File( + "baz.py", + content=[ + "from .bar import Bar", + "", + "", + "class Baz(Bar):", + " c: str | None", + ], + ), + File( + "foo.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class Foo(BaseModel):", + " a: str | None", + ], + ), + ), + expected=Folder( + "nested_inheritance", + File("__init__.py", content=[]), + File( + "bar.py", + content=[ + "from .foo import Foo", + "", + "", + "class Bar(Foo):", + " b: str | None = None", + ], + ), + File( + "baz.py", + content=[ + "from .bar import Bar", + "", + "", + "class Baz(Bar):", + " c: str | None = None", + ], + ), + File( + "foo.py", + content=[ + "from pydantic import BaseModel", + "", + "", + "class Foo(BaseModel):", + " a: str | None = None", + ], + ), + ), + ) +] diff --git a/tests/unit/test_add_annotations.py b/tests/unit/test_add_annotations.py index 2750a96..fd05dd3 100644 --- a/tests/unit/test_add_annotations.py +++ b/tests/unit/test_add_annotations.py @@ -16,9 +16,9 @@ def add_annotations(self, file_path: str, code: str) -> cst.Module: mod = MetadataWrapper( parse_module(CodemodTest.make_fixture_data(code)), cache={ - FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( - file_path, "" - ) + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache( + Path(""), [file_path], timeout=None + ).get(file_path, "") }, ) mod.resolve_many(AddAnnotationsCommand.METADATA_DEPENDENCIES) diff --git a/tests/unit/test_add_default_none.py b/tests/unit/test_add_default_none.py index a0f1202..3a4f28e 100644 --- a/tests/unit/test_add_default_none.py +++ b/tests/unit/test_add_default_none.py @@ -17,9 +17,9 @@ def add_default_none(self, file_path: str, code: str) -> cst.Module: mod = MetadataWrapper( parse_module(CodemodTest.make_fixture_data(code)), cache={ - FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache(Path(""), [file_path], None).get( - file_path, "" - ) + FullyQualifiedNameProvider: FullyQualifiedNameProvider.gen_cache( + Path(""), [file_path], timeout=None + ).get(file_path, "") }, ) mod.resolve_many(AddDefaultNoneCommand.METADATA_DEPENDENCIES)