@@ -56,6 +56,7 @@ def __init__(
56
56
buffers : Optional [List [torch .Tensor ]] = None ,
57
57
non_blocking : bool = False ,
58
58
stream : Optional [torch .cuda .Stream ] = None ,
59
+ record_stream : Optional [bool ] = False ,
59
60
cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] = None ,
60
61
onload_self : bool = True ,
61
62
) -> None :
@@ -68,32 +69,45 @@ def __init__(
68
69
self .buffers = buffers
69
70
self .non_blocking = non_blocking or stream is not None
70
71
self .stream = stream
72
+ self .record_stream = record_stream
71
73
self .cpu_param_dict = cpu_param_dict
72
74
self .onload_self = onload_self
73
75
74
76
if self .stream is not None and self .cpu_param_dict is None :
75
- raise ValueError ("cpu_param_dict must be provided when using stream for data transfer." )
77
+ raise ValueError ("`cpu_param_dict` must be provided when using stream for data transfer." )
78
+
79
+ if self .record_stream and not self .stream :
80
+ raise ValueError ("`record_stream` cannot be True when `stream` is None." )
76
81
77
82
def onload_ (self ):
78
83
r"""Onloads the group of modules to the onload_device."""
79
84
context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
85
+ current_stream = torch .cuda .current_stream () if self .record_stream else None
86
+
80
87
if self .stream is not None :
81
88
# Wait for previous Host->Device transfer to complete
82
89
self .stream .synchronize ()
83
90
84
91
with context :
85
92
for group_module in self .modules :
86
93
group_module .to (self .onload_device , non_blocking = self .non_blocking )
94
+ if self .record_stream :
95
+ for param in group_module .parameters ():
96
+ param .data .record_stream (current_stream )
87
97
if self .parameters is not None :
88
98
for param in self .parameters :
89
99
param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
100
+ if self .record_stream :
101
+ param .data .record_stream (current_stream )
90
102
if self .buffers is not None :
91
103
for buffer in self .buffers :
92
104
buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
105
+ if self .record_stream :
106
+ buffer .data .record_stream (current_stream )
93
107
94
108
def offload_ (self ):
95
109
r"""Offloads the group of modules to the offload_device."""
96
- if self .stream is not None :
110
+ if self .stream is not None and not self . record_stream :
97
111
torch .cuda .current_stream ().synchronize ()
98
112
for group_module in self .modules :
99
113
for param in group_module .parameters ():
@@ -268,6 +282,7 @@ def apply_group_offloading(
268
282
num_blocks_per_group : Optional [int ] = None ,
269
283
non_blocking : bool = False ,
270
284
use_stream : bool = False ,
285
+ record_stream : bool = False
271
286
) -> None :
272
287
r"""
273
288
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -314,6 +329,7 @@ def apply_group_offloading(
314
329
use_stream (`bool`, defaults to `False`):
315
330
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
316
331
overlapping computation and data transfer.
332
+ record_stream: TODO
317
333
318
334
Example:
319
335
```python
@@ -349,10 +365,10 @@ def apply_group_offloading(
349
365
raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
350
366
351
367
_apply_group_offloading_block_level (
352
- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream
368
+ module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , record_stream
353
369
)
354
370
elif offload_type == "leaf_level" :
355
- _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream )
371
+ _apply_group_offloading_leaf_level (module , offload_device , onload_device , non_blocking , stream , record_stream )
356
372
else :
357
373
raise ValueError (f"Unsupported offload_type: { offload_type } " )
358
374
@@ -364,6 +380,7 @@ def _apply_group_offloading_block_level(
364
380
onload_device : torch .device ,
365
381
non_blocking : bool ,
366
382
stream : Optional [torch .cuda .Stream ] = None ,
383
+ record_stream : Optional [bool ] = False
367
384
) -> None :
368
385
r"""
369
386
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -382,6 +399,7 @@ def _apply_group_offloading_block_level(
382
399
stream (`torch.cuda.Stream`, *optional*):
383
400
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
384
401
for overlapping computation and data transfer.
402
+ record_stream: TODO
385
403
"""
386
404
387
405
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
@@ -411,6 +429,7 @@ def _apply_group_offloading_block_level(
411
429
onload_leader = current_modules [0 ],
412
430
non_blocking = non_blocking ,
413
431
stream = stream ,
432
+ record_stream = record_stream ,
414
433
cpu_param_dict = cpu_param_dict ,
415
434
onload_self = stream is None ,
416
435
)
@@ -448,6 +467,7 @@ def _apply_group_offloading_block_level(
448
467
buffers = buffers ,
449
468
non_blocking = False ,
450
469
stream = None ,
470
+ record_stream = False ,
451
471
cpu_param_dict = None ,
452
472
onload_self = True ,
453
473
)
@@ -461,6 +481,7 @@ def _apply_group_offloading_leaf_level(
461
481
onload_device : torch .device ,
462
482
non_blocking : bool ,
463
483
stream : Optional [torch .cuda .Stream ] = None ,
484
+ record_stream : Optional [bool ] = False
464
485
) -> None :
465
486
r"""
466
487
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -481,6 +502,7 @@ def _apply_group_offloading_leaf_level(
481
502
stream (`torch.cuda.Stream`, *optional*):
482
503
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
483
504
for overlapping computation and data transfer.
505
+ record_stream: TODO
484
506
"""
485
507
486
508
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
@@ -503,6 +525,7 @@ def _apply_group_offloading_leaf_level(
503
525
onload_leader = submodule ,
504
526
non_blocking = non_blocking ,
505
527
stream = stream ,
528
+ record_stream = record_stream ,
506
529
cpu_param_dict = cpu_param_dict ,
507
530
onload_self = True ,
508
531
)
@@ -548,6 +571,7 @@ def _apply_group_offloading_leaf_level(
548
571
buffers = buffers ,
549
572
non_blocking = non_blocking ,
550
573
stream = stream ,
574
+ record_stream = record_stream ,
551
575
cpu_param_dict = cpu_param_dict ,
552
576
onload_self = True ,
553
577
)
@@ -567,6 +591,7 @@ def _apply_group_offloading_leaf_level(
567
591
buffers = None ,
568
592
non_blocking = False ,
569
593
stream = None ,
594
+ record_stream = False ,
570
595
cpu_param_dict = None ,
571
596
onload_self = True ,
572
597
)
0 commit comments