@@ -75,15 +75,15 @@ def __init__(
75
75
76
76
if self .stream is not None and self .cpu_param_dict is None :
77
77
raise ValueError ("`cpu_param_dict` must be provided when using stream for data transfer." )
78
-
78
+
79
79
if self .record_stream and not self .stream :
80
80
raise ValueError ("`record_stream` cannot be True when `stream` is None." )
81
81
82
82
def onload_ (self ):
83
83
r"""Onloads the group of modules to the onload_device."""
84
84
context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
85
85
current_stream = torch .cuda .current_stream () if self .record_stream else None
86
-
86
+
87
87
if self .stream is not None :
88
88
# Wait for previous Host->Device transfer to complete
89
89
self .stream .synchronize ()
@@ -283,7 +283,7 @@ def apply_group_offloading(
283
283
num_blocks_per_group : Optional [int ] = None ,
284
284
non_blocking : bool = False ,
285
285
use_stream : bool = False ,
286
- record_stream : bool = False
286
+ record_stream : bool = False ,
287
287
) -> None :
288
288
r"""
289
289
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -381,7 +381,7 @@ def _apply_group_offloading_block_level(
381
381
onload_device : torch .device ,
382
382
non_blocking : bool ,
383
383
stream : Optional [torch .cuda .Stream ] = None ,
384
- record_stream : Optional [bool ] = False
384
+ record_stream : Optional [bool ] = False ,
385
385
) -> None :
386
386
r"""
387
387
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -482,7 +482,7 @@ def _apply_group_offloading_leaf_level(
482
482
onload_device : torch .device ,
483
483
non_blocking : bool ,
484
484
stream : Optional [torch .cuda .Stream ] = None ,
485
- record_stream : Optional [bool ] = False
485
+ record_stream : Optional [bool ] = False ,
486
486
) -> None :
487
487
r"""
488
488
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
0 commit comments