You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/optimization/memory.md
+14-10Lines changed: 14 additions & 10 deletions
Original file line number
Diff line number
Diff line change
@@ -103,12 +103,14 @@ Diffusers uses the maxmium memory of all devices by default, but if they don't f
103
103
Use the [`~DiffusionPipeline.reset_device_map`] method to reset the `device_map`. This is necessary if you want to use methods like `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
104
104
105
105
```py
106
-
pipeline.reset_device_map
106
+
pipeline.reset_device_map()
107
107
```
108
108
109
-
## Sliced VAE
109
+
## VAE slicing
110
110
111
-
Sliced VAE saves memory by processing an image in smaller non-overlapping "slices" instead of processing the entire image at once. This reduces peak memory usage because the GPU is only processing a small slice at a time.
111
+
VAE slicing saves memory by splitting large batches of inputs into a single batch of data and separately processing them. This method works best when generating more than one image at a time.
112
+
113
+
For example, if you're generating 4 images at once, decoding would increase peak activation memory by 4x. VAE slicing reduces this by only decoding 1 image at a time instead of all 4 images at once.
112
114
113
115
Call [`~StableDiffusionPipeline.enable_vae_slicing`] to enable sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.
VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time. Unlike sliced VAE, tiled VAE maintains some context between tiles because they overlap which can generate more coherent images.
135
+
VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time.
134
136
135
-
Call [`~StableDiffusionPipeline.enable_vae_tiling`] to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they're decoded separately, but there shouldn't be any obvious seams between the tiles. Tiling is disabled for images that are 512x512 or smaller.
137
+
Call [`~StableDiffusionPipeline.enable_vae_tiling`] to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they're decoded separately, but there shouldn't be any obvious seams between the tiles. Tiling is disabled for resolutions lower than a pre-specified (but configurable) limit. For example, this limit is 512x512 for the VAE in [`StableDiffusionPipeline`].
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
159
161
160
-
CPU offloading dramatically reduces memory usage, but it is also extremely slow because submodules are passed back and forth multiple times between devices.
162
+
CPU offloading dramatically reduces memory usage, but it is also **extremely slow** because submodules are passed back and forth multiple times between devices. It can often be impractical due to how slow it is.
161
163
162
164
> [!WARNING]
163
165
> Don't move the pipeline to CUDA before calling [`~DiffusionPipeline.enable_sequential_cpu_offload`], otherwise the amount of memory saved is only minimal (refer to this [issue](https://github.com/huggingface/diffusers/issues/1934) for more details). This is a stateful operation that installs hooks on the model.
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
190
192
191
193
> [!WARNING]
192
-
> To properly offload models after they're called, it is required to run the entire pipeline and models in the expected order. Keep this in mind if models are reused outside the pipeline context after hooks have been installed (see [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) for more details). This is a stateful operation that installs hooks on the model.
194
+
> Keep in mind that if models are reused outside the pipeline after hookes have been installed (see [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) for more details), you need to run the entire pipeline and models in the expected order to properly offload them. This is a stateful operation that installs hooks on the model.
193
195
194
196
Call [`~DiffusionPipeline.enable_model_cpu_offload`] to enable it on a pipeline.
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
219
+
216
220
## Group offloading
217
221
218
222
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
274
278
275
-
## FP8 layerwise casting
279
+
## Layerwise casting
276
280
277
-
Layerwise casting stores weights in a smaller data format (`torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
281
+
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like`torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
278
282
279
283
> [!WARNING]
280
284
> Layerwise casting may not work with all models if the forward implementation contains internal typecasting of weights. The current implementation of layerwise casting assumes the forward pass is independent of the weight precision and the input datatypes are always specified in `compute_dtype` (see [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299) for an incompatible implementation).
0 commit comments