15
15
16
16
"""
17
17
18
+ from collections .abc import Mapping
18
19
from dataclasses import dataclass , field
19
- from typing import Callable , ClassVar , Dict , List , Optional , Tuple , Type
20
+ from typing import Callable , ClassVar , Dict , Iterator , List , Optional , Tuple , Type
20
21
21
22
import torch
22
23
from omegaconf import DictConfig
@@ -164,8 +165,9 @@ def get_output_dim(args: DictConfig) -> int:
164
165
165
166
def change_resolution (
166
167
self ,
167
- epoch : int ,
168
168
grid_values : VoxelGridValuesBase ,
169
+ epoch : int ,
170
+ * ,
169
171
mode : str = "linear" ,
170
172
align_corners : bool = True ,
171
173
antialias : bool = False ,
@@ -177,8 +179,8 @@ def change_resolution(
177
179
epoch: current training epoch, used to see if the grid needs regridding
178
180
grid_values: instance of self.values_type which contains
179
181
the voxel grid which will be interpolated to create the new grid
180
- wanted_resolution: tuple of (x, y, z) resolutions which determine
181
- new grid's resolution
182
+ epoch: epoch which is used to get the resolution of the new
183
+ `grid_values` using `self.resolution_changes`.
182
184
align_corners: as for torch.nn.functional.interpolate
183
185
mode: as for torch.nn.functional.interpolate
184
186
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
@@ -225,11 +227,17 @@ def change_individual_resolution(tensor, wanted_resolution):
225
227
# pyre-ignore[29]
226
228
return self .values_type (** params ), True
227
229
228
- def get_resolution_change_epochs (self ) -> List [int ]:
230
+ def get_resolution_change_epochs (self ) -> Tuple [int , ... ]:
229
231
"""
230
232
Returns epochs at which this grid should change epochs.
231
233
"""
232
- return list (self .resolution_changes .keys ())
234
+ return tuple (self .resolution_changes .keys ())
235
+
236
+ def get_align_corners (self ) -> bool :
237
+ """
238
+ Returns True if voxel grid uses align_corners=True
239
+ """
240
+ return self .align_corners
233
241
234
242
235
243
@dataclass
@@ -583,6 +591,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
583
591
"""
584
592
A wrapper torch.nn.Module for the VoxelGrid classes, which
585
593
contains parameters that are needed to train the VoxelGrid classes.
594
+ Can contain the parameters for the voxel grid as pytorch parameters
595
+ or as registered buffers.
586
596
587
597
Members:
588
598
voxel_grid_class_type: The name of the class to use for voxel_grid,
@@ -596,17 +606,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
596
606
with mean=init_mean and std=init_std. Default 0.1
597
607
init_mean: Parameters are initialized using the gaussian distribution
598
608
with mean=init_mean and std=init_std. Default 0.
609
+ hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
610
+ will be saved as parameters and therefore be trainable. Default True.
599
611
"""
600
612
601
613
voxel_grid_class_type : str = "FullResolutionVoxelGrid"
602
614
voxel_grid : VoxelGridBase
603
615
604
- extents : Tuple [float , float , float ] = (1 .0 , 1 .0 , 1 .0 )
616
+ extents : Tuple [float , float , float ] = (2 .0 , 2 .0 , 2 .0 )
605
617
translation : Tuple [float , float , float ] = (0.0 , 0.0 , 0.0 )
606
618
607
619
init_std : float = 0.1
608
620
init_mean : float = 0
609
621
622
+ hold_voxel_grid_as_parameters : bool = True
623
+
610
624
def __post_init__ (self ):
611
625
super ().__init__ ()
612
626
run_auto_creation (self )
@@ -619,7 +633,8 @@ def __post_init__(self):
619
633
)
620
634
for name , shape in shapes .items ()
621
635
}
622
- self .params = torch .nn .ParameterDict (params )
636
+
637
+ self .set_voxel_grid_parameters (self .voxel_grid .values_type (** params ))
623
638
self ._register_load_state_dict_pre_hook (self ._create_parameters_with_new_size )
624
639
625
640
def forward (self , points : torch .Tensor ) -> torch .Tensor :
@@ -632,31 +647,29 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
632
647
Returns:
633
648
torch.Tensor of shape (..., n_features)
634
649
"""
635
- locator = VolumeLocator (
636
- batch_size = 1 ,
637
- # The resolution of the voxel grid does not need to be known
638
- # to the locator object. It is easiest to fix the resolution of the locator.
639
- # In particular we fix it to (2,2,2) so that there is exactly one voxel of the
640
- # desired size. The locator object uses (z, y, x) convention for the grid_size,
641
- # and this module uses (x, y, z) convention so the order has to be reversed
642
- # (irrelevant in this case since they are all equal).
643
- # It is (2, 2, 2) because the VolumeLocator object behaves like
644
- # align_corners=True, which means that the points are in the corners of
645
- # the volume. So in the grid of (2, 2, 2) there is only one voxel.
646
- grid_sizes = (2 , 2 , 2 ),
647
- # The locator object uses (x, y, z) convention for the
648
- # voxel size and translation.
649
- voxel_size = tuple (self .extents ),
650
- volume_translation = tuple (self .translation ),
651
- # pyre-ignore[29]
652
- device = next (val for val in self .params .values () if val is not None ).device ,
653
- )
650
+ locator = self ._get_volume_locator ()
654
651
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
655
652
# torch.nn.modules.module.Module]` is not a function.
656
653
grid_values = self .voxel_grid .values_type (** self .params )
657
654
# voxel grids operate with extra n_grids dimension, which we fix to one
658
655
return self .voxel_grid .evaluate_world (points [None ], grid_values , locator )[0 ]
659
656
657
+ def set_voxel_grid_parameters (self , params : VoxelGridValuesBase ) -> None :
658
+ """
659
+ Sets the parameters of the underlying voxel grid.
660
+
661
+ Args:
662
+ params: parameters of type `self.voxel_grid.values_type` which will
663
+ replace current parameters
664
+ """
665
+ if self .hold_voxel_grid_as_parameters :
666
+ # pyre-ignore [16]
667
+ self .params = torch .nn .ParameterDict (vars (params ))
668
+ else :
669
+ # Torch Module to hold parameters since they can only be registered
670
+ # at object level.
671
+ self .params = _RegistratedBufferDict (vars (params ))
672
+
660
673
@staticmethod
661
674
def get_output_dim (args : DictConfig ) -> int :
662
675
"""
@@ -672,12 +685,12 @@ def get_output_dim(args: DictConfig) -> int:
672
685
args ["voxel_grid_" + args ["voxel_grid_class_type" ] + "_args" ]
673
686
)
674
687
675
- def subscribe_to_epochs (self ) -> Tuple [List [int ], Callable [[int ], bool ]]:
688
+ def subscribe_to_epochs (self ) -> Tuple [Tuple [int , ... ], Callable [[int ], bool ]]:
676
689
"""
677
690
Method which expresses interest in subscribing to optimization epoch updates.
678
691
679
692
Returns:
680
- list of epochs on which to call a callable and callable to be called on
693
+ tuple of epochs on which to call a callable and callable to be called on
681
694
particular epoch. The callable returns True if parameter change has
682
695
happened else False and it must be supplied with one argument, epoch.
683
696
"""
@@ -697,13 +710,12 @@ def _apply_epochs(self, epoch: int) -> bool:
697
710
"""
698
711
# pyre-ignore[29]
699
712
grid_values = self .voxel_grid .values_type (** self .params )
700
- grid_values , change = self .voxel_grid .change_resolution (epoch , grid_values )
713
+ grid_values , change = self .voxel_grid .change_resolution (
714
+ grid_values , epoch = epoch
715
+ )
701
716
if change :
702
- # pyre-ignore[16]
703
- self .params = torch .nn .ParameterDict (
704
- {name : tensor for name , tensor in vars (grid_values ).items ()}
705
- )
706
- return change
717
+ self .set_voxel_grid_parameters (grid_values )
718
+ return change and self .hold_voxel_grid_as_parameters
707
719
708
720
def _create_parameters_with_new_size (
709
721
self ,
@@ -749,5 +761,113 @@ def _create_parameters_with_new_size(
749
761
key = prefix + "params." + name
750
762
if key in state_dict :
751
763
new_params [name ] = torch .zeros_like (state_dict [key ])
752
- # pyre-ignore[16]
753
- self .params = torch .nn .ParameterDict (new_params )
764
+ # pyre-ignore[29]
765
+ self .set_voxel_grid_parameters (self .voxel_grid .values_type (** new_params ))
766
+
767
+ def get_device (self ) -> torch .device :
768
+ """
769
+ Returns torch.device on which module parameters are located
770
+ """
771
+ # pyre-ignore[29]
772
+ return next (val for val in self .params .values () if val is not None ).device
773
+
774
+ def _get_volume_locator (self ) -> VolumeLocator :
775
+ """
776
+ Returns VolumeLocator calculated from `extents` and `translation` members.
777
+ """
778
+ return VolumeLocator (
779
+ batch_size = 1 ,
780
+ # The resolution of the voxel grid does not need to be known
781
+ # to the locator object. It is easiest to fix the resolution of the locator.
782
+ # In particular we fix it to (2,2,2) so that there is exactly one voxel of the
783
+ # desired size. The locator object uses (z, y, x) convention for the grid_size,
784
+ # and this module uses (x, y, z) convention so the order has to be reversed
785
+ # (irrelevant in this case since they are all equal).
786
+ # It is (2, 2, 2) because the VolumeLocator object behaves like
787
+ # align_corners=True, which means that the points are in the corners of
788
+ # the volume. So in the grid of (2, 2, 2) there is only one voxel.
789
+ grid_sizes = (2 , 2 , 2 ),
790
+ # The locator object uses (x, y, z) convention for the
791
+ # voxel size and translation.
792
+ voxel_size = tuple (self .extents ),
793
+ # volume_translation is defined in `VolumeLocator` as a vector from the origin
794
+ # of local coordinate frame to origin of world coordinate frame, that is:
795
+ # x_world = x_local * extents/2 - translation.
796
+ # To get the reverse we need to negate it.
797
+ volume_translation = tuple (- t for t in self .translation ),
798
+ device = self .get_device (),
799
+ )
800
+
801
+ def get_grid_points (self , epoch : int ) -> torch .Tensor :
802
+ """
803
+ Returns a grid of points that represent centers of voxels of the
804
+ underlying voxel grid in world coordinates at specific epoch.
805
+
806
+ Args:
807
+ epoch: underlying voxel grids change resolution depending on the
808
+ epoch, this argument is used to determine the resolution
809
+ of the voxel grid at that epoch.
810
+ Returns:
811
+ tensor of shape [xresolution, yresolution, zresolution, 3] where
812
+ xresolution, yresolution, zresolution are resolutions of the
813
+ underlying voxel grid
814
+ """
815
+ xresolution , yresolution , zresolution = self .voxel_grid .get_resolution (epoch )
816
+ width , height , depth = self .extents
817
+ if not self .voxel_grid .get_align_corners ():
818
+ width = (
819
+ width * (xresolution - 1 ) / xresolution if xresolution > 1 else width
820
+ )
821
+ height = (
822
+ height * (xresolution - 1 ) / xresolution if xresolution > 1 else height
823
+ )
824
+ depth = (
825
+ depth * (xresolution - 1 ) / xresolution if xresolution > 1 else depth
826
+ )
827
+ xs = torch .linspace (
828
+ - width / 2 , width / 2 , xresolution , device = self .get_device ()
829
+ )
830
+ ys = torch .linspace (
831
+ - height / 2 , height / 2 , yresolution , device = self .get_device ()
832
+ )
833
+ zs = torch .linspace (
834
+ - depth / 2 , depth / 2 , zresolution , device = self .get_device ()
835
+ )
836
+ xmesh , ymesh , zmesh = torch .meshgrid (xs , ys , zs , indexing = "ij" )
837
+ return torch .stack ((xmesh , ymesh , zmesh ), dim = 3 )
838
+
839
+
840
+ class _RegistratedBufferDict (torch .nn .Module , Mapping ):
841
+ """
842
+ Mapping class and a torch.nn.Module that registeres its values
843
+ with `self.register_buffer`. Can be indexed like a regular Python
844
+ dictionary, but torch.Tensors it contains are properly registered, and will be visible
845
+ by all Module methods. Supports only `torch.Tensor` as value and str as key.
846
+ """
847
+
848
+ def __init__ (self , init_dict : Optional [Dict [str , torch .Tensor ]] = None ) -> None :
849
+ """
850
+ Args:
851
+ init_dict: dictionary which will be used to populate the object
852
+ """
853
+ super ().__init__ ()
854
+ self ._keys = set ()
855
+ if init_dict is not None :
856
+ for k , v in init_dict .items ():
857
+ self [k ] = v
858
+
859
+ def __iter__ (self ) -> Iterator [Dict [str , torch .Tensor ]]:
860
+ return iter ({k : self [k ] for k in self ._keys })
861
+
862
+ def __len__ (self ) -> int :
863
+ return len (self ._keys )
864
+
865
+ def __getitem__ (self , key : str ) -> torch .Tensor :
866
+ return getattr (self , key )
867
+
868
+ def __setitem__ (self , key , value ) -> None :
869
+ self ._keys .add (key )
870
+ self .register_buffer (key , value )
871
+
872
+ def __hash__ (self ) -> int :
873
+ return hash (repr (self ))
0 commit comments