Skip to content

Commit 60ab1cd

Browse files
bottlerfacebook-github-bot
authored andcommitted
make x_enabled compulsory
Summary: Optional[some_configurable] won't autogenerate the enabled flag Reviewed By: shapovalov Differential Revision: D41522104 fbshipit-source-id: 555ff6b343faf6f18aad2f92fbb7c341f5e991c6
1 parent 1706eb8 commit 60ab1cd

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

pytorch3d/implicitron/tools/config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,16 +783,16 @@ def create_x_impl(self, enabled, args):
783783
Similarly, replace,
784784
785785
x: Optional[X]
786+
x_enabled: bool = ...
786787
787788
and optionally
788789
789790
def create_x(self):...
790-
x_enabled: bool = ...
791791
792792
with
793793
794794
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
795-
x_enabled: bool = False
795+
x_enabled: bool = ...
796796
def create_x(self):
797797
self.create_x_impl(self.x_enabled, self.x_args)
798798
@@ -1091,8 +1091,10 @@ def _process_member(
10911091
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
10921092
enabled_name = name + ENABLED_SUFFIX
10931093
if enabled_name not in some_class.__annotations__:
1094-
some_class.__annotations__[enabled_name] = bool
1095-
setattr(some_class, enabled_name, False)
1094+
raise ValueError(
1095+
f"{name} is an Optional[{type_.__name__}] member "
1096+
f"but there is no corresponding member {enabled_name}."
1097+
)
10961098

10971099
creation_function_name = f"{CREATE_PREFIX}{name}"
10981100
if not hasattr(some_class, creation_function_name):

tests/implicitron/test_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ class C(Configurable):
446446
b2: Optional[B]
447447
b3: Optional[B]
448448
b2_enabled: bool = True
449+
b3_enabled: bool = False
449450

450451
def __post_init__(self):
451452
run_auto_creation(self)
@@ -681,9 +682,10 @@ def test_remove_unused_components(self):
681682
def test_remove_unused_components_optional(self):
682683
class MainTestWrapper(Configurable):
683684
mt: Optional[MainTest]
685+
mt_enabled: bool = False
684686

685687
args = get_default_args(MainTestWrapper)
686-
self.assertEqual(list(args.keys()), ["mt_args", "mt_enabled"])
688+
self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"])
687689
remove_unused_components(args)
688690
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
689691

@@ -775,6 +777,7 @@ class MyClass(Configurable):
775777
boring_o: Optional[BoringConfigurable]
776778
boring_o_enabled: bool = True
777779
boring_0: Optional[BoringConfigurable]
780+
boring_0_enabled: bool = False
778781

779782
def __post_init__(self):
780783
run_auto_creation(self)

0 commit comments

Comments
 (0)