Skip to content

Pyramid Attention Broadcast rewrite + introduce hooks #9826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 8, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Oct 31, 2024

Introduces accelerate-like hook support, which would be useful for:

  • Partial CPU Offloading
  • Cache Offloading in VAE
  • Pyramid Attention Broadcast ([core] Pyramid Attention Broadcast #9562) attention state caching
  • Split inference across non-channel dims
  • Intermediate layer output caching for conditions that don't change between inference steps

Internal discussion: https://huggingface.slack.com/archives/C03HBN1C8CW/p1730123963082119

(Most of the dicussion is private convo b/w Dhruv and me, so I'll summarize and share that soon)

This is just a prototype/test for how the hook-based implementation would look like for PAB. Taking a look at making compatible with torch.compile soon.

cc @DN6 @yiyixuxu @sayakpaul

@a-r-r-o-w a-r-r-o-w changed the title Pab hook impl Pyramid Attention Broadcast rewrite + introduce hooks Oct 31, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

This is just a prototype/test for how the hook-based implementation would look like for PAB. Taking a look at making compatible with torch.compile soon.

See an internal thread here: https://huggingface.slack.com/archives/C01Q6JPP6NA/p1726820287833739

TLDR was to use CUDA streams for this but I would be positively surprised if hooks can work with torch.compile(). Cc: @SunMarc too.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. LMK if they make sense.



# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
class ModelHook:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not copy-pasted verbatim, then let's add a note describing how it differs from accelerate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied verbatim, except for a single attribute (which I think I will remove soon)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why can't we import directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it would be okay to import but the implementations may diverge soon as the design in my mind might require diffusers to have some additional class attributes + does not make sense to depend on another library for something core to diffusers IMO (any change they make upstream will affect us, and since this is very simple to have a base class, it might better to have our own implementation). Will use imports if we don't need anything additional at the moment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh then I am okay to have them copy-pasted with this exact note.

return module


class SequentialHook(ModelHook):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied completely

return module._diffusers_hook.post_forward(module, output)

def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
self.attention_cache = output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could consider making this configurable i.e., if we want to offload the cache to CPU/disk while executing a certain computation on the GPU. This can be done with CUDA streams. This way, we should be able to overlap computation and device transfer, saving time.

https://github.com/huggingface/transformers/blob/f83917ed89292072e57c4ba013d1cc6477538b11/src/transformers/cache_utils.py is a good reference for this.

Comment on lines +94 to +96
spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"],
temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"],
cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we just specify "transformer_blocks" instead of ("blocks", "transformer_blocks")?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some implementations that use "blocks" and some that use "transformer_blocks" because of not strictly following some conventions - maybe we can create alias attribute called "transformer_blocks" for places that use "blocks" and just default to the one

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah better approach indeed!

Comment on lines +109 to +114
temporal_attn_skip_range (`int`, *optional*):
The attention block to execute after skipping intermediate temporal attention blocks. If set to the
value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
different tolerances to how much attention computation can be reused based on the differences between
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
Setting this value to `4` is recommended for different models PAB has been experimented with.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(doc nit): we could consider adding the additional piece of info as a "Tip". This way it would stand out better for a user.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 1, 2024

See an internal thread here: https://huggingface.slack.com/archives/C01Q6JPP6NA/p1726820287833739

TLDR was to use CUDA streams for this but I would be positively surprised if hooks can work with torch.compile(). Cc: @SunMarc too.

I think that is a separate issue for offloading things that might not be compatible with torch.compile and might require us to use that.

For now, torch.compile might be broken for PAB due to conditional statements in forward pass that cause different forward branches to be followed based on whether attention should be computed or not. This did not work until a few weeks/months back but seems to have been resolved (see pytorch/pytorch#124717, which contains good commit/cross-references) if you set torch._dynamo.config.guard_nn_modules=True (has performance penalty).

@sayakpaul
Copy link
Member

My comment was not particularly for offloading + torch.compile but more for hooks + torch.compile which I thought applies here.

But maybe not since you mentioned:

For now, torch.compile might be broken for PAB due to conditional statements in forward pass that cause different forward branches to be followed based on whether attention should be computed or not. This did not work until a few weeks/months back but seems to have been resolved (see pytorch/pytorch#124717, which contains good commit/cross-references) if you set torch._dynamo.config.guard_nn_modules=True (has performance penalty).

Can we try using torch.where() to get rid of the conditionals?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 1, 2024

Hooks are perfectly fine with torch.compile, if they are not mutating state of the module itself or changing the control flow (PAB changes control flow, so this would not have worked until a few months ago).

In this case, I'm not sure if torch.where would be applicable since we want to determine IF a specific branch of the module will be executed or not. If the branch is executed, we utilize the returns from it, but if not, we return the cached values. For torch.where to work here, we would already need to have result from execution on hand to place in the conditional, but doing the computation would lead to not realizing the speedup gains from PAB.

If I'm missing something, please LMK because it would be really nice if there was a clean way to make this torch.compile compatible, and not have the perf penalty that comes with guard_nn_modules.

@sayakpaul
Copy link
Member

Hooks are perfectly fine with torch.compile, if they are not mutating state of the module itself or changing the control flow

Do you have a very minimal working example here?

If I'm missing something, please LMK because it would be really nice if there was a clean way to make this torch.compile compatible, and not have the perf penalty that comes with guard_nn_modules.

I can look into it. Would you be able to provide a minimal working example for me to work with?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Nov 1, 2024

Do you have a very minimal working example here?

When there are no modifications to the module/input state, or a flip on a conditional that would affect the forward pass (in which case guard_nn_modules=True would be required), there is nothing that would modify the neccessary conditions required for the preferred inductor mode (max-autotune with cudagraph), so it should be compatible (please correct me if I'm not understanding what you mean). This is a quick example I wrote to verify:

When we just clone the inputs or do something that does not effect module/input state (for example, device transfer would affect module state but caching inputs would not [imagine resnet like residual caching that can later be used in post_forward hook]), I believe all cudagraphs conditions are met, so max-autotune would not emit a warning mentioning that it is skipping cudagraph. When we modify the input state, it is expected that cudagraph would be skipped since its requirements are not met. We can still use torch.compile though without recompiles/warnings, but without the advantages of cudagraphs by using max-autotune-no-cudagraphs

Code
import os

os.environ["TORCH_LOGS"] = "+output_code,recompiles,graph_breaks"

import torch
import torch.nn as nn
import torch._inductor.config

# torch._inductor.config.triton.cudagraph_support_input_mutation = True

from diffusers.models.hooks import ModelHook, add_hook_to_module

class MyModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        return self.linear(x)


class InputNonModifierAndOutputCacher(ModelHook):
    def __init__(self):
        super().__init__()
        self.args_register = None
        self.cache = None
    
    def pre_forward(self, module, *args, **kwargs):
        self.args_register = (arg.clone() for arg in args if isinstance(arg, torch.Tensor))
        return args, kwargs
    
    def post_forward(self, module, output):
        self.cache = output
        return output
    
    def reset_state(self, module):
        self.args_register = None
        self.cache = None


class InputModifierAndOutputCacher(ModelHook):
    def __init__(self):
        super().__init__()
        self.cache = None
    
    def pre_forward(self, module, *args, **kwargs):
        for arg in args:
            if isinstance(arg, torch.Tensor):
                arg += 1.0
        return args, kwargs
    
    def post_forward(self, module, output):
        self.cache = output
        return output
    
    def reset_state(self, module):
        self.cache = None


batch_size = 4
dim = 128

input = torch.randn((batch_size, dim), device="cuda")

model = MyModule(dim).to("cuda")
print(model(input).shape)

compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True)
print(compiled_model(input).shape)


hook = InputNonModifierAndOutputCacher()
hook_model = add_hook_to_module(model, hook, append=True)
# Can just use model too since the modification is inplace
print(hook_model(input).shape)

hook_compiled_model = torch.compile(hook_model, mode="max-autotune", fullgraph=True)
print(hook_compiled_model(input).shape)

hook_compiled_model_no_cudagraphs = torch.compile(hook_model, mode="max-autotune-no-cudagraphs", fullgraph=True)
print(hook_compiled_model_no_cudagraphs(input).shape)


hook = InputModifierAndOutputCacher()
hook_model = add_hook_to_module(model, hook, append=True)
# Can just use model too since the modification is inplace
print(hook_model(input).shape)

hook_compiled_model = torch.compile(hook_model, mode="max-autotune", fullgraph=True)
print(hook_compiled_model(input).shape)

hook_compiled_model_no_cudagraphs = torch.compile(hook_model, mode="max-autotune-no-cudagraphs", fullgraph=True)
print(hook_compiled_model_no_cudagraphs(input).shape)
Output
torch.Size([4, 128])
/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:150: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] Output code: 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] # AOT ID: ['0_forward']
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from ctypes import c_void_p, c_long
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] import torch
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] import math
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] import random
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] import os
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] import tempfile
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from math import inf, nan
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch import device, empty_strided
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] aten = torch.ops.aten
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] _quantized = torch.ops._quantized
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] async_compile = AsyncCompile()
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] async_compile.wait(globals())
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] del async_compile
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] def call(args):
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     primals_1, primals_2, primals_3 = args
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     args.clear()
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     assert_size_stride(primals_1, (128, 128), (128, 1))
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     assert_size_stride(primals_2, (128, ), (1, ))
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     assert_size_stride(primals_3, (4, 128), (128, 1))
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]         torch.cuda.set_device(0)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]         buf0 = empty_strided_cuda((4, 128), (128, 1), torch.float32)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]         # Source Nodes: [], Original ATen: []
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]         extern_kernels.bias_addmm(reinterpret_tensor(primals_2, (4, 128), (0, 1), 0), primals_3, reinterpret_tensor(primals_1, (128, 128), (1, 128), 0), alpha=1, beta=1, out=buf0)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]         del primals_1
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]         del primals_2
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     return (buf0, primals_3, )
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     primals_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     primals_2 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     primals_3 = rand_strided((4, 128), (128, 1), device='cuda:0', dtype=torch.float32)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     fn = lambda: call([primals_1, primals_2, primals_3])
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] if __name__ == "__main__":
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V1101 08:30:10.503000 139697423578944 torch/_inductor/graph.py:1683] [0/0] [__output_code] 
I1101 08:30:10.511000 139697423578944 torch/_inductor/graph.py:1717] [0/0] [__output_code] Output code written to: /tmp/torchinductor_aryan/ot/cotczyvjay3jbsjf3yzkqrom47ivrjfu2k62fxzyepe5quowoi3p.py
torch.Size([4, 128])
torch.Size([4, 128])
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] Output code: 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] # AOT ID: ['1_forward']
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from ctypes import c_void_p, c_long
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import torch
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import math
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import random
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import os
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import tempfile
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from math import inf, nan
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.utils import maybe_profile
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch import device, empty_strided
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] aten = torch.ops.aten
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] inductor_ops = torch.ops.inductor
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] _quantized = torch.ops._quantized
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] async_compile = AsyncCompile()
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] # kernel path: /tmp/torchinductor_aryan/u6/cu64ropwbt2pxrdxblburuobklxxwnraxxnbq6hv72gwbybopx6w.py
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] # Source Nodes: [clone], Original ATen: [aten.clone]
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] # clone => clone
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] triton_poi_fused_clone_0 = async_compile.triton('triton_', '''
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import triton
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import triton.language as tl
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from triton.compiler.compiler import AttrsDescriptor
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] @triton_heuristics.pointwise(
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     size_hints=[512], 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     filename=__file__,
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '620BE4B991CCC88079FD29703288616B3BF60DB5BAD26CDCE55A0789E6044C3A', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     min_elem_per_thread=0
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] )
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] @triton.jit
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     xnumel = 512
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     xmask = xindex < xnumel
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     x0 = xindex
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp0, xmask)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] ''', device_str='cuda')
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import triton
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] import triton.language as tl
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] async_compile.wait(globals())
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] del async_compile
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] def call(args):
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     primals_1, primals_2, primals_3 = args
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     args.clear()
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     assert_size_stride(primals_1, (128, 128), (128, 1))
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     assert_size_stride(primals_2, (128, ), (1, ))
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     assert_size_stride(primals_3, (4, 128), (128, 1))
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         torch.cuda.set_device(0)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         buf0 = empty_strided_cuda((4, 128), (128, 1), torch.float32)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         # Source Nodes: [clone], Original ATen: [aten.clone]
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         stream0 = get_raw_stream(0)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         triton_poi_fused_clone_0.run(primals_3, buf0, 512, grid=grid(512), stream=stream0)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         buf1 = empty_strided_cuda((4, 128), (128, 1), torch.float32)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         # Source Nodes: [], Original ATen: []
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         extern_kernels.bias_addmm(reinterpret_tensor(primals_2, (4, 128), (0, 1), 0), primals_3, reinterpret_tensor(primals_1, (128, 128), (1, 128), 0), alpha=1, beta=1, out=buf1)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         del primals_1
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]         del primals_2
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     return (buf1, buf0, primals_3, )
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     from torch._dynamo.testing import rand_strided
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     from torch._inductor.utils import print_performance
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     primals_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     primals_2 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     primals_3 = rand_strided((4, 128), (128, 1), device='cuda:0', dtype=torch.float32)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     fn = lambda: call([primals_1, primals_2, primals_3])
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] if __name__ == "__main__":
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V1101 08:30:11.695000 139697423578944 torch/_inductor/graph.py:1683] [1/0] [__output_code] 
I1101 08:30:12.472000 139697423578944 torch/_inductor/graph.py:1717] [1/0] [__output_code] Output code written to: /tmp/torchinductor_aryan/gx/cgx2kybdq7js6gaehcl36tsf62detys5tx74kt5kglk3ywquhvm7.py
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
V1101 08:30:12.501000 139697423578944 torch/_dynamo/guards.py:2611] [1/1] [__recompiles] Recompiling function new_forward in /home/aryan/work/diffusers/src/diffusers/models/hooks.py:208
V1101 08:30:12.501000 139697423578944 torch/_dynamo/guards.py:2611] [1/1] [__recompiles]     triggered by the following guard failure(s):
V1101 08:30:12.501000 139697423578944 torch/_dynamo/guards.py:2611] [1/1] [__recompiles]     - ___check_type_id(L['module']._diffusers_hook, 94402397383072)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] Output code: 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] # AOT ID: ['2_forward']
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from ctypes import c_void_p, c_long
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import torch
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import math
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import random
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import os
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import tempfile
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from math import inf, nan
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.utils import maybe_profile
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch import device, empty_strided
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.async_compile import AsyncCompile
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] aten = torch.ops.aten
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] inductor_ops = torch.ops.inductor
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] _quantized = torch.ops._quantized
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] async_compile = AsyncCompile()
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] # kernel path: /tmp/torchinductor_aryan/3t/c3twisufybu2tkvj7o6pwopgsy3ylum4ppxt7eqifhfrh2lvc5vz.py
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] # Source Nodes: [arg, clone], Original ATen: [aten.add, aten.clone]
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] # arg => add
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] # clone => clone
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] triton_poi_fused_add_clone_0 = async_compile.triton('triton_', '''
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import triton
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import triton.language as tl
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from triton.compiler.compiler import AttrsDescriptor
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] @triton_heuristics.pointwise(
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     size_hints=[512], 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     filename=__file__,
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, multi_processor_count=108), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_clone_0', 'mutated_arg_names': ['in_ptr0', 'out_ptr2'], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '620BE4B991CCC88079FD29703288616B3BF60DB5BAD26CDCE55A0789E6044C3A', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     min_elem_per_thread=0
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] )
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] @triton.jit
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] def triton_(in_ptr0, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr):
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     xnumel = 512
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     xmask = xindex < xnumel
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     x0 = xindex
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     tmp1 = 1.0
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     tmp2 = tmp0 + tmp1
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     tl.store(out_ptr0 + (x0), tmp0, xmask)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     tl.store(out_ptr1 + (x0), tmp2, xmask)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     tl.store(out_ptr2 + (x0), tmp2, xmask)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] ''', device_str='cuda')
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import triton
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] import triton.language as tl
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] async_compile.wait(globals())
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] del async_compile
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] def call(args):
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     primals_1, primals_2, primals_3 = args
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     args.clear()
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     assert_size_stride(primals_1, (128, 128), (128, 1))
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     assert_size_stride(primals_2, (128, ), (1, ))
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     assert_size_stride(primals_3, (4, 128), (128, 1))
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     with torch.cuda._DeviceGuard(0):
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         torch.cuda.set_device(0)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         buf0 = empty_strided_cuda((4, 128), (128, 1), torch.float32)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         buf1 = empty_strided_cuda((4, 128), (128, 1), torch.float32)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         # Source Nodes: [arg, clone], Original ATen: [aten.add, aten.clone]
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         stream0 = get_raw_stream(0)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         triton_poi_fused_add_clone_0.run(primals_3, buf0, buf1, primals_3, 512, grid=grid(512), stream=stream0)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         del primals_3
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         buf2 = empty_strided_cuda((4, 128), (128, 1), torch.float32)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         # Source Nodes: [], Original ATen: []
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         extern_kernels.bias_addmm(reinterpret_tensor(primals_2, (4, 128), (0, 1), 0), buf1, reinterpret_tensor(primals_1, (128, 128), (1, 128), 0), alpha=1, beta=1, out=buf2)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         del primals_1
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]         del primals_2
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     return (buf2, buf0, buf1, )
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     from torch._dynamo.testing import rand_strided
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     from torch._inductor.utils import print_performance
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     primals_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     primals_2 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     primals_3 = rand_strided((4, 128), (128, 1), device='cuda:0', dtype=torch.float32)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     fn = lambda: call([primals_1, primals_2, primals_3])
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] if __name__ == "__main__":
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V1101 08:30:12.671000 139697423578944 torch/_inductor/graph.py:1683] [1/1] [__output_code] 
I1101 08:30:13.430000 139697423578944 torch/_inductor/graph.py:1717] [1/1] [__output_code] Output code written to: /tmp/torchinductor_aryan/fx/cfx44vzm7bfkn72526fudeugztee644jp5pfvtvuo2kii4idn4nv.py
skipping cudagraphs due to mutated inputs (1 instances)
torch.Size([4, 128])
torch.Size([4, 128])

I can look into it. Would you be able to provide a minimal working example for me to work with?

Thanks! Here is the code for running benchmark comparison between normal vs PAB:

Code
import gc

import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from tabulate import tabulate


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


@torch.no_grad()
def test_cogvideox_5b():
    reset_memory()
    
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    model_memory = torch.cuda.memory_allocated() / 1024**3

    prompt = (
        "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
        "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
        "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
        "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
        "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
        "atmosphere of this unique musical performance."
    )

    # Warmup
    _ = pipe(
        prompt=prompt,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=2,
        generator=torch.Generator().manual_seed(31337),
    )

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    normal_time = start.elapsed_time(end) / 1000
    normal_memory = torch.cuda.max_memory_reserved() / 1024**3
    
    video = pipe.decode_latents(latent)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/cogvideox_5b.mp4", fps=8)

    pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    latent = pipe(
        prompt=prompt,
        guidance_scale=6,
        use_dynamic_cfg=True,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(31337),
        output_type="latent",
    ).frames
    end.record()
    torch.cuda.synchronize()
    
    pab_time = start.elapsed_time(end) / 1000
    pab_memory = torch.cuda.max_memory_reserved() / 1024**3

    video = pipe.decode_latents(latent)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, "outputs/cogvideox_pab_5b.mp4", fps=8)

    return {
        "model_memory": model_memory,
        "normal_memory": normal_memory,
        "pab_memory": pab_memory,
        "normal_time": normal_time,
        "pab_time": pab_time,
    }

results = test_cogvideox_5b()
pretty_print_results(results)

This is the absolute minimal example:

Code
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=[100, 850])

prompt = (
    "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
    "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
    "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
    "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
    "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
    "atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]

export_to_video(video, "output.mp4", fps=8)

@sayakpaul
Copy link
Member

When there are no modifications to the module/input state, or a flip on a conditional that would effect the forward pass

Makes sense. Thanks for explaining! We can visit the requirement of using CUDAgraphs later depending on the level of performance penalty it incurs. It might well be that we won't use CUDAgraphs at all to support the design.

Here is the code for running benchmark comparison between normal vs PAB:

I guess we will need to call torch.compile() after PAB is applied here? Just confirming.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review November 5, 2024 11:42
@a-r-r-o-w a-r-r-o-w requested a review from DN6 November 5, 2024 11:53
@a-r-r-o-w a-r-r-o-w merged commit c52cf42 into pyramid-attention-broadcast Nov 8, 2024
2 checks passed
@a-r-r-o-w a-r-r-o-w deleted the pab-hook-impl branch November 8, 2024 16:16
a-r-r-o-w added a commit that referenced this pull request Jan 27, 2025
* start pyramid attention broadcast

* add coauthor

Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>

* update

* make style

* update

* make style

* add docs

* add tests

* update

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Pyramid Attention Broadcast rewrite + introduce hooks (#9826)

* rewrite implementation with hooks

* make style

* update

* merge pyramid-attention-rewrite-2

* make style

* remove changes from latte transformer

* revert docs changes

* better debug message

* add todos for future

* update tests

* make style

* cleanup

* fix

* improve log message; fix latte test

* refactor

* update

* update

* update

* revert changes to tests

* update docs

* update tests

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* fix flux test

* reorder

* refactor

* make fix-copies

* update docs

* fixes

* more fixes

* make style

* update tests

* update code example

* make fix-copies

* refactor based on reviews

* use maybe_free_model_hooks

* CacheMixin

* make style

* update

* add current_timestep property; update docs

* make fix-copies

* update

* improve tests

* try circular import fix

* apply suggestions from review

* address review comments

* Apply suggestions from code review

* refactor hook implementation

* add test suite for hooks

* PAB Refactor (#10667)

* update

* update

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

* update

* fix remove hook behaviour

---------

Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants