-
Notifications
You must be signed in to change notification settings - Fork 6k
Pyramid Attention Broadcast rewrite + introduce hooks #9826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
See an internal thread here: https://huggingface.slack.com/archives/C01Q6JPP6NA/p1726820287833739 TLDR was to use CUDA streams for this but I would be positively surprised if hooks can work with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments. LMK if they make sense.
|
||
|
||
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py | ||
class ModelHook: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's not copy-pasted verbatim, then let's add a note describing how it differs from accelerate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied verbatim, except for a single attribute (which I think I will remove soon)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then why can't we import directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it would be okay to import but the implementations may diverge soon as the design in my mind might require diffusers to have some additional class attributes + does not make sense to depend on another library for something core to diffusers IMO (any change they make upstream will affect us, and since this is very simple to have a base class, it might better to have our own implementation). Will use imports if we don't need anything additional at the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh then I am okay to have them copy-pasted with this exact note.
return module | ||
|
||
|
||
class SequentialHook(ModelHook): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied completely
return module._diffusers_hook.post_forward(module, output) | ||
|
||
def post_forward(self, module: torch.nn.Module, output: Any) -> Any: | ||
self.attention_cache = output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could consider making this configurable i.e., if we want to offload the cache to CPU/disk while executing a certain computation on the GPU. This can be done with CUDA streams. This way, we should be able to overlap computation and device transfer, saving time.
https://github.com/huggingface/transformers/blob/f83917ed89292072e57c4ba013d1cc6477538b11/src/transformers/cache_utils.py is a good reference for this.
spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], | ||
temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"], | ||
cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just specify "transformer_blocks"
instead of ("blocks", "transformer_blocks")
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have some implementations that use "blocks" and some that use "transformer_blocks" because of not strictly following some conventions - maybe we can create alias attribute called "transformer_blocks" for places that use "blocks" and just default to the one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah better approach indeed!
temporal_attn_skip_range (`int`, *optional*): | ||
The attention block to execute after skipping intermediate temporal attention blocks. If set to the | ||
value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have | ||
different tolerances to how much attention computation can be reused based on the differences between | ||
successive blocks. So, this parameter must be adjusted per model after performing experimentation. | ||
Setting this value to `4` is recommended for different models PAB has been experimented with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(doc nit): we could consider adding the additional piece of info as a "Tip". This way it would stand out better for a user.
I think that is a separate issue for offloading things that might not be compatible with torch.compile and might require us to use that. For now, torch.compile might be broken for PAB due to conditional statements in forward pass that cause different forward branches to be followed based on whether attention should be computed or not. This did not work until a few weeks/months back but seems to have been resolved (see pytorch/pytorch#124717, which contains good commit/cross-references) if you set |
My comment was not particularly for offloading + torch.compile but more for hooks + torch.compile which I thought applies here. But maybe not since you mentioned:
Can we try using |
Hooks are perfectly fine with torch.compile, if they are not mutating state of the module itself or changing the control flow (PAB changes control flow, so this would not have worked until a few months ago). In this case, I'm not sure if torch.where would be applicable since we want to determine IF a specific branch of the module will be executed or not. If the branch is executed, we utilize the returns from it, but if not, we return the cached values. For torch.where to work here, we would already need to have result from execution on hand to place in the conditional, but doing the computation would lead to not realizing the speedup gains from PAB. If I'm missing something, please LMK because it would be really nice if there was a clean way to make this torch.compile compatible, and not have the perf penalty that comes with guard_nn_modules. |
Do you have a very minimal working example here?
I can look into it. Would you be able to provide a minimal working example for me to work with? |
When there are no modifications to the module/input state, or a flip on a conditional that would affect the forward pass (in which case When we just clone the inputs or do something that does not effect module/input state (for example, device transfer would affect module state but caching inputs would not [imagine resnet like residual caching that can later be used in post_forward hook]), I believe all cudagraphs conditions are met, so Codeimport os
os.environ["TORCH_LOGS"] = "+output_code,recompiles,graph_breaks"
import torch
import torch.nn as nn
import torch._inductor.config
# torch._inductor.config.triton.cudagraph_support_input_mutation = True
from diffusers.models.hooks import ModelHook, add_hook_to_module
class MyModule(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
def forward(self, x):
return self.linear(x)
class InputNonModifierAndOutputCacher(ModelHook):
def __init__(self):
super().__init__()
self.args_register = None
self.cache = None
def pre_forward(self, module, *args, **kwargs):
self.args_register = (arg.clone() for arg in args if isinstance(arg, torch.Tensor))
return args, kwargs
def post_forward(self, module, output):
self.cache = output
return output
def reset_state(self, module):
self.args_register = None
self.cache = None
class InputModifierAndOutputCacher(ModelHook):
def __init__(self):
super().__init__()
self.cache = None
def pre_forward(self, module, *args, **kwargs):
for arg in args:
if isinstance(arg, torch.Tensor):
arg += 1.0
return args, kwargs
def post_forward(self, module, output):
self.cache = output
return output
def reset_state(self, module):
self.cache = None
batch_size = 4
dim = 128
input = torch.randn((batch_size, dim), device="cuda")
model = MyModule(dim).to("cuda")
print(model(input).shape)
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True)
print(compiled_model(input).shape)
hook = InputNonModifierAndOutputCacher()
hook_model = add_hook_to_module(model, hook, append=True)
# Can just use model too since the modification is inplace
print(hook_model(input).shape)
hook_compiled_model = torch.compile(hook_model, mode="max-autotune", fullgraph=True)
print(hook_compiled_model(input).shape)
hook_compiled_model_no_cudagraphs = torch.compile(hook_model, mode="max-autotune-no-cudagraphs", fullgraph=True)
print(hook_compiled_model_no_cudagraphs(input).shape)
hook = InputModifierAndOutputCacher()
hook_model = add_hook_to_module(model, hook, append=True)
# Can just use model too since the modification is inplace
print(hook_model(input).shape)
hook_compiled_model = torch.compile(hook_model, mode="max-autotune", fullgraph=True)
print(hook_compiled_model(input).shape)
hook_compiled_model_no_cudagraphs = torch.compile(hook_model, mode="max-autotune-no-cudagraphs", fullgraph=True)
print(hook_compiled_model_no_cudagraphs(input).shape) Output
Thanks! Here is the code for running benchmark comparison between normal vs PAB: Codeimport gc
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from tabulate import tabulate
def reset_memory():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
def pretty_print_results(results, precision: int = 3):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
@torch.no_grad()
def test_cogvideox_5b():
reset_memory()
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
model_memory = torch.cuda.memory_allocated() / 1024**3
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
# Warmup
_ = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=2,
generator=torch.Generator().manual_seed(31337),
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
normal_time = start.elapsed_time(end) / 1000
normal_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/cogvideox_5b.mp4", fps=8)
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
latent = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(31337),
output_type="latent",
).frames
end.record()
torch.cuda.synchronize()
pab_time = start.elapsed_time(end) / 1000
pab_memory = torch.cuda.max_memory_reserved() / 1024**3
video = pipe.decode_latents(latent)
video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
export_to_video(video, "outputs/cogvideox_pab_5b.mp4", fps=8)
return {
"model_memory": model_memory,
"normal_memory": normal_memory,
"pab_memory": pab_memory,
"normal_time": normal_time,
"pab_time": pab_time,
}
results = test_cogvideox_5b()
pretty_print_results(results) This is the absolute minimal example: Codeimport torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8) |
Makes sense. Thanks for explaining! We can visit the requirement of using CUDAgraphs later depending on the level of performance penalty it incurs. It might well be that we won't use CUDAgraphs at all to support the design.
I guess we will need to call torch.compile() after PAB is applied here? Just confirming. |
* start pyramid attention broadcast * add coauthor Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> * update * make style * update * make style * add docs * add tests * update * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Pyramid Attention Broadcast rewrite + introduce hooks (#9826) * rewrite implementation with hooks * make style * update * merge pyramid-attention-rewrite-2 * make style * remove changes from latte transformer * revert docs changes * better debug message * add todos for future * update tests * make style * cleanup * fix * improve log message; fix latte test * refactor * update * update * update * revert changes to tests * update docs * update tests * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * fix flux test * reorder * refactor * make fix-copies * update docs * fixes * more fixes * make style * update tests * update code example * make fix-copies * refactor based on reviews * use maybe_free_model_hooks * CacheMixin * make style * update * add current_timestep property; update docs * make fix-copies * update * improve tests * try circular import fix * apply suggestions from review * address review comments * Apply suggestions from code review * refactor hook implementation * add test suite for hooks * PAB Refactor (#10667) * update * update * update --------- Co-authored-by: DN6 <dhruv.nair@gmail.com> * update * fix remove hook behaviour --------- Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: DN6 <dhruv.nair@gmail.com>
Introduces accelerate-like hook support, which would be useful for:
Internal discussion: https://huggingface.slack.com/archives/C03HBN1C8CW/p1730123963082119
(Most of the dicussion is private convo b/w Dhruv and me, so I'll summarize and share that soon)
This is just a prototype/test for how the hook-based implementation would look like for PAB. Taking a look at making compatible with
torch.compile
soon.cc @DN6 @yiyixuxu @sayakpaul