@@ -43,27 +43,8 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
43
43
"""
44
44
Decoding function is a torch.nn.Module which takes the embedding of a location in
45
45
space and transforms it into the required quantity (for example density and color).
46
-
47
- Members:
48
- param_groups: dictionary where keys are names of individual parameters
49
- or module members and values are the parameter group where the
50
- parameter/member will be sorted to. "self" key is used to denote the
51
- parameter group at the module level. Possible keys, including the "self" key
52
- do not have to be defined. By default all parameters are put into "default"
53
- parameter group and have the learning rate defined in the optimizer,
54
- it can be overridden at the:
55
- - module level with “self” key, all the parameters and child
56
- module's parameters will be put to that parameter group
57
- - member level, which is the same as if the `param_groups` in that
58
- member has key=“self” and value equal to that parameter group.
59
- This is useful if members do not have `param_groups`, for
60
- example torch.nn.Linear.
61
- - parameter level, parameter with the same name as the key
62
- will be put to that parameter group.
63
46
"""
64
47
65
- param_groups : Dict [str , str ] = field (default_factory = lambda : {})
66
-
67
48
def __post_init__ (self ):
68
49
super ().__init__ ()
69
50
@@ -280,11 +261,30 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
280
261
class MLPDecoder (DecoderFunctionBase ):
281
262
"""
282
263
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
283
- If using Implicitron config system `input_dim` of the `network` is changed to the
284
- value of `input_dim` member and `input_skips` is removed.
264
+ The `input_dim` of the `network` is set from the value of `input_dim` member.
265
+
266
+ Members:
267
+ input_dim: dimension of input.
268
+ param_groups: dictionary where keys are names of individual parameters
269
+ or module members and values are the parameter group where the
270
+ parameter/member will be sorted to. "self" key is used to denote the
271
+ parameter group at the module level. Possible keys, including the "self" key
272
+ do not have to be defined. By default all parameters are put into "default"
273
+ parameter group and have the learning rate defined in the optimizer,
274
+ it can be overridden at the:
275
+ - module level with “self” key, all the parameters and child
276
+ module's parameters will be put to that parameter group
277
+ - member level, which is the same as if the `param_groups` in that
278
+ member has key=“self” and value equal to that parameter group.
279
+ This is useful if members do not have `param_groups`, for
280
+ example torch.nn.Linear.
281
+ - parameter level, parameter with the same name as the key
282
+ will be put to that parameter group.
283
+ network_args: configuration for MLPWithInputSkips
285
284
"""
286
285
287
286
input_dim : int = 3
287
+ param_groups : Dict [str , str ] = field (default_factory = lambda : {})
288
288
network : MLPWithInputSkips
289
289
290
290
def __post_init__ (self ):
0 commit comments