-
Notifications
You must be signed in to change notification settings - Fork 6k
8k Stable Diffusion with tiled VAE #1441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
This looks really nice already, nice job @kig ! Could we add a test for this? |
Thanks @patrickvonplaten ! And yes, let me write a test for it. I guess a smoke test and one that verifies that 512x512 output matches non-tiled, and one where 1024x1024 output is mostly similar to non-tiled? Would that last one run ok on the test infra? :) The tiling can be done at least in a couple of ways that I tried:
The nicest way to do this would be to make the VAE attention use xformers and make the convolution layers run in a fixed amount of memory. That way it'd produce actually correct results... There's a PR for the xformers #1507 and I got the convolution layers bit sort of working with channels_last memory format, but they still use tens of GB of RAM -- that one's in https://github.com/kig/diffusers/blob/sd-vae-hires/src/diffusers/models/vae.py#L298 but it's quite messy. |
[Going on a tangent.] VAE Decoder Only the mid_block has an attention layer, the others are a mix of Conv2d, GroupNorm, Dropout and SiLU. I guess the tiling artifacts would come from the mid_block and the GroupNorms. The mid_block can be run on the full image, it's not very memory-intensive. The rest of the pipeline you'd have to tile. Fixing tiling artifacts coming from GroupNorms... I guess you could create a downsampled version of the image, compute the group norm parameters for the image, and apply those instead of the dynamically computed per-tile group norm. In a nutshell, tile the image after mid_block, replace up_block GroupNorms with a fixed whole-image group norm, run the rest of the pipeline tiled, put the tiles back together for the decoder output. |
Sorry for being so slow here - will try to look into it this week! |
@patil-suraj could you pick this up maybe? |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Now that the UNet has various options for slicing and memory-efficient attention, it's not uncommon to generate results that are bigger than fit through the VAE. I haven't reviewed this myself but it sounds like it could be one way to address that problem. |
I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same? |
Not that we have also a "slice_vae" functionality: https://huggingface.co/docs/diffusers/v0.13.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.enable_vae_slicing |
Can reproduce the results from above when using the faster UniPC sampler (just 20 steps). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed some things to make naming more general, PR looks clean and tests are nice! Thanks a lot for the work here @kig and sorry to so extremely late!
Good to merge for me! @patil-suraj @williamberman @pcuenca can you take a final look?
@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version. From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I do agree that blending the tiles is a good compromise for reasonable results.
@@ -158,6 +189,108 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode | |||
|
|||
return DecoderOutput(sample=decoded) | |||
|
|||
def blend_v(self, a, b, blend_width): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would avoiding the use of width
in both these methods prevent confusion? Not convinced about blend_size
though. Maybe something like blend_extent
? If you think it doesn't help then it's ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width) | ||
return b | ||
|
||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question, what use case does tiled encoding fulfills? Also, do we need blending during the encoding phase? If we used this as an autoencoder of a large image, I would have thought that blending during decoding would be enough to avoid seams between the tiles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's just to save memory? Encoding a large image is pretty memory intensive no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I meant that we don't need encoding for inference, and I don't think we'll train with very large images. It was just out of curiosity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did need it for something. IIRC img2img with large images.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you did img2img on huge images; cool, understood.
@@ -96,6 +96,7 @@ def dummy_vqvae_and_unet(self): | |||
) | |||
return vqvae, unet | |||
|
|||
@slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something we need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this test takes a minute and this model has pretty much 0 usage, so disabling for fast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But the decorator is on test_audio_diffusion
, seemed unrelated :)
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Hello, I have completed a wild hack that achieves exactly what you may want! https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111 |
I have completed a tricky optimization on VAEs. After tons of tricks I found my implementation to be nearly perfect in terms of no seams, except that you must use fp32 VAE for 8K images otherwise it report NANs. You also need giant CPU RAM (~ 85GB for 8k images) to store intermediate tensors. My hack is implemented as an Automatic1111's WebUI extensions, with recommended and user-changeable tiling sizes to fit their own GPU VRAMs. See https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111 |
the current eg: vae.enable_tiling tiled VAE the defect also occurs in other images when using |
Wow, this looks great, avoids the burnt out spots! Do you have code or a PR for the TiledVAE? |
Would be cool to see a PR here :-) |
Hello I'm the original author of this engineering implementation. I'm willing to make a pr but I have little time,As I'm exploring the possibility of tiling |
Yes ! I simply incorporated tiledVAE into the StableDiffusionUpscalePipeline, and made some code modifications to adapt the model structure definition of the decoder portion in the diffusers library. Looking forward your PR. |
Yes it is very challenging. However, the text information is injected via the cross attention mechanism, where the (QK/d^-2) * V is a linear process. so you can just calc a small picture's QK, and then do a bilinear interpolation on the result matrix...This is very novel and tricky, ah? |
* Tiled VAE for high-res text2img and img2img * vae tiling, fix formatting * enable_vae_tiling API and tests * tiled vae docs, disable tiling for images that would have only one tile * tiled vae tests, use channels_last memory format * tiled vae tests, use smaller test image * tiled vae tests, remove tiling test from fast tests * up * up * make style * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * make style * improve naming * finish * apply suggestions * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up --------- Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This breakes the tiling (seamless) option, seems by not connectin/overlapping outer edges over the opposite tile edges. Could this be fixed? To generate 8K seamless tiles? |
Hey @GitHub1712, Could you open a new issue here? |
* Tiled VAE for high-res text2img and img2img * vae tiling, fix formatting * enable_vae_tiling API and tests * tiled vae docs, disable tiling for images that would have only one tile * tiled vae tests, use channels_last memory format * tiled vae tests, use smaller test image * tiled vae tests, remove tiling test from fast tests * up * up * make style * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * make style * improve naming * finish * apply suggestions * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up --------- Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
* Tiled VAE for high-res text2img and img2img * vae tiling, fix formatting * enable_vae_tiling API and tests * tiled vae docs, disable tiling for images that would have only one tile * tiled vae tests, use channels_last memory format * tiled vae tests, use smaller test image * tiled vae tests, remove tiling test from fast tests * up * up * make style * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * make style * improve naming * finish * apply suggestions * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up --------- Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This PR makes it possible to generate 4k images in 8GB of VRAM using a tiled VAE codec combined with
enable_xformers_memory_efficient_attention()
. With 24GB of VRAM, you can generate 8k images.The tiled codec splits the input into overlapping tiles, processes the tiles sequentially, and blends the output tiles together for the final output.
It's not perfect. Each tile has a different decoder so uniform surfaces tend to have tile-to-tile tone variation. You also want to disable this for smaller images, which I'm doing at the pipeline level.
If the tiling artifacts are giving you grief there's another way to do this by adding xformers support to VAE, switching the VAE to channels_last memory format, and running the up_blocks on the CPU.
But that's a different PR.
Example output: