Skip to content

Commit f5edaa7

Browse files
DN6sayakpaul
andauthored
[Quantization] Add Quanto backend (#10756)
* update * updaet * update * update * update * update * update * update * update * update * update * update * Update docs/source/en/quantization/quanto.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/quantizers/quanto/utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 9a1810f commit f5edaa7

21 files changed

+997
-3
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ jobs:
418418
test_location: "gguf"
419419
- backend: "torchao"
420420
test_location: "torchao"
421+
- backend: "optimum_quanto"
422+
test_location: "quanto"
421423
runs-on:
422424
group: aws-g6e-xlarge-plus
423425
container:

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@
173173
title: gguf
174174
- local: quantization/torchao
175175
title: torchao
176+
- local: quantization/quanto
177+
title: quanto
176178
title: Quantization Methods
177179
- sections:
178180
- local: optimization/fp16

docs/source/en/api/quantization.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
3131
## GGUFQuantizationConfig
3232

3333
[[autodoc]] GGUFQuantizationConfig
34+
35+
## QuantoConfig
36+
37+
[[autodoc]] QuantoConfig
38+
3439
## TorchAoConfig
3540

3641
[[autodoc]] TorchAoConfig

docs/source/en/quantization/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
3636
- [BitsandBytes](./bitsandbytes)
3737
- [TorchAO](./torchao)
3838
- [GGUF](./gguf)
39+
- [Quanto](./quanto.md)
3940

4041
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.

docs/source/en/quantization/quanto.md

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
-->
13+
14+
# Quanto
15+
16+
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
17+
18+
- All features are available in eager mode (works with non-traceable models)
19+
- Supports quantization aware training
20+
- Quantized models are compatible with `torch.compile`
21+
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
22+
23+
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
24+
25+
```shell
26+
pip install optimum-quanto accelerate
27+
```
28+
29+
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
30+
31+
```python
32+
import torch
33+
from diffusers import FluxTransformer2DModel, QuantoConfig
34+
35+
model_id = "black-forest-labs/FLUX.1-dev"
36+
quantization_config = QuantoConfig(weights_dtype="float8")
37+
transformer = FluxTransformer2DModel.from_pretrained(
38+
model_id,
39+
subfolder="transformer",
40+
quantization_config=quantization_config,
41+
torch_dtype=torch.bfloat16,
42+
)
43+
44+
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
45+
pipe.to("cuda")
46+
47+
prompt = "A cat holding a sign that says hello world"
48+
image = pipe(
49+
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
50+
).images[0]
51+
image.save("output.png")
52+
```
53+
54+
## Skipping Quantization on specific modules
55+
56+
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
57+
58+
```python
59+
import torch
60+
from diffusers import FluxTransformer2DModel, QuantoConfig
61+
62+
model_id = "black-forest-labs/FLUX.1-dev"
63+
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
64+
transformer = FluxTransformer2DModel.from_pretrained(
65+
model_id,
66+
subfolder="transformer",
67+
quantization_config=quantization_config,
68+
torch_dtype=torch.bfloat16,
69+
)
70+
```
71+
72+
## Using `from_single_file` with the Quanto Backend
73+
74+
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
75+
76+
```python
77+
import torch
78+
from diffusers import FluxTransformer2DModel, QuantoConfig
79+
80+
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
81+
quantization_config = QuantoConfig(weights_dtype="float8")
82+
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
83+
```
84+
85+
## Saving Quantized models
86+
87+
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
88+
89+
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
90+
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
91+
92+
```python
93+
import torch
94+
from diffusers import FluxTransformer2DModel, QuantoConfig
95+
96+
model_id = "black-forest-labs/FLUX.1-dev"
97+
quantization_config = QuantoConfig(weights_dtype="float8")
98+
transformer = FluxTransformer2DModel.from_pretrained(
99+
model_id,
100+
subfolder="transformer",
101+
quantization_config=quantization_config,
102+
torch_dtype=torch.bfloat16,
103+
)
104+
# save quantized model to reuse
105+
transformer.save_pretrained("<your quantized model save path>")
106+
107+
# you can reload your quantized model with
108+
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
109+
```
110+
111+
## Using `torch.compile` with Quanto
112+
113+
Currently the Quanto backend supports `torch.compile` for the following quantization types:
114+
115+
- `int8` weights
116+
117+
```python
118+
import torch
119+
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
120+
121+
model_id = "black-forest-labs/FLUX.1-dev"
122+
quantization_config = QuantoConfig(weights_dtype="int8")
123+
transformer = FluxTransformer2DModel.from_pretrained(
124+
model_id,
125+
subfolder="transformer",
126+
quantization_config=quantization_config,
127+
torch_dtype=torch.bfloat16,
128+
)
129+
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
130+
131+
pipe = FluxPipeline.from_pretrained(
132+
model_id, transformer=transformer, torch_dtype=torch_dtype
133+
)
134+
pipe.to("cuda")
135+
images = pipe("A cat holding a sign that says hello").images[0]
136+
images.save("flux-quanto-compile.png")
137+
```
138+
139+
## Supported Quantization Types
140+
141+
### Weights
142+
143+
- float8
144+
- int8
145+
- int4
146+
- int2
147+
148+

setup.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@
128128
"GitPython<3.1.19",
129129
"scipy",
130130
"onnx",
131+
"optimum_quanto>=0.2.6",
132+
"gguf>=0.10.0",
133+
"torchao>=0.7.0",
134+
"bitsandbytes>=0.43.3",
131135
"regex!=2019.12.17",
132136
"requests",
133137
"tensorboard",
@@ -235,6 +239,11 @@ def run(self):
235239
)
236240
extras["torch"] = deps_list("torch", "accelerate")
237241

242+
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
243+
extras["gguf"] = deps_list("gguf", "accelerate")
244+
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
245+
extras["torchao"] = deps_list("torchao", "accelerate")
246+
238247
if os.name == "nt": # windows
239248
extras["flax"] = [] # jax is not supported on windows
240249
else:

src/diffusers/__init__.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
from typing import TYPE_CHECKING
44

5+
from diffusers.quantizers import quantization_config
6+
from diffusers.utils import dummy_gguf_objects
7+
from diffusers.utils.import_utils import (
8+
is_bitsandbytes_available,
9+
is_gguf_available,
10+
is_optimum_quanto_version,
11+
is_torchao_available,
12+
)
13+
514
from .utils import (
615
DIFFUSERS_SLOW_IMPORT,
716
OptionalDependencyNotAvailable,
@@ -11,6 +20,7 @@
1120
is_librosa_available,
1221
is_note_seq_available,
1322
is_onnx_available,
23+
is_optimum_quanto_available,
1424
is_scipy_available,
1525
is_sentencepiece_available,
1626
is_torch_available,
@@ -32,7 +42,7 @@
3242
"loaders": ["FromOriginalModelMixin"],
3343
"models": [],
3444
"pipelines": [],
35-
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
45+
"quantizers.quantization_config": [],
3646
"schedulers": [],
3747
"utils": [
3848
"OptionalDependencyNotAvailable",
@@ -54,6 +64,55 @@
5464
],
5565
}
5666

67+
try:
68+
if not is_bitsandbytes_available():
69+
raise OptionalDependencyNotAvailable()
70+
except OptionalDependencyNotAvailable:
71+
from .utils import dummy_bitsandbytes_objects
72+
73+
_import_structure["utils.dummy_bitsandbytes_objects"] = [
74+
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
75+
]
76+
else:
77+
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
78+
79+
try:
80+
if not is_gguf_available():
81+
raise OptionalDependencyNotAvailable()
82+
except OptionalDependencyNotAvailable:
83+
from .utils import dummy_gguf_objects
84+
85+
_import_structure["utils.dummy_gguf_objects"] = [
86+
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
87+
]
88+
else:
89+
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
90+
91+
try:
92+
if not is_torchao_available():
93+
raise OptionalDependencyNotAvailable()
94+
except OptionalDependencyNotAvailable:
95+
from .utils import dummy_torchao_objects
96+
97+
_import_structure["utils.dummy_torchao_objects"] = [
98+
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
99+
]
100+
else:
101+
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
102+
103+
try:
104+
if not is_optimum_quanto_available():
105+
raise OptionalDependencyNotAvailable()
106+
except OptionalDependencyNotAvailable:
107+
from .utils import dummy_optimum_quanto_objects
108+
109+
_import_structure["utils.dummy_optimum_quanto_objects"] = [
110+
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
111+
]
112+
else:
113+
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
114+
115+
57116
try:
58117
if not is_onnx_available():
59118
raise OptionalDependencyNotAvailable()
@@ -599,7 +658,38 @@
599658

600659
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
601660
from .configuration_utils import ConfigMixin
602-
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
661+
662+
try:
663+
if not is_bitsandbytes_available():
664+
raise OptionalDependencyNotAvailable()
665+
except OptionalDependencyNotAvailable:
666+
from .utils.dummy_bitsandbytes_objects import *
667+
else:
668+
from .quantizers.quantization_config import BitsAndBytesConfig
669+
670+
try:
671+
if not is_gguf_available():
672+
raise OptionalDependencyNotAvailable()
673+
except OptionalDependencyNotAvailable:
674+
from .utils.dummy_gguf_objects import *
675+
else:
676+
from .quantizers.quantization_config import GGUFQuantizationConfig
677+
678+
try:
679+
if not is_torchao_available():
680+
raise OptionalDependencyNotAvailable()
681+
except OptionalDependencyNotAvailable:
682+
from .utils.dummy_torchao_objects import *
683+
else:
684+
from .quantizers.quantization_config import TorchAoConfig
685+
686+
try:
687+
if not is_optimum_quanto_available():
688+
raise OptionalDependencyNotAvailable()
689+
except OptionalDependencyNotAvailable:
690+
from .utils.dummy_optimum_quanto_objects import *
691+
else:
692+
from .quantizers.quantization_config import QuantoConfig
603693

604694
try:
605695
if not is_onnx_available():

src/diffusers/dependency_versions_table.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
"GitPython": "GitPython<3.1.19",
3636
"scipy": "scipy",
3737
"onnx": "onnx",
38+
"optimum_quanto": "optimum_quanto>=0.2.6",
39+
"gguf": "gguf>=0.10.0",
40+
"torchao": "torchao>=0.7.0",
41+
"bitsandbytes": "bitsandbytes>=0.43.3",
3842
"regex": "regex!=2019.12.17",
3943
"requests": "requests",
4044
"tensorboard": "tensorboard",

src/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ def load_model_dict_into_meta(
245245
):
246246
param = param.to(torch.float32)
247247
set_module_kwargs["dtype"] = torch.float32
248+
# For quantizers have save weights using torch.float8_e4m3fn
249+
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
250+
pass
248251
else:
249252
param = param.to(dtype)
250253
set_module_kwargs["dtype"] = dtype
@@ -292,7 +295,9 @@ def load_model_dict_into_meta(
292295
elif is_quantized and (
293296
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
294297
):
295-
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
298+
hf_quantizer.create_quantized_param(
299+
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
300+
)
296301
else:
297302
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
298303

src/diffusers/quantizers/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,26 @@
2626
GGUFQuantizationConfig,
2727
QuantizationConfigMixin,
2828
QuantizationMethod,
29+
QuantoConfig,
2930
TorchAoConfig,
3031
)
32+
from .quanto import QuantoQuantizer
3133
from .torchao import TorchAoHfQuantizer
3234

3335

3436
AUTO_QUANTIZER_MAPPING = {
3537
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
3638
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
3739
"gguf": GGUFQuantizer,
40+
"quanto": QuantoQuantizer,
3841
"torchao": TorchAoHfQuantizer,
3942
}
4043

4144
AUTO_QUANTIZATION_CONFIG_MAPPING = {
4245
"bitsandbytes_4bit": BitsAndBytesConfig,
4346
"bitsandbytes_8bit": BitsAndBytesConfig,
4447
"gguf": GGUFQuantizationConfig,
48+
"quanto": QuantoConfig,
4549
"torchao": TorchAoConfig,
4650
}
4751

0 commit comments

Comments
 (0)