Skip to content

8k Stable Diffusion with tiled VAE #1441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 2, 2023
Merged

Conversation

kig
Copy link
Contributor

@kig kig commented Nov 27, 2022

This PR makes it possible to generate 4k images in 8GB of VRAM using a tiled VAE codec combined with enable_xformers_memory_efficient_attention(). With 24GB of VRAM, you can generate 8k images.

The tiled codec splits the input into overlapping tiles, processes the tiles sequentially, and blends the output tiles together for the final output.

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=True)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
pipe.vae.enable_tiling()

prompt = "a beautiful landscape photo"
image = pipe(prompt, width=4096, height=2048, num_inference_steps=10).images[0]

image.save("4k_landscape.jpg")

It's not perfect. Each tile has a different decoder so uniform surfaces tend to have tile-to-tile tone variation. You also want to disable this for smaller images, which I'm doing at the pipeline level.

If the tiling artifacts are giving you grief there's another way to do this by adding xformers support to VAE, switching the VAE to channels_last memory format, and running the up_blocks on the CPU.

But that's a different PR.

Example output:

4k_output

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 27, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 1, 2022

This looks really nice already, nice job @kig !
Just to better understand, are there multiple ways of how tiling can be implemented? Is there a reference paper / implementation for this?

Could we add a test for this?

@kig
Copy link
Contributor Author

kig commented Dec 1, 2022

Thanks @patrickvonplaten !
I don't think there's a reference paper / implementation for this, it's based on experimentation. I might be wrong though and maybe there's a paper out there discussing this. Idea-wise it's similar to the GOBIG upscaler.

And yes, let me write a test for it. I guess a smoke test and one that verifies that 512x512 output matches non-tiled, and one where 1024x1024 output is mostly similar to non-tiled? Would that last one run ok on the test infra? :)

The tiling can be done at least in a couple of ways that I tried:

  1. Split the latents into 64x64 non-overlapping tiles and decode them separately. Produces sharp seams between the tiles.
  2. Add an overlapping border around the tiles (say, pad each tile by a border with width 64, so your (1,4,64,64) tile becomes (1,4,192,192)) and decode them separately, using only the middle part for the output image. This is mostly seamless since the decoders see the neighboring tile latents, but you can still get seams in flat areas. And the per-tile processing time and memory use go way up.
  3. Decode 64x64 tiles, but overlap each tile with the tile on the left and the tile above it. Blend the overlap in the output tiles with a lerp. No visible seams in the output, and no increase in per-tile memory use (all the tiles are 64x64), but the processing takes longer compared to non-overlapping tiles. This is what the PR code is doing, 64x64 tiles with a 48 px stride.

The nicest way to do this would be to make the VAE attention use xformers and make the convolution layers run in a fixed amount of memory. That way it'd produce actually correct results...

There's a PR for the xformers #1507 and I got the convolution layers bit sort of working with channels_last memory format, but they still use tens of GB of RAM -- that one's in https://github.com/kig/diffusers/blob/sd-vae-hires/src/diffusers/models/vae.py#L298 but it's quite messy.

@kig
Copy link
Contributor Author

kig commented Dec 3, 2022

[Going on a tangent.]
Profiling the memory use a bit further, running the non-tiled decoder with limited memory seems tricky. The decoder images have channel counts ranging from 512 to 128. The convolutions do run memory-efficiently with channels_last, but if your input image is 8 GB and your output image is 4 GB, you're going to need 12 GB.

VAE Decoder forward() input is a 4-channel image and the output is a 3-channel image. First (pardon the infodump) conv_in goes 4->512, then mid_block 512->512, up_blocks [512 -> 2x res 512 -> 4x res 256 -> 8x res 128], conv_norm_out 128->128, conv_act 128->128, conv_out 128->3. Peak memory use happens on the last up_block when it F.interpolates to 8x res 256c and then conv2ds that to 8x res 128c.

Only the mid_block has an attention layer, the others are a mix of Conv2d, GroupNorm, Dropout and SiLU. I guess the tiling artifacts would come from the mid_block and the GroupNorms. The mid_block can be run on the full image, it's not very memory-intensive. The rest of the pipeline you'd have to tile. Fixing tiling artifacts coming from GroupNorms... I guess you could create a downsampled version of the image, compute the group norm parameters for the image, and apply those instead of the dynamically computed per-tile group norm.

In a nutshell, tile the image after mid_block, replace up_block GroupNorms with a fixed whole-image group norm, run the rest of the pipeline tiled, put the tiles back together for the decoder output.

@kig kig marked this pull request as ready for review December 3, 2022 19:34
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 5, 2022

Sorry for being so slow here - will try to look into it this week!

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 30, 2022
@huggingface huggingface deleted a comment from github-actions bot Jan 4, 2023
@patrickvonplaten
Copy link
Contributor

@patil-suraj could you pick this up maybe?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@keturn
Copy link
Contributor

keturn commented Feb 26, 2023

Now that the UNet has various options for slicing and memory-efficient attention, it's not uncommon to generate results that are bigger than fit through the VAE.

I haven't reviewed this myself but it sounds like it could be one way to address that problem.

@pkuliyi2015
Copy link

I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 1, 2023

Not that we have also a "slice_vae" functionality: https://huggingface.co/docs/diffusers/v0.13.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.enable_vae_slicing
but this only makes sense for higher batch sizes so I guess this is still very much relevant

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 1, 2023

Can reproduce the results from above when using the faster UniPC sampler (just 20 steps).

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed some things to make naming more general, PR looks clean and tests are nice! Thanks a lot for the work here @kig and sorry to so extremely late!

Good to merge for me! @patil-suraj @williamberman @pcuenca can you take a final look?

@kig
Copy link
Contributor Author

kig commented Mar 2, 2023

I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?

@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.

From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I do agree that blending the tiles is a good compromise for reasonable results.

@@ -158,6 +189,108 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode

return DecoderOutput(sample=decoded)

def blend_v(self, a, b, blend_width):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would avoiding the use of width in both these methods prevent confusion? Not convinced about blend_size though. Maybe something like blend_extent? If you think it doesn't help then it's ok.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

b[:, :, :, x] = a[:, :, :, -blend_width + x] * (1 - x / blend_width) + b[:, :, :, x] * (x / blend_width)
return b

def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, what use case does tiled encoding fulfills? Also, do we need blending during the encoding phase? If we used this as an autoencoder of a large image, I would have thought that blending during decoding would be enough to avoid seams between the tiles.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's just to save memory? Encoding a large image is pretty memory intensive no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I meant that we don't need encoding for inference, and I don't think we'll train with very large images. It was just out of curiosity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did need it for something. IIRC img2img with large images.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you did img2img on huge images; cool, understood.

@@ -96,6 +96,7 @@ def dummy_vqvae_and_unet(self):
)
return vqvae, unet

@slow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we need?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this test takes a minute and this model has pretty much 0 usage, so disabling for fast

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the decorator is on test_audio_diffusion, seemed unrelated :)

patrickvonplaten and others added 3 commits March 2, 2023 17:25
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@patrickvonplaten patrickvonplaten merged commit 8014848 into huggingface:main Mar 2, 2023
@pkuliyi2015
Copy link

Hello, I have completed a wild hack that achieves exactly what you may want! https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111
This is an upscaler WITHOUT image post-processing. Everything is in latent space. The repo includes a wild hack that breaks VAE into task queues. Please refer to my repo for details!

@pkuliyi2015
Copy link

pkuliyi2015 commented Mar 3, 2023

I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?

@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.

From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.

I have completed a tricky optimization on VAEs. After tons of tricks I found my implementation to be nearly perfect in terms of no seams, except that you must use fp32 VAE for 8K images otherwise it report NANs. You also need giant CPU RAM (~ 85GB for 8k images) to store intermediate tensors. My hack is implemented as an Automatic1111's WebUI extensions, with recommended and user-changeable tiling sizes to fit their own GPU VRAMs. See https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111

@lizhengwei1992
Copy link

I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?

@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.

From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.

the current vae.enable_tiling() doesn't work well according my test. the tiled VAE implementation by @pkuliyi2015 perform better !

eg:
StableDiffusionUpscalePipeline (turn on vae.enable_tiling()) vs StableDiffusionUpscaleTiledVAEPipeline(modified myself using tiled VAE in decoder ),
input image resolution is 362x512 , upscale 4x to 1448 × 2048 by model stabilityai/stable-diffusion-x4-upscaler

vae.enable_tiling

vcg_VCG41157379862_RF_resize_sd_v2_enable_tiling

tiled VAE

vcg_VCG41157379862_RF_resize_sd_v2_tiled_vae

the defect also occurs in other images when using vae.enable_tiling, but TiledVAE works well.

@kig
Copy link
Contributor Author

kig commented Mar 23, 2023

I have tried to split the vae decoder's upsampling part. I confirm that the seams is from global-aware operators, specifically the attention, most of which can be safely removed except the groupnorms inside the ResNet block. How can I keep the mean and the variance of these group norm the same?

@pkuliyi2015 Looking at the group_norm source, it shouldn't be too difficult to make a custom Python version.
From what I understood digging into the PyTorch ATen C++ & CUDA implementations, the mean and standard deviation are computed inside the kernels (e.g. group_norm_kernel.cu calls into RowwiseMomentsCUDAKernel which is using WelfordOps to compute the mean and standard deviation.) Making them use custom params would require changing PyTorch or adding a new custom op.

the current vae.enable_tiling() doesn't work well according my test. the tiled VAE implementation by @pkuliyi2015 perform better !

eg: StableDiffusionUpscalePipeline (turn on vae.enable_tiling()) vs StableDiffusionUpscaleTiledVAEPipeline(modified myself using tiled VAE in decoder ), input image resolution is 362x512 , upscale 4x to 1448 × 2048 by model stabilityai/stable-diffusion-x4-upscaler

vae.enable_tiling

vcg_VCG41157379862_RF_resize_sd_v2_enable_tiling

tiled VAE

vcg_VCG41157379862_RF_resize_sd_v2_tiled_vae

the defect also occurs in other images when using vae.enable_tiling, but TiledVAE works well.

Wow, this looks great, avoids the burnt out spots! Do you have code or a PR for the TiledVAE?

@patrickvonplaten
Copy link
Contributor

Would be cool to see a PR here :-)

@pkuliyi2015
Copy link

Hello I'm the original author of this engineering implementation. I'm willing to make a pr but I have little time,As I'm exploring the possibility of tiling
unet (not folding it!) So I will do this PR when I have time. But you are welcome to make a pr to my repo as a new branch. My new workload is extremely heavy and I still needed very much time to do that.

@lizhengwei1992
Copy link

Hello I'm the original author of this engineering implementation. I'm willing to make a pr but I have little time,As I'm exploring the possibility of tiling unet (not folding it!) So I will do this PR when I have time. But you are welcome to make a pr to my repo as a new branch. My new workload is extremely heavy and I still needed very much time to do that.

Yes ! I simply incorporated tiledVAE into the StableDiffusionUpscalePipeline, and made some code modifications to adapt the model structure definition of the decoder portion in the diffusers library. Looking forward your PR.
but , about tiled UNET, it my not suitable in Text driven UpscalePipeline.
The different blocks in a large image may contain different semantic information, and if diffusion is performed on the segmented blocks, the text information required by Unet may be different, which means that the text also needs to correspond to each image block, which may be challenging to handle.

@pkuliyi2015
Copy link

Yes it is very challenging.

However, the text information is injected via the cross attention mechanism, where the (QK/d^-2) * V is a linear process. so you can just calc a small picture's QK, and then do a bilinear interpolation on the result matrix...This is very novel and tricky, ah?

mengfei25 pushed a commit to mengfei25/diffusers that referenced this pull request Mar 27, 2023
* Tiled VAE for high-res text2img and img2img

* vae tiling, fix formatting

* enable_vae_tiling API and tests

* tiled vae docs, disable tiling for images that would have only one tile

* tiled vae tests, use channels_last memory format

* tiled vae tests, use smaller test image

* tiled vae tests, remove tiling test from fast tests

* up

* up

* make style

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* make style

* improve naming

* finish

* apply suggestions

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

---------

Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@GitHub1712
Copy link

GitHub1712 commented May 14, 2023

This breakes the tiling (seamless) option, seems by not connectin/overlapping outer edges over the opposite tile edges. Could this be fixed? To generate 8K seamless tiles?

@patrickvonplaten
Copy link
Contributor

Hey @GitHub1712,

Could you open a new issue here?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Tiled VAE for high-res text2img and img2img

* vae tiling, fix formatting

* enable_vae_tiling API and tests

* tiled vae docs, disable tiling for images that would have only one tile

* tiled vae tests, use channels_last memory format

* tiled vae tests, use smaller test image

* tiled vae tests, remove tiling test from fast tests

* up

* up

* make style

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* make style

* improve naming

* finish

* apply suggestions

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

---------

Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Tiled VAE for high-res text2img and img2img

* vae tiling, fix formatting

* enable_vae_tiling API and tests

* tiled vae docs, disable tiling for images that would have only one tile

* tiled vae tests, use channels_last memory format

* tiled vae tests, use smaller test image

* tiled vae tests, remove tiling test from fast tests

* up

* up

* make style

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* make style

* improve naming

* finish

* apply suggestions

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

---------

Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants