10
10
import warnings
11
11
from collections import Counter , defaultdict
12
12
from enum import Enum
13
- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , cast
13
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union
14
14
15
15
from omegaconf import DictConfig , OmegaConf , open_dict
16
+ from pytorch3d .common .datatypes import get_args , get_origin
16
17
17
18
18
19
"""
@@ -97,6 +98,8 @@ class A2(A):
97
98
class B(Configurable):
98
99
a: A
99
100
a_class_type: str = "A2"
101
+ b: Optional[A]
102
+ b_class_type: Optional[str] = "A2"
100
103
101
104
def __post_init__(self):
102
105
run_auto_creation(self)
@@ -124,6 +127,13 @@ class B:
124
127
a_A2_args: DictConfig = dataclasses.field(
125
128
default_factory=lambda: DictConfig({"k": 1, "m": 3}
126
129
)
130
+ b_class_type: Optional[str] = "A2"
131
+ b_A1_args: DictConfig = dataclasses.field(
132
+ default_factory=lambda: DictConfig({"k": 1, "m": 3}
133
+ )
134
+ b_A2_args: DictConfig = dataclasses.field(
135
+ default_factory=lambda: DictConfig({"k": 1, "m": 3}
136
+ )
127
137
128
138
def __post_init__(self):
129
139
if self.a_class_type == "A1":
@@ -133,6 +143,15 @@ def __post_init__(self):
133
143
else:
134
144
raise ValueError(...)
135
145
146
+ if self.b_class_type is None:
147
+ self.b = None
148
+ elif self.b_class_type == "A1":
149
+ self.b = A1(**self.b_A1_args)
150
+ elif self.b_class_type == "A2":
151
+ self.b = A2(**self.b_A2_args)
152
+ else:
153
+ raise ValueError(...)
154
+
136
155
3. Aside from these classes, the members of these classes should be things
137
156
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
138
157
can be built from them with DictConfigs and lists of them.
@@ -324,16 +343,28 @@ def _base_class_from_class(
324
343
registry = _Registry ()
325
344
326
345
327
- def _default_create (name : str , type_ : Type , pluggable : bool ) -> Callable [[Any ], None ]:
346
+ class _ProcessType (Enum ):
347
+ """
348
+ Type of member which gets rewritten by expand_args_fields.
349
+ """
350
+
351
+ CONFIGURABLE = 1
352
+ REPLACEABLE = 2
353
+ OPTIONAL_REPLACEABLE = 3
354
+
355
+
356
+ def _default_create (
357
+ name : str , type_ : Type , process_type : _ProcessType
358
+ ) -> Callable [[Any ], None ]:
328
359
"""
329
360
Return the default creation function for a member. This is a function which
330
361
could be called in __post_init__ to initialise the member, and will be called
331
362
from run_auto_creation.
332
363
333
364
Args:
334
365
name: name of the member
335
- type_: declared type of the member
336
- pluggable: True if the member's declared type inherits ReplaceableBase,
366
+ type_: type of the member (with any Optional removed)
367
+ process_type: Shows whether member's declared type inherits ReplaceableBase,
337
368
in which case the actual type to be created is decided at
338
369
runtime.
339
370
@@ -349,6 +380,10 @@ def inner(self):
349
380
350
381
def inner_pluggable (self ):
351
382
type_name = getattr (self , name + TYPE_SUFFIX )
383
+ if type_name is None :
384
+ setattr (self , name , None )
385
+ return
386
+
352
387
chosen_class = registry .get (type_ , type_name )
353
388
if self ._known_implementations .get (type_name , chosen_class ) is not chosen_class :
354
389
# If this warning is raised, it means that a new definition of
@@ -362,7 +397,7 @@ def inner_pluggable(self):
362
397
args = getattr (self , f"{ name } _{ type_name } { ARGS_SUFFIX } " )
363
398
setattr (self , name , chosen_class (** args ))
364
399
365
- return inner_pluggable if pluggable else inner
400
+ return inner if process_type == _ProcessType . CONFIGURABLE else inner_pluggable
366
401
367
402
368
403
def run_auto_creation (self : Any ) -> None :
@@ -499,7 +534,7 @@ def expand_args_fields(
499
534
500
535
The transformations this function makes, before the concluding
501
536
dataclasses.dataclass, are as follows. if X is a base class with registered
502
- subclasses Y and Z, replace
537
+ subclasses Y and Z, replace a class member
503
538
504
539
x: X
505
540
@@ -518,7 +553,32 @@ def create_x(self):
518
553
)
519
554
x_class_type: str = "UNDEFAULTED"
520
555
521
- without adding the optional things if they are already there.
556
+ without adding the optional attributes if they are already there.
557
+
558
+ Similarly, replace
559
+
560
+ x: Optional[X]
561
+
562
+ and optionally
563
+
564
+ x_class_type: Optional[str] = "Y"
565
+ def create_x(self):...
566
+
567
+ with
568
+
569
+ x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
570
+ x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
571
+ def create_x(self):
572
+ if self.x_class_type is None:
573
+ self.x = None
574
+ return
575
+
576
+ self.x = registry.get(X, self.x_class_type)(
577
+ **self.getattr(f"x_{self.x_class_type}_args)
578
+ )
579
+ x_class_type: Optional[str] = "UNDEFAULTED"
580
+
581
+ without adding the optional attributes if they are already there.
522
582
523
583
Similarly, if X is a subclass of Configurable,
524
584
@@ -587,26 +647,21 @@ def create_x(self):
587
647
if "_processed_members" in base .__dict__ :
588
648
processed_members .update (base ._processed_members )
589
649
590
- to_process : List [Tuple [str , Type , bool ]] = []
650
+ to_process : List [Tuple [str , Type , _ProcessType ]] = []
591
651
if "__annotations__" in some_class .__dict__ :
592
652
for name , type_ in some_class .__annotations__ .items ():
593
- if not isinstance (type_ , type ):
594
- # type_ could be something like typing.Tuple
653
+ underlying_and_process_type = _get_type_to_process (type_ )
654
+ if underlying_and_process_type is None :
595
655
continue
596
- if (
597
- issubclass (type_ , ReplaceableBase )
598
- and ReplaceableBase in type_ .__bases__
599
- ):
600
- to_process .append ((name , type_ , True ))
601
- elif issubclass (type_ , Configurable ):
602
- to_process .append ((name , type_ , False ))
603
-
604
- for name , type_ , pluggable in to_process :
656
+ underlying_type , process_type = underlying_and_process_type
657
+ to_process .append ((name , underlying_type , process_type ))
658
+
659
+ for name , underlying_type , process_type in to_process :
605
660
_process_member (
606
661
name = name ,
607
- type_ = type_ ,
608
- pluggable = pluggable ,
609
- some_class = cast ( type , some_class ) ,
662
+ type_ = underlying_type ,
663
+ process_type = process_type ,
664
+ some_class = some_class ,
610
665
creation_functions = creation_functions ,
611
666
_do_not_process = _do_not_process ,
612
667
known_implementations = known_implementations ,
@@ -641,11 +696,39 @@ def create():
641
696
return dataclasses .field (default_factory = create )
642
697
643
698
699
+ def _get_type_to_process (type_ ) -> Optional [Tuple [Type , _ProcessType ]]:
700
+ """
701
+ If a member is annotated as `type_`, and that should expanded in
702
+ expand_args_fields, return how it should be expanded.
703
+ """
704
+ if get_origin (type_ ) == Union :
705
+ # We look for Optional[X] which is a Union of X with None.
706
+ args = get_args (type_ )
707
+ if len (args ) != 2 or all (a is not type (None ) for a in args ): # noqa: E721
708
+ return
709
+ underlying = args [0 ] if args [1 ] is type (None ) else args [1 ] # noqa: E721
710
+ if (
711
+ issubclass (underlying , ReplaceableBase )
712
+ and ReplaceableBase in underlying .__bases__
713
+ ):
714
+ return underlying , _ProcessType .OPTIONAL_REPLACEABLE
715
+
716
+ if not isinstance (type_ , type ):
717
+ # e.g. any other Union or Tuple
718
+ return
719
+
720
+ if issubclass (type_ , ReplaceableBase ) and ReplaceableBase in type_ .__bases__ :
721
+ return type_ , _ProcessType .REPLACEABLE
722
+
723
+ if issubclass (type_ , Configurable ):
724
+ return type_ , _ProcessType .CONFIGURABLE
725
+
726
+
644
727
def _process_member (
645
728
* ,
646
729
name : str ,
647
730
type_ : Type ,
648
- pluggable : bool ,
731
+ process_type : _ProcessType ,
649
732
some_class : Type ,
650
733
creation_functions : List [str ],
651
734
_do_not_process : Tuple [type , ...],
@@ -656,8 +739,8 @@ def _process_member(
656
739
657
740
Args:
658
741
name: member name
659
- type_: member declared type
660
- plugglable : whether member has dynamic type
742
+ type_: member type (with Optional removed if needed)
743
+ process_type : whether member has dynamic type
661
744
some_class: (MODIFIED IN PLACE) the class being processed
662
745
creation_functions: (MODIFIED IN PLACE) the names of the create functions
663
746
_do_not_process: as for expand_args_fields.
@@ -668,10 +751,13 @@ def _process_member(
668
751
# there are non-defaulted standard class members.
669
752
del some_class .__annotations__ [name ]
670
753
671
- if pluggable :
754
+ if process_type != _ProcessType . CONFIGURABLE :
672
755
type_name = name + TYPE_SUFFIX
673
756
if type_name not in some_class .__annotations__ :
674
- some_class .__annotations__ [type_name ] = str
757
+ if process_type == _ProcessType .OPTIONAL_REPLACEABLE :
758
+ some_class .__annotations__ [type_name ] = Optional [str ]
759
+ else :
760
+ some_class .__annotations__ [type_name ] = str
675
761
setattr (some_class , type_name , "UNDEFAULTED" )
676
762
677
763
for derived_type in registry .get_all (type_ ):
@@ -720,7 +806,7 @@ def _process_member(
720
806
setattr (
721
807
some_class ,
722
808
creation_function_name ,
723
- _default_create (name , type_ , pluggable ),
809
+ _default_create (name , type_ , process_type ),
724
810
)
725
811
creation_functions .append (creation_function_name )
726
812
@@ -743,7 +829,10 @@ def remove_unused_components(dict_: DictConfig) -> None:
743
829
args_keys = [key for key in keys if key .endswith (ARGS_SUFFIX )]
744
830
for replaceable in replaceables :
745
831
selected_type = dict_ [replaceable + TYPE_SUFFIX ]
746
- expect = replaceable + "_" + selected_type + ARGS_SUFFIX
832
+ if selected_type is None :
833
+ expect = ""
834
+ else :
835
+ expect = replaceable + "_" + selected_type + ARGS_SUFFIX
747
836
with open_dict (dict_ ):
748
837
for key in args_keys :
749
838
if key .startswith (replaceable + "_" ) and key != expect :
0 commit comments