Skip to content

Commit 872c91e

Browse files
committed
add docs.
1 parent dc90b06 commit 872c91e

File tree

4 files changed

+106
-2
lines changed

4 files changed

+106
-2
lines changed

docs/source/en/api/quantization.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
2323

2424
</Tip>
2525

26+
## PipelineQuantizationConfig
27+
28+
[[autodoc]] PipelineQuantizationConfig
2629

2730
## BitsAndBytesConfig
2831

docs/source/en/quantization/overview.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,62 @@ Diffusers currently supports the following quantization methods.
3939
- [Quanto](./quanto.md)
4040

4141
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
42+
43+
## Pipeline-level quantization
44+
45+
Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
46+
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
47+
do this with [`PipelineQuantizationConfig`].
48+
49+
Start by defining a `PipelineQuantizationConfig`:
50+
51+
```py
52+
import torch
53+
from diffusers import DiffusionPipeline
54+
from diffusers.quantizers.quantization_config import QuantoConfig
55+
from diffusers.quantizers import PipelineQuantizationConfig
56+
from transformers import BitsAndBytesConfig
57+
58+
pipeline_quant_config = PipelineQuantizationConfig(
59+
quant_mapping={
60+
"transformer": QuantoConfig(weights_dtype="int8"),
61+
"text_encoder_2": BitsAndBytesConfig(
62+
load_in_4bit=True, compute_dtype=torch.bfloat16
63+
),
64+
}
65+
)
66+
```
67+
68+
Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:
69+
70+
```py
71+
pipe = DiffusionPipeline.from_pretrained(
72+
"black-forest-labs/FLUX.1-dev",
73+
quantization_config=pipeline_quant_config,
74+
torch_dtype=torch.bfloat16,
75+
).to("cuda")
76+
77+
image = pipe("photo of a cute dog").images[0]
78+
```
79+
80+
This method allows for more granular control over the quantization specifications of individual
81+
model-level components of a pipeline. It also allows for different quantization backends for
82+
different components. In the above example, you used a combination of Quanto and BitsandBytes.
83+
84+
The other method is simpler in terms of experience but is
85+
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:
86+
87+
```py
88+
pipeline_quant_config = PipelineQuantizationConfig(
89+
quant_backend="bitsandbytes_4bit",
90+
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
91+
components_to_quantize=["transformer", "text_encoder_2"],
92+
)
93+
```
94+
95+
This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.
96+
97+
In this case, `quant_kwargs` will be used to initialize the quantization specifications
98+
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
99+
is used to denote the components that will be quantized. For most pipelines, you would want to
100+
keep `transformer` in the list as that is often the most compute and memory intensive.

src/diffusers/quantizers/__init__.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,49 @@ class TransformersQuantConfigMixin:
3333

3434

3535
class PipelineQuantizationConfig:
36-
"""TODO"""
36+
"""
37+
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
38+
39+
Args:
40+
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
41+
is available to both `diffusers` and `transformers`.
42+
quant_kwargs (`dict`): Params to initialize the quantization backend class.
43+
components_to_quantize (`list`): Components of a pipeline to be quantized.
44+
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
45+
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
46+
and `components_to_quantize`.
47+
48+
Examples:
49+
50+
When using with `quant_backend`:
51+
52+
>>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.quantizers import
53+
PipelineQuantizationConfig
54+
55+
>>> pipeline_quant_config = PipelineQuantizationConfig( ... quant_backend="bitsandbytes_4bit", ... quant_kwargs={
56+
... "load_in_4bit": True, ... "bnb_4bit_quant_type": "nf4", ... "bnb_4bit_compute_dtype": torch.bfloat16, ... },
57+
... components_to_quantize=["transformer", "text_encoder_2"], ... )
58+
59+
>>> pipe = DiffusionPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ...
60+
quantization_config=pipeline_quant_config, ... torch_dtype=torch.bfloat16, ... ).to("cuda")
61+
62+
>>> image = pipe("photo of a cute dog").images[0]
63+
64+
When using with `quant_mapping`:
65+
66+
>>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.quantizers.quantization_config
67+
import QuantoConfig >>> from diffusers.quantizers import PipelineQuantizationConfig >>> from transformers import
68+
BitsAndBytesConfig
69+
70+
>>> pipeline_quant_config = PipelineQuantizationConfig( ... quant_mapping={ ... "transformer":
71+
QuantoConfig(weights_dtype="int8"), ... "text_encoder_2": BitsAndBytesConfig( ... load_in_4bit=True,
72+
compute_dtype=torch.bfloat16 ... ), ... } ... )
73+
74+
>>> pipe = DiffusionPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ...
75+
quantization_config=pipeline_quant_config, ... torch_dtype=torch.bfloat16, ... ).to("cuda")
76+
77+
>>> image = pipe("photo of a cute dog").images[0]
78+
"""
3779

3880
def __init__(
3981
self,

src/diffusers/quantizers/quantization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
7575
Args:
7676
config_dict (`Dict[str, Any]`):
7777
Dictionary that will be used to instantiate the configuration object.
78-
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
78+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
7979
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
8080
`PreTrainedModel`.
8181
kwargs (`Dict[str, Any]`):

0 commit comments

Comments
 (0)