Skip to content

Commit 7186bb4

Browse files
authored
Add enable_vae_tiling to AllegroPipeline, fix example (#10212)
1 parent 438bd60 commit 7186bb4

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
6060
>>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32)
6161
>>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda")
62+
>>> pipe.enable_vae_tiling()
6263
6364
>>> prompt = (
6465
... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
@@ -636,6 +637,35 @@ def _prepare_rotary_positional_embeddings(
636637

637638
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
638639

640+
def enable_vae_slicing(self):
641+
r"""
642+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
643+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
644+
"""
645+
self.vae.enable_slicing()
646+
647+
def disable_vae_slicing(self):
648+
r"""
649+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
650+
computing decoding in one step.
651+
"""
652+
self.vae.disable_slicing()
653+
654+
def enable_vae_tiling(self):
655+
r"""
656+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
657+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
658+
processing larger images.
659+
"""
660+
self.vae.enable_tiling()
661+
662+
def disable_vae_tiling(self):
663+
r"""
664+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
665+
computing decoding in one step.
666+
"""
667+
self.vae.disable_tiling()
668+
639669
@property
640670
def guidance_scale(self):
641671
return self._guidance_scale

0 commit comments

Comments
 (0)