@@ -83,7 +83,10 @@ def onload_(self):
83
83
84
84
with context :
85
85
for group_module in self .modules :
86
- group_module .to (self .onload_device , non_blocking = self .non_blocking )
86
+ for param in group_module .parameters ():
87
+ param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
88
+ for buffer in group_module .buffers ():
89
+ buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
87
90
if self .parameters is not None :
88
91
for param in self .parameters :
89
92
param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
@@ -98,6 +101,12 @@ def offload_(self):
98
101
for group_module in self .modules :
99
102
for param in group_module .parameters ():
100
103
param .data = self .cpu_param_dict [param ]
104
+ if self .parameters is not None :
105
+ for param in self .parameters :
106
+ param .data = self .cpu_param_dict [param ]
107
+ if self .buffers is not None :
108
+ for buffer in self .buffers :
109
+ buffer .data = self .cpu_param_dict [buffer ]
101
110
else :
102
111
for group_module in self .modules :
103
112
group_module .to (self .offload_device , non_blocking = self .non_blocking )
@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
387
396
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
388
397
cpu_param_dict = None
389
398
if stream is not None :
390
- for param in module .parameters ():
391
- param .data = param .data .cpu ().pin_memory ()
392
- cpu_param_dict = {param : param .data for param in module .parameters ()}
399
+ cpu_param_dict = _get_pinned_cpu_param_dict (module )
393
400
394
401
# Create module groups for ModuleList and Sequential blocks
395
402
modules_with_group_offloading = set ()
@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
486
493
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
487
494
cpu_param_dict = None
488
495
if stream is not None :
489
- for param in module .parameters ():
490
- param .data = param .data .cpu ().pin_memory ()
491
- cpu_param_dict = {param : param .data for param in module .parameters ()}
496
+ cpu_param_dict = _get_pinned_cpu_param_dict (module )
492
497
493
498
# Create module groups for leaf modules and apply group offloading hooks
494
499
modules_with_group_offloading = set ()
@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
604
609
registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
605
610
606
611
612
+ def _get_pinned_cpu_param_dict (module : torch .nn .Module ) -> Dict [torch .nn .Parameter , torch .Tensor ]:
613
+ cpu_param_dict = {}
614
+ for param in module .parameters ():
615
+ param .data = param .data .cpu ().pin_memory ()
616
+ cpu_param_dict [param ] = param .data
617
+ for buffer in module .buffers ():
618
+ buffer .data = buffer .data .cpu ().pin_memory ()
619
+ cpu_param_dict [buffer ] = buffer .data
620
+ return cpu_param_dict
621
+
622
+
607
623
def _gather_parameters_with_no_group_offloading_parent (
608
624
module : torch .nn .Module , modules_with_group_offloading : Set [str ]
609
625
) -> List [torch .nn .Parameter ]:
0 commit comments