Skip to content

Group offloading with cuda stream prefetching #10516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 11, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 10, 2025

See #10503 for more details.

Code
import argparse
import gc
import os
import time

# os.environ["TORCH_LOGS"] = "+recompiles_verbose,guards"
os.environ["TORCH_LOGS"] = "+recompiles_verbose"

import types
types.MethodType

import torch
import torch._dynamo
import torch._dynamo.utils
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
from diffusers.hooks import apply_group_offloading

torch._dynamo.config.cache_size_limit = 4

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def normal():
    pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
    pipe.to("cuda")
    
    cleanup()
    print(f"Model memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    t1 = time.time()
    video = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=768,
        height=512,
        num_frames=161,
        num_inference_steps=50,
    ).frames[0]
    torch.cuda.synchronize()
    t2 = time.time()
    export_to_video(video, "output.mp4", fps=24)
    
    print(f"Inference time: {t2 - t1:.2f} s")
    print(f"Inference memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")


def model_cpu_offload(compile: bool = False):
    pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
    pipe.enable_model_cpu_offload()

    if compile:
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
    
    cleanup()
    print(f"Model memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    t1 = time.time()
    video = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=768,
        height=512,
        num_frames=161,
        num_inference_steps=50,
    ).frames[0]
    torch.cuda.synchronize()
    t2 = time.time()
    export_to_video(video, "output.mp4", fps=24)
    
    print(f"Inference time: {t2 - t1:.2f} s")
    print(f"Inference memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")


def sequential_cpu_offload(compile: bool = False):
    pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
    pipe.enable_sequential_cpu_offload()

    if compile:
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
    
    cleanup()
    print(f"Model memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    t1 = time.time()
    video = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=768,
        height=512,
        num_frames=161,
        num_inference_steps=50,
    ).frames[0]
    torch.cuda.synchronize()
    t2 = time.time()
    export_to_video(video, "output.mp4", fps=24)
    
    print(f"Inference time: {t2 - t1:.2f} s")
    print(f"Inference memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")


def group_offloading(offload_group_patterns: str = "diffusers_block", num_blocks_per_group: int = 4, compile: bool = False):
    pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
    pipe.text_encoder.to("cuda")
    pipe.vae.to("cuda")
    apply_group_offloading(
        module=pipe.transformer,
        offload_group_patterns=offload_group_patterns,
        num_blocks_per_group=num_blocks_per_group,
        offload_device=torch.device("cpu"),
        onload_device=torch.device("cuda"),
        force_offload=True,
        non_blocking=True,
        # cuda_stream=True,
    )
    if compile:
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
        # pipe.transformer = torch.compile(pipe.transformer, backend="eager")
    
    cleanup()
    print(f"Model memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

    prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
    negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

    t1 = time.time()
    video = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=768,
        height=512,
        num_frames=161,
        num_inference_steps=50,
    ).frames[0]
    torch.cuda.synchronize()
    t2 = time.time()
    export_to_video(video, "output.mp4", fps=24)

    compile_times = torch._dynamo.utils.compile_times()
    print(f"Compile times: {compile_times}")
    
    print(f"Inference time: {t2 - t1:.2f} s")
    print(f"Inference memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--offload_type", type=str, default="normal", choices=["normal", "model", "sequential", "group"])
    parser.add_argument("--offload_group_patterns", type=str, default="diffusers_block")
    parser.add_argument("--num_blocks_per_group", type=int, default=4)
    parser.add_argument("--compile", action="store_true")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    if args.offload_type == "normal":
        normal()
    elif args.offload_type == "model":
        model_cpu_offload(args.compile)
    elif args.offload_type == "sequential":
        sequential_cpu_offload(args.compile)
    else:
        group_offloading(args.offload_group_patterns, args.num_blocks_per_group, args.compile)

Results for:

python3 test.py --offload_type group --offload_group_patterns diffusers_block --num_blocks_per_group 4

Without cuda stream prefetching:

Model memory: 9.70 GB
Inference time: 85.36 s
Inference memory: 21.25 GB

With cuda streams:

# 4 block groups
Model memory: 10.26 GB
Inference time: 47.71 s
Inference memory: 21.89 GB

# 1 block groups
Model memory: 9.84 GB
Inference time: 46.15 s
Inference memory: 21.04 

Currently not compatible with torch.compile and triggers several recompilations.

Reading:

cc @gau-nernst because the original ideas for layer prefetching come from his implementation

cc @DN6 @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +204 to +207
if len(module_groups) > 1:
# Assign the first module_group as the next_group for the last module_group
hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader)
hook_registry.hooks["group_offloading"].next_group = module_groups[0]
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky for the moment just so that I could get it running quickly without putting in much thought. Will try to improve soon

@a-r-r-o-w
Copy link
Member Author

To avoid some confusion, the numbers in the PR are not a 1:1 comparison with the group offloading PR.

In that PR, the benchmark only times and notes the memory requirement for transformer forward pass (prompt embeddings are precomputed and latents are decoded separately), but here they are including both prompt embedding and decoding steps. Will do a fair comparison numbers report shortly after completing benchmark script

@a-r-r-o-w a-r-r-o-w merged commit d579037 into groupwise-offloading Jan 11, 2025
2 checks passed
@a-r-r-o-w a-r-r-o-w deleted the cuda-stream-group-offloading branch January 11, 2025 07:16
DN6 pushed a commit that referenced this pull request Feb 14, 2025
* update

* fix

* non_blocking; handle parameters and buffers

* update

* Group offloading with cuda stream prefetching (#10516)

* cuda stream prefetch

* remove breakpoints

* update

* copy model hook implementation from pab

* update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite

* more workarounds to make it actually work

* cleanup

* rewrite

* update

* make sure to sync current stream before overwriting with pinned params

not doing so will lead to erroneous computations on the GPU and cause bad results

* better check

* update

* remove hook implementation to not deal with merge conflict

* re-add hook changes

* why use more memory when less memory do trick

* why still use slightly more memory when less memory do trick

* optimise

* add model tests

* add pipeline tests

* update docs

* add layernorm and groupnorm

* address review comments

* improve tests; add docs

* improve docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* apply suggestions from code review

* update tests

* apply suggestions from review

* enable_group_offloading -> enable_group_offload for naming consistency

* raise errors if multiple offloading strategies used; add relevant tests

* handle .to() when group offload applied

* refactor some repeated code

* remove unintentional change from merge conflict

* handle .cuda()

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants