Skip to content

Commit 8014848

Browse files
kigIlmari Heikkinenpatrickvonplatenpcuenca
authored
8k Stable Diffusion with tiled VAE (#1441)
* Tiled VAE for high-res text2img and img2img * vae tiling, fix formatting * enable_vae_tiling API and tests * tiled vae docs, disable tiling for images that would have only one tile * tiled vae tests, use channels_last memory format * tiled vae tests, use smaller test image * tiled vae tests, remove tiling test from fast tests * up * up * make style * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * make style * improve naming * finish * apply suggestions * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up --------- Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 8dfff7c commit 8014848

File tree

7 files changed

+288
-17
lines changed

7 files changed

+288
-17
lines changed

docs/source/en/api/pipelines/stable_diffusion/text2img.mdx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,6 @@ Available Checkpoints are:
3636
- enable_vae_slicing
3737
- disable_vae_slicing
3838
- enable_xformers_memory_efficient_attention
39-
- disable_xformers_memory_efficient_attention
39+
- disable_xformers_memory_efficient_attention
40+
- enable_vae_tiling
41+
- disable_vae_tiling

docs/source/en/optimization/fp16.mdx

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,34 @@ images = pipe([prompt] * 32).images
133133
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
134134

135135

136+
## Tiled VAE decode and encode for large images
137+
138+
Tiled VAE processing makes it possible to work with large images on limited VRAM. For example, generating 4k images in 8GB of VRAM. Tiled VAE decoder splits the image into overlapping tiles, decodes the tiles, and blends the outputs to make the final image.
139+
140+
You want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
141+
142+
To use tiled VAE processing, invoke [`~StableDiffusionPipeline.enable_vae_tiling`] in your pipeline before inference. For example:
143+
144+
```python
145+
import torch
146+
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
147+
148+
pipe = StableDiffusionPipeline.from_pretrained(
149+
"runwayml/stable-diffusion-v1-5",
150+
torch_dtype=torch.float16,
151+
)
152+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
153+
pipe = pipe.to("cuda")
154+
prompt = "a beautiful landscape photograph"
155+
pipe.enable_vae_tiling()
156+
pipe.enable_xformers_memory_efficient_attention()
157+
158+
image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]
159+
```
160+
161+
The output image will have some tile-to-tile tone variation from the tiles having separate decoders, but you shouldn't see sharp seams between the tiles. The tiling is turned off for images that are 512x512 or smaller.
162+
163+
136164
<a name="sequential_offloading"></a>
137165
## Offloading to CPU with accelerate for memory savings
138166

src/diffusers/models/autoencoder_kl.py

Lines changed: 149 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,54 @@ def __init__(
107107

108108
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
109109
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
110+
111+
self.use_slicing = False
112+
self.use_tiling = False
113+
114+
# only relevant if vae tiling is enabled
115+
self.tile_sample_min_size = self.config.sample_size
116+
sample_size = (
117+
self.config.sample_size[0]
118+
if isinstance(self.config.sample_size, (list, tuple))
119+
else self.config.sample_size
120+
)
121+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
122+
self.tile_overlap_factor = 0.25
123+
124+
def enable_tiling(self, use_tiling: bool = True):
125+
r"""
126+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
127+
compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
128+
the processing of larger images.
129+
"""
130+
self.use_tiling = use_tiling
131+
132+
def disable_tiling(self):
133+
r"""
134+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
135+
computing decoding in one step.
136+
"""
137+
self.enable_tiling(False)
138+
139+
def enable_slicing(self):
140+
r"""
141+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
142+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
143+
"""
144+
self.use_slicing = True
145+
146+
def disable_slicing(self):
147+
r"""
148+
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
149+
decoding in one step.
150+
"""
110151
self.use_slicing = False
111152

112153
@apply_forward_hook
113154
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
155+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
156+
return self.tiled_encode(x, return_dict=return_dict)
157+
114158
h = self.encoder(x)
115159
moments = self.quant_conv(h)
116160
posterior = DiagonalGaussianDistribution(moments)
@@ -121,6 +165,9 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK
121165
return AutoencoderKLOutput(latent_dist=posterior)
122166

123167
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
168+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
169+
return self.tiled_decode(z, return_dict=return_dict)
170+
124171
z = self.post_quant_conv(z)
125172
dec = self.decoder(z)
126173

@@ -129,22 +176,6 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
129176

130177
return DecoderOutput(sample=dec)
131178

132-
def enable_slicing(self):
133-
r"""
134-
Enable sliced VAE decoding.
135-
136-
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
137-
steps. This is useful to save some memory and allow larger batch sizes.
138-
"""
139-
self.use_slicing = True
140-
141-
def disable_slicing(self):
142-
r"""
143-
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
144-
decoding in one step.
145-
"""
146-
self.use_slicing = False
147-
148179
@apply_forward_hook
149180
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
150181
if self.use_slicing and z.shape[0] > 1:
@@ -158,6 +189,108 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
158189

159190
return DecoderOutput(sample=decoded)
160191

192+
def blend_v(self, a, b, blend_extent):
193+
for y in range(blend_extent):
194+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
195+
return b
196+
197+
def blend_h(self, a, b, blend_extent):
198+
for x in range(blend_extent):
199+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
200+
return b
201+
202+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
203+
r"""Encode a batch of images using a tiled encoder.
204+
Args:
205+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
206+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
207+
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
208+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
209+
look of the output, but they should be much less noticeable.
210+
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
211+
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
212+
"""
213+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
214+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
215+
row_limit = self.tile_latent_min_size - blend_extent
216+
217+
# Split the image into 512x512 tiles and encode them separately.
218+
rows = []
219+
for i in range(0, x.shape[2], overlap_size):
220+
row = []
221+
for j in range(0, x.shape[3], overlap_size):
222+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
223+
tile = self.encoder(tile)
224+
tile = self.quant_conv(tile)
225+
row.append(tile)
226+
rows.append(row)
227+
result_rows = []
228+
for i, row in enumerate(rows):
229+
result_row = []
230+
for j, tile in enumerate(row):
231+
# blend the above tile and the left tile
232+
# to the current tile and add the current tile to the result row
233+
if i > 0:
234+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
235+
if j > 0:
236+
tile = self.blend_h(row[j - 1], tile, blend_extent)
237+
result_row.append(tile[:, :, :row_limit, :row_limit])
238+
result_rows.append(torch.cat(result_row, dim=3))
239+
240+
moments = torch.cat(result_rows, dim=2)
241+
posterior = DiagonalGaussianDistribution(moments)
242+
243+
if not return_dict:
244+
return (posterior,)
245+
246+
return AutoencoderKLOutput(latent_dist=posterior)
247+
248+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
249+
r"""Decode a batch of images using a tiled decoder.
250+
Args:
251+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
252+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
253+
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
254+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
255+
look of the output, but they should be much less noticeable.
256+
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
257+
`True`):
258+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
259+
"""
260+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
261+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
262+
row_limit = self.tile_sample_min_size - blend_extent
263+
264+
# Split z into overlapping 64x64 tiles and decode them separately.
265+
# The tiles have an overlap to avoid seams between tiles.
266+
rows = []
267+
for i in range(0, z.shape[2], overlap_size):
268+
row = []
269+
for j in range(0, z.shape[3], overlap_size):
270+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
271+
tile = self.post_quant_conv(tile)
272+
decoded = self.decoder(tile)
273+
row.append(decoded)
274+
rows.append(row)
275+
result_rows = []
276+
for i, row in enumerate(rows):
277+
result_row = []
278+
for j, tile in enumerate(row):
279+
# blend the above tile and the left tile
280+
# to the current tile and add the current tile to the result row
281+
if i > 0:
282+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
283+
if j > 0:
284+
tile = self.blend_h(row[j - 1], tile, blend_extent)
285+
result_row.append(tile[:, :, :row_limit, :row_limit])
286+
result_rows.append(torch.cat(result_row, dim=3))
287+
288+
dec = torch.cat(result_rows, dim=2)
289+
if not return_dict:
290+
return (dec,)
291+
292+
return DecoderOutput(sample=dec)
293+
161294
def forward(
162295
self,
163296
sample: torch.FloatTensor,

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,22 @@ def disable_vae_slicing(self):
183183
"""
184184
self.vae.disable_slicing()
185185

186+
def enable_vae_tiling(self):
187+
r"""
188+
Enable tiled VAE decoding.
189+
190+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
191+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
192+
"""
193+
self.vae.enable_tiling()
194+
195+
def disable_vae_tiling(self):
196+
r"""
197+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
198+
computing decoding in one step.
199+
"""
200+
self.vae.disable_tiling()
201+
186202
def enable_sequential_cpu_offload(self, gpu_id=0):
187203
r"""
188204
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,22 @@ def disable_vae_slicing(self):
186186
"""
187187
self.vae.disable_slicing()
188188

189+
def enable_vae_tiling(self):
190+
r"""
191+
Enable tiled VAE decoding.
192+
193+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
194+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
195+
"""
196+
self.vae.enable_tiling()
197+
198+
def disable_vae_tiling(self):
199+
r"""
200+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
201+
computing decoding in one step.
202+
"""
203+
self.vae.disable_tiling()
204+
189205
def enable_sequential_cpu_offload(self, gpu_id=0):
190206
r"""
191207
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,

tests/pipelines/audio_diffusion/test_audio_diffusion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def dummy_vqvae_and_unet(self):
9696
)
9797
return vqvae, unet
9898

99+
@slow
99100
def test_audio_diffusion(self):
100101
device = "cpu" # ensure determinism for the device-dependent torch.Generator
101102
mel = Mel()

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,29 @@ def test_stable_diffusion_vae_slicing(self):
422422
# there is a small discrepancy at image borders vs. full batch decode
423423
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
424424

425+
def test_stable_diffusion_vae_tiling(self):
426+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
427+
components = self.get_dummy_components()
428+
429+
# make sure here that pndm scheduler skips prk
430+
components["safety_checker"] = None
431+
sd_pipe = StableDiffusionPipeline(**components)
432+
sd_pipe = sd_pipe.to(device)
433+
sd_pipe.set_progress_bar_config(disable=None)
434+
435+
prompt = "A painting of a squirrel eating a burger"
436+
437+
# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
438+
generator = torch.Generator(device=device).manual_seed(0)
439+
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
440+
441+
# make sure tiled vae decode yields the same result
442+
sd_pipe.enable_vae_tiling()
443+
generator = torch.Generator(device=device).manual_seed(0)
444+
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
445+
446+
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1
447+
425448
def test_stable_diffusion_negative_prompt(self):
426449
device = "cpu" # ensure determinism for the device-dependent torch.Generator
427450
components = self.get_dummy_components()
@@ -702,6 +725,58 @@ def test_stable_diffusion_vae_slicing(self):
702725
# There is a small discrepancy at the image borders vs. a fully batched version.
703726
assert np.abs(image_sliced - image).max() < 1e-2
704727

728+
def test_stable_diffusion_vae_tiling(self):
729+
torch.cuda.reset_peak_memory_stats()
730+
model_id = "CompVis/stable-diffusion-v1-4"
731+
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
732+
pipe.to(torch_device)
733+
pipe.set_progress_bar_config(disable=None)
734+
pipe.enable_attention_slicing()
735+
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
736+
pipe.vae = pipe.vae.to(memory_format=torch.channels_last)
737+
738+
prompt = "a photograph of an astronaut riding a horse"
739+
740+
# enable vae tiling
741+
pipe.enable_vae_tiling()
742+
generator = torch.Generator(device=torch_device).manual_seed(0)
743+
with torch.autocast(torch_device):
744+
output_chunked = pipe(
745+
[prompt],
746+
width=640,
747+
height=640,
748+
generator=generator,
749+
guidance_scale=7.5,
750+
num_inference_steps=2,
751+
output_type="numpy",
752+
)
753+
image_chunked = output_chunked.images
754+
755+
mem_bytes = torch.cuda.max_memory_allocated()
756+
torch.cuda.reset_peak_memory_stats()
757+
# make sure that less than 4 GB is allocated
758+
assert mem_bytes < 4e9
759+
760+
# disable vae tiling
761+
pipe.disable_vae_tiling()
762+
generator = torch.Generator(device=torch_device).manual_seed(0)
763+
with torch.autocast(torch_device):
764+
output = pipe(
765+
[prompt],
766+
width=640,
767+
height=640,
768+
generator=generator,
769+
guidance_scale=7.5,
770+
num_inference_steps=2,
771+
output_type="numpy",
772+
)
773+
image = output.images
774+
775+
# make sure that more than 4 GB is allocated
776+
mem_bytes = torch.cuda.max_memory_allocated()
777+
assert mem_bytes > 4e9
778+
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
779+
705780
def test_stable_diffusion_fp16_vs_autocast(self):
706781
# this test makes sure that the original model with autocast
707782
# and the new model with fp16 yield the same result

0 commit comments

Comments
 (0)