Skip to content

Commit ffce2d1

Browse files
committed
implement record_stream for better performance.
1 parent 1001425 commit ffce2d1

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
buffers: Optional[List[torch.Tensor]] = None,
5757
non_blocking: bool = False,
5858
stream: Optional[torch.cuda.Stream] = None,
59+
record_stream: Optional[bool] = False,
5960
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
6061
onload_self: bool = True,
6162
) -> None:
@@ -68,32 +69,45 @@ def __init__(
6869
self.buffers = buffers
6970
self.non_blocking = non_blocking or stream is not None
7071
self.stream = stream
72+
self.record_stream = record_stream
7173
self.cpu_param_dict = cpu_param_dict
7274
self.onload_self = onload_self
7375

7476
if self.stream is not None and self.cpu_param_dict is None:
75-
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
77+
raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.")
78+
79+
if self.record_stream and not self.stream:
80+
raise ValueError("`record_stream` cannot be True when `stream` is None.")
7681

7782
def onload_(self):
7883
r"""Onloads the group of modules to the onload_device."""
7984
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
85+
current_stream = torch.cuda.current_stream() if self.record_stream else None
86+
8087
if self.stream is not None:
8188
# Wait for previous Host->Device transfer to complete
8289
self.stream.synchronize()
8390

8491
with context:
8592
for group_module in self.modules:
8693
group_module.to(self.onload_device, non_blocking=self.non_blocking)
94+
if self.record_stream:
95+
for param in group_module.parameters():
96+
param.data.record_stream(current_stream)
8797
if self.parameters is not None:
8898
for param in self.parameters:
8999
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
100+
if self.record_stream:
101+
param.data.record_stream(current_stream)
90102
if self.buffers is not None:
91103
for buffer in self.buffers:
92104
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
105+
if self.record_stream:
106+
buffer.data.record_stream(current_stream)
93107

94108
def offload_(self):
95109
r"""Offloads the group of modules to the offload_device."""
96-
if self.stream is not None:
110+
if self.stream is not None and not self.record_stream:
97111
torch.cuda.current_stream().synchronize()
98112
for group_module in self.modules:
99113
for param in group_module.parameters():
@@ -268,6 +282,7 @@ def apply_group_offloading(
268282
num_blocks_per_group: Optional[int] = None,
269283
non_blocking: bool = False,
270284
use_stream: bool = False,
285+
record_stream: bool = False
271286
) -> None:
272287
r"""
273288
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -314,6 +329,7 @@ def apply_group_offloading(
314329
use_stream (`bool`, defaults to `False`):
315330
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
316331
overlapping computation and data transfer.
332+
record_stream: TODO
317333
318334
Example:
319335
```python
@@ -349,10 +365,10 @@ def apply_group_offloading(
349365
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
350366

351367
_apply_group_offloading_block_level(
352-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
368+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, record_stream
353369
)
354370
elif offload_type == "leaf_level":
355-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
371+
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, record_stream)
356372
else:
357373
raise ValueError(f"Unsupported offload_type: {offload_type}")
358374

@@ -364,6 +380,7 @@ def _apply_group_offloading_block_level(
364380
onload_device: torch.device,
365381
non_blocking: bool,
366382
stream: Optional[torch.cuda.Stream] = None,
383+
record_stream: Optional[bool] = False
367384
) -> None:
368385
r"""
369386
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -382,6 +399,7 @@ def _apply_group_offloading_block_level(
382399
stream (`torch.cuda.Stream`, *optional*):
383400
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
384401
for overlapping computation and data transfer.
402+
record_stream: TODO
385403
"""
386404

387405
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
@@ -411,6 +429,7 @@ def _apply_group_offloading_block_level(
411429
onload_leader=current_modules[0],
412430
non_blocking=non_blocking,
413431
stream=stream,
432+
record_stream=record_stream,
414433
cpu_param_dict=cpu_param_dict,
415434
onload_self=stream is None,
416435
)
@@ -448,6 +467,7 @@ def _apply_group_offloading_block_level(
448467
buffers=buffers,
449468
non_blocking=False,
450469
stream=None,
470+
record_stream=False,
451471
cpu_param_dict=None,
452472
onload_self=True,
453473
)
@@ -461,6 +481,7 @@ def _apply_group_offloading_leaf_level(
461481
onload_device: torch.device,
462482
non_blocking: bool,
463483
stream: Optional[torch.cuda.Stream] = None,
484+
record_stream: Optional[bool] = False
464485
) -> None:
465486
r"""
466487
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -481,6 +502,7 @@ def _apply_group_offloading_leaf_level(
481502
stream (`torch.cuda.Stream`, *optional*):
482503
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
483504
for overlapping computation and data transfer.
505+
record_stream: TODO
484506
"""
485507

486508
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
@@ -503,6 +525,7 @@ def _apply_group_offloading_leaf_level(
503525
onload_leader=submodule,
504526
non_blocking=non_blocking,
505527
stream=stream,
528+
record_stream=record_stream,
506529
cpu_param_dict=cpu_param_dict,
507530
onload_self=True,
508531
)
@@ -548,6 +571,7 @@ def _apply_group_offloading_leaf_level(
548571
buffers=buffers,
549572
non_blocking=non_blocking,
550573
stream=stream,
574+
record_stream=record_stream,
551575
cpu_param_dict=cpu_param_dict,
552576
onload_self=True,
553577
)
@@ -567,6 +591,7 @@ def _apply_group_offloading_leaf_level(
567591
buffers=None,
568592
non_blocking=False,
569593
stream=None,
594+
record_stream=False,
570595
cpu_param_dict=None,
571596
onload_self=True,
572597
)

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def enable_group_offload(
546546
num_blocks_per_group: Optional[int] = None,
547547
non_blocking: bool = False,
548548
use_stream: bool = False,
549+
record_stream: bool = False,
549550
) -> None:
550551
r"""
551552
Activates group offloading for the current model.
@@ -584,7 +585,7 @@ def enable_group_offload(
584585
f"open an issue at https://github.com/huggingface/diffusers/issues."
585586
)
586587
apply_group_offloading(
587-
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
588+
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream, record_stream
588589
)
589590

590591
def save_pretrained(

0 commit comments

Comments
 (0)