Skip to content

Commit 21262e3

Browse files
bottlerfacebook-github-bot
authored andcommitted
Optional ReplaceableBase
Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables. Reviewed By: davnov134 Differential Revision: D35118339 fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985
1 parent e332f9f commit 21262e3

File tree

4 files changed

+172
-47
lines changed

4 files changed

+172
-47
lines changed

pytorch3d/common/datatypes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import sys
78
from typing import Optional, Union
89

910
import torch
@@ -56,3 +57,20 @@ def get_device(x, device: Optional[Device] = None) -> torch.device:
5657

5758
# Default device is cpu
5859
return torch.device("cpu")
60+
61+
62+
# Provide get_origin and get_args even in Python 3.7.
63+
64+
if sys.version_info >= (3, 8, 0):
65+
from typing import get_args, get_origin
66+
elif sys.version_info >= (3, 7, 0):
67+
68+
def get_origin(cls): # pragma: no cover
69+
return getattr(cls, "__origin__", None)
70+
71+
def get_args(cls): # pragma: no cover
72+
return getattr(cls, "__args__", None)
73+
74+
75+
else:
76+
raise ImportError("This module requires Python 3.7+")

pytorch3d/implicitron/dataset/types.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,15 @@
88
import dataclasses
99
import gzip
1010
import json
11-
import sys
1211
from dataclasses import MISSING, Field, dataclass
1312
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
1413

1514
import numpy as np
15+
from pytorch3d.common.datatypes import get_args, get_origin
1616

1717

1818
_X = TypeVar("_X")
1919

20-
21-
if sys.version_info >= (3, 8, 0):
22-
from typing import get_args, get_origin
23-
elif sys.version_info >= (3, 7, 0):
24-
25-
def get_origin(cls):
26-
return getattr(cls, "__origin__", None)
27-
28-
def get_args(cls):
29-
return getattr(cls, "__args__", None)
30-
31-
32-
else:
33-
raise ImportError("This module requires Python 3.7+")
34-
35-
3620
TF3 = Tuple[float, float, float]
3721

3822

pytorch3d/implicitron/tools/config.py

Lines changed: 118 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import warnings
1111
from collections import Counter, defaultdict
1212
from enum import Enum
13-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
13+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
1414

1515
from omegaconf import DictConfig, OmegaConf, open_dict
16+
from pytorch3d.common.datatypes import get_args, get_origin
1617

1718

1819
"""
@@ -97,6 +98,8 @@ class A2(A):
9798
class B(Configurable):
9899
a: A
99100
a_class_type: str = "A2"
101+
b: Optional[A]
102+
b_class_type: Optional[str] = "A2"
100103
101104
def __post_init__(self):
102105
run_auto_creation(self)
@@ -124,6 +127,13 @@ class B:
124127
a_A2_args: DictConfig = dataclasses.field(
125128
default_factory=lambda: DictConfig({"k": 1, "m": 3}
126129
)
130+
b_class_type: Optional[str] = "A2"
131+
b_A1_args: DictConfig = dataclasses.field(
132+
default_factory=lambda: DictConfig({"k": 1, "m": 3}
133+
)
134+
b_A2_args: DictConfig = dataclasses.field(
135+
default_factory=lambda: DictConfig({"k": 1, "m": 3}
136+
)
127137
128138
def __post_init__(self):
129139
if self.a_class_type == "A1":
@@ -133,6 +143,15 @@ def __post_init__(self):
133143
else:
134144
raise ValueError(...)
135145
146+
if self.b_class_type is None:
147+
self.b = None
148+
elif self.b_class_type == "A1":
149+
self.b = A1(**self.b_A1_args)
150+
elif self.b_class_type == "A2":
151+
self.b = A2(**self.b_A2_args)
152+
else:
153+
raise ValueError(...)
154+
136155
3. Aside from these classes, the members of these classes should be things
137156
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
138157
can be built from them with DictConfigs and lists of them.
@@ -324,16 +343,28 @@ def _base_class_from_class(
324343
registry = _Registry()
325344

326345

327-
def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]:
346+
class _ProcessType(Enum):
347+
"""
348+
Type of member which gets rewritten by expand_args_fields.
349+
"""
350+
351+
CONFIGURABLE = 1
352+
REPLACEABLE = 2
353+
OPTIONAL_REPLACEABLE = 3
354+
355+
356+
def _default_create(
357+
name: str, type_: Type, process_type: _ProcessType
358+
) -> Callable[[Any], None]:
328359
"""
329360
Return the default creation function for a member. This is a function which
330361
could be called in __post_init__ to initialise the member, and will be called
331362
from run_auto_creation.
332363
333364
Args:
334365
name: name of the member
335-
type_: declared type of the member
336-
pluggable: True if the member's declared type inherits ReplaceableBase,
366+
type_: type of the member (with any Optional removed)
367+
process_type: Shows whether member's declared type inherits ReplaceableBase,
337368
in which case the actual type to be created is decided at
338369
runtime.
339370
@@ -349,6 +380,10 @@ def inner(self):
349380

350381
def inner_pluggable(self):
351382
type_name = getattr(self, name + TYPE_SUFFIX)
383+
if type_name is None:
384+
setattr(self, name, None)
385+
return
386+
352387
chosen_class = registry.get(type_, type_name)
353388
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
354389
# If this warning is raised, it means that a new definition of
@@ -362,7 +397,7 @@ def inner_pluggable(self):
362397
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
363398
setattr(self, name, chosen_class(**args))
364399

365-
return inner_pluggable if pluggable else inner
400+
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
366401

367402

368403
def run_auto_creation(self: Any) -> None:
@@ -499,7 +534,7 @@ def expand_args_fields(
499534
500535
The transformations this function makes, before the concluding
501536
dataclasses.dataclass, are as follows. if X is a base class with registered
502-
subclasses Y and Z, replace
537+
subclasses Y and Z, replace a class member
503538
504539
x: X
505540
@@ -518,7 +553,32 @@ def create_x(self):
518553
)
519554
x_class_type: str = "UNDEFAULTED"
520555
521-
without adding the optional things if they are already there.
556+
without adding the optional attributes if they are already there.
557+
558+
Similarly, replace
559+
560+
x: Optional[X]
561+
562+
and optionally
563+
564+
x_class_type: Optional[str] = "Y"
565+
def create_x(self):...
566+
567+
with
568+
569+
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
570+
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
571+
def create_x(self):
572+
if self.x_class_type is None:
573+
self.x = None
574+
return
575+
576+
self.x = registry.get(X, self.x_class_type)(
577+
**self.getattr(f"x_{self.x_class_type}_args)
578+
)
579+
x_class_type: Optional[str] = "UNDEFAULTED"
580+
581+
without adding the optional attributes if they are already there.
522582
523583
Similarly, if X is a subclass of Configurable,
524584
@@ -587,26 +647,21 @@ def create_x(self):
587647
if "_processed_members" in base.__dict__:
588648
processed_members.update(base._processed_members)
589649

590-
to_process: List[Tuple[str, Type, bool]] = []
650+
to_process: List[Tuple[str, Type, _ProcessType]] = []
591651
if "__annotations__" in some_class.__dict__:
592652
for name, type_ in some_class.__annotations__.items():
593-
if not isinstance(type_, type):
594-
# type_ could be something like typing.Tuple
653+
underlying_and_process_type = _get_type_to_process(type_)
654+
if underlying_and_process_type is None:
595655
continue
596-
if (
597-
issubclass(type_, ReplaceableBase)
598-
and ReplaceableBase in type_.__bases__
599-
):
600-
to_process.append((name, type_, True))
601-
elif issubclass(type_, Configurable):
602-
to_process.append((name, type_, False))
603-
604-
for name, type_, pluggable in to_process:
656+
underlying_type, process_type = underlying_and_process_type
657+
to_process.append((name, underlying_type, process_type))
658+
659+
for name, underlying_type, process_type in to_process:
605660
_process_member(
606661
name=name,
607-
type_=type_,
608-
pluggable=pluggable,
609-
some_class=cast(type, some_class),
662+
type_=underlying_type,
663+
process_type=process_type,
664+
some_class=some_class,
610665
creation_functions=creation_functions,
611666
_do_not_process=_do_not_process,
612667
known_implementations=known_implementations,
@@ -641,11 +696,39 @@ def create():
641696
return dataclasses.field(default_factory=create)
642697

643698

699+
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
700+
"""
701+
If a member is annotated as `type_`, and that should expanded in
702+
expand_args_fields, return how it should be expanded.
703+
"""
704+
if get_origin(type_) == Union:
705+
# We look for Optional[X] which is a Union of X with None.
706+
args = get_args(type_)
707+
if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
708+
return
709+
underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
710+
if (
711+
issubclass(underlying, ReplaceableBase)
712+
and ReplaceableBase in underlying.__bases__
713+
):
714+
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
715+
716+
if not isinstance(type_, type):
717+
# e.g. any other Union or Tuple
718+
return
719+
720+
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
721+
return type_, _ProcessType.REPLACEABLE
722+
723+
if issubclass(type_, Configurable):
724+
return type_, _ProcessType.CONFIGURABLE
725+
726+
644727
def _process_member(
645728
*,
646729
name: str,
647730
type_: Type,
648-
pluggable: bool,
731+
process_type: _ProcessType,
649732
some_class: Type,
650733
creation_functions: List[str],
651734
_do_not_process: Tuple[type, ...],
@@ -656,8 +739,8 @@ def _process_member(
656739
657740
Args:
658741
name: member name
659-
type_: member declared type
660-
plugglable: whether member has dynamic type
742+
type_: member type (with Optional removed if needed)
743+
process_type: whether member has dynamic type
661744
some_class: (MODIFIED IN PLACE) the class being processed
662745
creation_functions: (MODIFIED IN PLACE) the names of the create functions
663746
_do_not_process: as for expand_args_fields.
@@ -668,10 +751,13 @@ def _process_member(
668751
# there are non-defaulted standard class members.
669752
del some_class.__annotations__[name]
670753

671-
if pluggable:
754+
if process_type != _ProcessType.CONFIGURABLE:
672755
type_name = name + TYPE_SUFFIX
673756
if type_name not in some_class.__annotations__:
674-
some_class.__annotations__[type_name] = str
757+
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
758+
some_class.__annotations__[type_name] = Optional[str]
759+
else:
760+
some_class.__annotations__[type_name] = str
675761
setattr(some_class, type_name, "UNDEFAULTED")
676762

677763
for derived_type in registry.get_all(type_):
@@ -720,7 +806,7 @@ def _process_member(
720806
setattr(
721807
some_class,
722808
creation_function_name,
723-
_default_create(name, type_, pluggable),
809+
_default_create(name, type_, process_type),
724810
)
725811
creation_functions.append(creation_function_name)
726812

@@ -743,7 +829,10 @@ def remove_unused_components(dict_: DictConfig) -> None:
743829
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
744830
for replaceable in replaceables:
745831
selected_type = dict_[replaceable + TYPE_SUFFIX]
746-
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
832+
if selected_type is None:
833+
expect = ""
834+
else:
835+
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
747836
with open_dict(dict_):
748837
for key in args_keys:
749838
if key.startswith(replaceable + "_") and key != expect:

0 commit comments

Comments
 (0)