Skip to content

feat: pipeline-level quantization config #11130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
316ff46
feat: pipeline-level quant config.
sayakpaul Mar 10, 2025
eec5b98
Revert "feat: pipeline-level quant config."
sayakpaul Mar 20, 2025
c94d85a
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Mar 20, 2025
4d3dede
feat: implement pipeline-level quantization config
sayakpaul Mar 21, 2025
dc79f32
update
sayakpaul Mar 21, 2025
df749e4
fixes
sayakpaul Mar 21, 2025
d0ad15e
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Mar 21, 2025
f8b514b
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Mar 26, 2025
9250941
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Mar 27, 2025
13d5589
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 2, 2025
f678437
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 10, 2025
5a85871
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 11, 2025
557136d
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 17, 2025
0d9814f
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 21, 2025
f8d1bd1
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 24, 2025
c7e0774
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 25, 2025
6861da1
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 28, 2025
82bcce0
fix validation.
sayakpaul Apr 29, 2025
78f134b
add tests and other improvements.
sayakpaul Apr 29, 2025
f2b39e0
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul Apr 29, 2025
3b76e0a
add tests
sayakpaul Apr 29, 2025
695061b
import quality
sayakpaul Apr 29, 2025
9693251
remove prints.
sayakpaul Apr 29, 2025
73f1ad1
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 1, 2025
dc90b06
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 2, 2025
872c91e
add docs.
sayakpaul May 2, 2025
fbdf4c6
fixes to docs.
sayakpaul May 2, 2025
da6df86
doc fixes.
sayakpaul May 2, 2025
9a418a9
doc fixes.
sayakpaul May 2, 2025
5b6ee10
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 3, 2025
478a353
add validation to the input quantization_config.
sayakpaul May 6, 2025
f96bcc7
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 6, 2025
0ae2a9a
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 8, 2025
d6b48ea
clarify recommendations.
sayakpaul May 8, 2025
ca2e116
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 8, 2025
ffb974f
docs
sayakpaul May 8, 2025
86ee773
add to ci.
sayakpaul May 8, 2025
037a68b
Merge branch 'main' into feat/pipeline-quant-config
sayakpaul May 9, 2025
7b8a73d
todo.
sayakpaul May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ specific language governing permissions and limitations under the License.

# Quantization

Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).

Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.

<Tip>

Learn how to quantize models in the [Quantization](../quantization/overview) guide.

</Tip>

## PipelineQuantizationConfig

[[autodoc]] quantizers.PipelineQuantizationConfig

## BitsAndBytesConfig

Expand Down
76 changes: 76 additions & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,79 @@ Diffusers currently supports the following quantization methods.
- [Quanto](./quanto.md)

[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.

## Pipeline-level quantization

Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
do this with [`~quantizers.PipelineQuantizationConfig`].

Start by defining a `PipelineQuantizationConfig`:

```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers.quantization_config import QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig

pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": BitsAndBytesConfig(
load_in_4bit=True, compute_dtype=torch.bfloat16
),
}
)
```

Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:

```py
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")

image = pipe("photo of a cute dog").images[0]
```

This method allows for more granular control over the quantization specifications of individual
model-level components of a pipeline. It also allows for different quantization backends for
different components. In the above example, you used a combination of Quanto and BitsandBytes.

The other method is simpler in terms of experience but is
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:

```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
```

This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.

In this case, `quant_kwargs` will be used to initialize the quantization specifications
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
is used to denote the components that will be quantized. For most pipelines, you would want to
keep `transformer` in the list as that is often the most compute and memory intensive.

The config below will work for most diffusion pipelines that have a `transformer` component present.
In most case, you will want to quantize the `transformer` component as that is often the most compute-
intensive part of a diffusion pipeline.

```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer"],
)
```

Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).
13 changes: 13 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,10 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig

# retrieve class candidates

Expand Down Expand Up @@ -769,6 +771,17 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False

if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
and issubclass(class_obj, torch.nn.Module)
):
model_quant_config = quantization_config._resolve_quant_config(
is_diffusers=is_diffusers_model, module_name=name
)
if model_quant_config is not None:
loading_kwargs["quantization_config"] = model_quant_config

# check if the module is in a subdirectory
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
Expand Down Expand Up @@ -725,6 +726,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)

if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
Expand All @@ -741,6 +743,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" install accelerate\n```\n."
)

if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")

if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
Expand Down Expand Up @@ -1001,6 +1006,7 @@ def load_module(name, value):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
quantization_config=quantization_config,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
Expand Down
178 changes: 178 additions & 0 deletions src/diffusers/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,183 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Dict, List, Optional, Union

from ..utils import is_transformers_available, logging
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin


try:
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
except ImportError:

class TransformersQuantConfigMixin:
pass


logger = logging.get_logger(__name__)


class PipelineQuantizationConfig:
"""
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].

Args:
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
is available to both `diffusers` and `transformers`.
quant_kwargs (`dict`): Params to initialize the quantization backend class.
components_to_quantize (`list`): Components of a pipeline to be quantized.
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
and `components_to_quantize`.
"""

def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping

self.post_init()

def post_init(self):
quant_mapping = self.quant_mapping
self.is_granular = True if quant_mapping is not None else False

self._validate_init_args()

def _validate_init_args(self):
if self.quant_backend and self.quant_mapping:
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")

if not self.quant_mapping and not self.quant_backend:
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")

if not self.quant_kwargs and not self.quant_mapping:
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")

if self.quant_backend is not None:
self._validate_init_kwargs_in_backends()

if self.quant_mapping is not None:
self._validate_quant_mapping_args()

def _validate_init_kwargs_in_backends(self):
quant_backend = self.quant_backend

self._check_backend_availability(quant_backend)

quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()

if quant_config_mapping_transformers is not None:
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
else:
init_kwargs_transformers = None

init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}

if init_kwargs_transformers != init_kwargs_diffusers:
raise ValueError(
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to the docs to learn more about how "
"this mapping would look like: TODO."
)

def _validate_quant_mapping_args(self):
quant_mapping = self.quant_mapping
transformers_map, diffusers_map = self._get_quant_config_list()

available_transformers = list(transformers_map.values()) if transformers_map else None
available_diffusers = list(diffusers_map.values())

for module_name, config in quant_mapping.items():
if any(isinstance(config, cfg) for cfg in available_diffusers):
continue

if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
continue

if available_transformers:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}; "
f"Available transformers configs: {available_transformers}."
)
else:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}."
)

def _check_backend_availability(self, quant_backend: str):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()

available_backends_transformers = (
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
)
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())

if (
available_backends_transformers and quant_backend not in available_backends_transformers
) or quant_backend not in quant_config_mapping_diffusers:
error_message = f"Provided quant_backend={quant_backend} was not found."
if available_backends_transformers:
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
raise ValueError(error_message)

def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()

quant_mapping = self.quant_mapping
components_to_quantize = self.components_to_quantize

# Granular case
if self.is_granular and module_name in quant_mapping:
logger.debug(f"Initializing quantization config class for {module_name}.")
config = quant_mapping[module_name]
return config

# Global config case
else:
should_quantize = False
# Only quantize the modules requested for.
if components_to_quantize and module_name in components_to_quantize:
should_quantize = True
# No specification for `components_to_quantize` means all modules should be quantized.
elif not self.is_granular and not components_to_quantize:
should_quantize = True

if should_quantize:
logger.debug(f"Initializing quantization config class for {module_name}.")
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
quant_config_cls = mapping_to_use[self.quant_backend]
quant_kwargs = self.quant_kwargs
return quant_config_cls(**quant_kwargs)

# Fallback: no applicable configuration found.
return None

def _get_quant_config_list(self):
if is_transformers_available():
from transformers.quantizers.auto import (
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
)
else:
quant_config_mapping_transformers = None

from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers

return quant_config_mapping_transformers, quant_config_mapping_diffusers
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
is_peft_available,
is_timm_available,
is_torch_available,
Expand Down Expand Up @@ -486,6 +487,13 @@ def require_bitsandbytes(test_case):
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)


def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)


def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
Expand Down
Loading
Loading