diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index a40be8558499..70dcf0a5f9cb 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -418,6 +418,8 @@ jobs: test_location: "gguf" - backend: "torchao" test_location: "torchao" + - backend: "optimum_quanto" + test_location: "quanto" runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9438fe1a55e1..8811fca5f5a2 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -173,6 +173,8 @@ title: gguf - local: quantization/torchao title: torchao + - local: quantization/quanto + title: quanto title: Quantization Methods - sections: - local: optimization/fp16 diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 168a9a03473f..2c728cff3c07 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ## GGUFQuantizationConfig [[autodoc]] GGUFQuantizationConfig + +## QuantoConfig + +[[autodoc]] QuantoConfig + ## TorchAoConfig [[autodoc]] TorchAoConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 794098e210a6..93323f86c7fc 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods. - [BitsandBytes](./bitsandbytes) - [TorchAO](./torchao) - [GGUF](./gguf) +- [Quanto](./quanto.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. diff --git a/docs/source/en/quantization/quanto.md b/docs/source/en/quantization/quanto.md new file mode 100644 index 000000000000..d322d76be267 --- /dev/null +++ b/docs/source/en/quantization/quanto.md @@ -0,0 +1,148 @@ + + +# Quanto + +[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind: + +- All features are available in eager mode (works with non-traceable models) +- Supports quantization aware training +- Quantized models are compatible with `torch.compile` +- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU) + +In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate` + +```shell +pip install optimum-quanto accelerate +``` + +Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto. + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="float8") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) + +pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe( + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 +).images[0] +image.save("output.png") +``` + +## Skipping Quantization on specific modules + +It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict` + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"]) +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +``` + +## Using `from_single_file` with the Quanto Backend + +`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`. + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" +quantization_config = QuantoConfig(weights_dtype="float8") +transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16) +``` + +## Saving Quantized models + +Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method. + +The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized +with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained` + +```python +import torch +from diffusers import FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="float8") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +# save quantized model to reuse +transformer.save_pretrained("") + +# you can reload your quantized model with +model = FluxTransformer2DModel.from_pretrained("") +``` + +## Using `torch.compile` with Quanto + +Currently the Quanto backend supports `torch.compile` for the following quantization types: + +- `int8` weights + +```python +import torch +from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig + +model_id = "black-forest-labs/FLUX.1-dev" +quantization_config = QuantoConfig(weights_dtype="int8") +transformer = FluxTransformer2DModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) + +pipe = FluxPipeline.from_pretrained( + model_id, transformer=transformer, torch_dtype=torch_dtype +) +pipe.to("cuda") +images = pipe("A cat holding a sign that says hello").images[0] +images.save("flux-quanto-compile.png") +``` + +## Supported Quantization Types + +### Weights + +- float8 +- int8 +- int4 +- int2 + + diff --git a/setup.py b/setup.py index 93945ae040dd..fdc166a81ecf 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,10 @@ "GitPython<3.1.19", "scipy", "onnx", + "optimum_quanto>=0.2.6", + "gguf>=0.10.0", + "torchao>=0.7.0", + "bitsandbytes>=0.43.3", "regex!=2019.12.17", "requests", "tensorboard", @@ -235,6 +239,11 @@ def run(self): ) extras["torch"] = deps_list("torch", "accelerate") +extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate") +extras["gguf"] = deps_list("gguf", "accelerate") +extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate") +extras["torchao"] = deps_list("torchao", "accelerate") + if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows else: diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cfb0bd08f818..6d3d2e109581 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -2,6 +2,15 @@ from typing import TYPE_CHECKING +from diffusers.quantizers import quantization_config +from diffusers.utils import dummy_gguf_objects +from diffusers.utils.import_utils import ( + is_bitsandbytes_available, + is_gguf_available, + is_optimum_quanto_version, + is_torchao_available, +) + from .utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, @@ -11,6 +20,7 @@ is_librosa_available, is_note_seq_available, is_onnx_available, + is_optimum_quanto_available, is_scipy_available, is_sentencepiece_available, is_torch_available, @@ -32,7 +42,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], + "quantizers.quantization_config": [], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -54,6 +64,55 @@ ], } +try: + if not is_bitsandbytes_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_bitsandbytes_objects + + _import_structure["utils.dummy_bitsandbytes_objects"] = [ + name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig") + +try: + if not is_gguf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_gguf_objects + + _import_structure["utils.dummy_gguf_objects"] = [ + name for name in dir(dummy_gguf_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig") + +try: + if not is_torchao_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_torchao_objects + + _import_structure["utils.dummy_torchao_objects"] = [ + name for name in dir(dummy_torchao_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("TorchAoConfig") + +try: + if not is_optimum_quanto_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_optimum_quanto_objects + + _import_structure["utils.dummy_optimum_quanto_objects"] = [ + name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("QuantoConfig") + + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -598,7 +657,38 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig + + try: + if not is_bitsandbytes_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_bitsandbytes_objects import * + else: + from .quantizers.quantization_config import BitsAndBytesConfig + + try: + if not is_gguf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_gguf_objects import * + else: + from .quantizers.quantization_config import GGUFQuantizationConfig + + try: + if not is_torchao_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_torchao_objects import * + else: + from .quantizers.quantization_config import TorchAoConfig + + try: + if not is_optimum_quanto_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_optimum_quanto_objects import * + else: + from .quantizers.quantization_config import QuantoConfig try: if not is_onnx_available(): diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 17d5da60347d..8ec95ed6fc8d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -35,6 +35,10 @@ "GitPython": "GitPython<3.1.19", "scipy": "scipy", "onnx": "onnx", + "optimum_quanto": "optimum_quanto>=0.2.6", + "gguf": "gguf>=0.10.0", + "torchao": "torchao>=0.7.0", + "bitsandbytes": "bitsandbytes>=0.43.3", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index f019a3cc67a6..741f7075d76d 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -245,6 +245,9 @@ def load_model_dict_into_meta( ): param = param.to(torch.float32) set_module_kwargs["dtype"] = torch.float32 + # For quantizers have save weights using torch.float8_e4m3fn + elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None): + pass else: param = param.to(dtype) set_module_kwargs["dtype"] = dtype @@ -292,7 +295,9 @@ def load_model_dict_into_meta( elif is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): - hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) + hf_quantizer.create_quantized_param( + model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype + ) else: set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index d9874cc282ae..ce214ae7bc17 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -26,8 +26,10 @@ GGUFQuantizationConfig, QuantizationConfigMixin, QuantizationMethod, + QuantoConfig, TorchAoConfig, ) +from .quanto import QuantoQuantizer from .torchao import TorchAoHfQuantizer @@ -35,6 +37,7 @@ "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, "gguf": GGUFQuantizer, + "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, } @@ -42,6 +45,7 @@ "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, "gguf": GGUFQuantizationConfig, + "quanto": QuantoConfig, "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 4fac8dd3829f..0bc433be0ff3 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" GGUF = "gguf" TORCHAO = "torchao" + QUANTO = "quanto" if is_torchao_available(): @@ -686,3 +687,38 @@ def __repr__(self): return ( f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n" ) + + +@dataclass +class QuantoConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `quanto`. + + Args: + weights_dtype (`str`, *optional*, defaults to `"int8"`): + The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2") + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + """ + + def __init__( + self, + weights_dtype: str = "int8", + modules_to_not_convert: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.QUANTO + self.weights_dtype = weights_dtype + self.modules_to_not_convert = modules_to_not_convert + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + accepted_weights = ["float8", "int8", "int4", "int2"] + if self.weights_dtype not in accepted_weights: + raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") diff --git a/src/diffusers/quantizers/quanto/__init__.py b/src/diffusers/quantizers/quanto/__init__.py new file mode 100644 index 000000000000..a4e8a1f41a1e --- /dev/null +++ b/src/diffusers/quantizers/quanto/__init__.py @@ -0,0 +1 @@ +from .quanto_quantizer import QuantoQuantizer diff --git a/src/diffusers/quantizers/quanto/quanto_quantizer.py b/src/diffusers/quantizers/quanto/quanto_quantizer.py new file mode 100644 index 000000000000..0120163804c9 --- /dev/null +++ b/src/diffusers/quantizers/quanto/quanto_quantizer.py @@ -0,0 +1,177 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from diffusers.utils.import_utils import is_optimum_quanto_version + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_accelerate_version, + is_optimum_quanto_available, + is_torch_available, + logging, +) +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate.utils import CustomDtype, set_module_tensor_to_device + +if is_optimum_quanto_available(): + from .utils import _replace_with_quanto_layers + +logger = logging.get_logger(__name__) + + +class QuantoQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for Optimum Quanto + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["quanto", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_optimum_quanto_available(): + raise ImportError( + "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)" + ) + if not is_optimum_quanto_version(">=", "0.2.6"): + raise ImportError( + "Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. " + "Please upgrade your installation with `pip install --upgrade optimum-quanto" + ) + + if not is_accelerate_available(): + raise ImportError( + "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)" + ) + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + raise ValueError( + "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" + ) + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + # Quanto imports diffusers internally. This is here to prevent circular imports + from optimum.quanto import QModuleMixin, QTensor + from optimum.quanto.tensor.packed import PackedTensor + + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]): + return True + elif isinstance(module, QModuleMixin) and "weight" in tensor_name: + return not module.frozen + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create the quantized parameter by calling .freeze() after setting it to the module. + """ + + dtype = kwargs.get("dtype", torch.float32) + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + setattr(module, tensor_name, param_value) + else: + set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) + module.freeze() + module.weight.requires_grad = False + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if is_accelerate_version(">=", "0.27.0"): + mapping = { + "int8": torch.int8, + "float8": CustomDtype.FP8, + "int4": CustomDtype.INT4, + "int2": CustomDtype.INT2, + } + target_dtype = mapping[self.quantization_config.weights_dtype] + + return target_dtype + + def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": + if torch_dtype is None: + logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") + torch_dtype = torch.float32 + return torch_dtype + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + # Quanto imports diffusers internally. This is here to prevent circular imports + from optimum.quanto import QModuleMixin + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, QModuleMixin): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + model = _replace_with_quanto_layers( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + return model + + @property + def is_trainable(self): + return True + + @property + def is_serializable(self): + return True diff --git a/src/diffusers/quantizers/quanto/utils.py b/src/diffusers/quantizers/quanto/utils.py new file mode 100644 index 000000000000..6f41fd36b43a --- /dev/null +++ b/src/diffusers/quantizers/quanto/utils.py @@ -0,0 +1,60 @@ +import torch.nn as nn + +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False): + # Quanto imports diffusers internally. These are placed here to avoid circular imports + from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8 + + def _get_weight_type(dtype: str): + return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype] + + def _replace_layers(model, quantization_config, modules_to_not_convert): + has_children = list(model.children()) + if not has_children: + return model + + for name, module in model.named_children(): + _replace_layers(module, quantization_config, modules_to_not_convert) + + if name in modules_to_not_convert: + continue + + if isinstance(module, nn.Linear): + with init_empty_weights(): + qlinear = QLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + dtype=module.weight.dtype, + weights=_get_weight_type(quantization_config.weights_dtype), + ) + model._modules[name] = qlinear + model._modules[name].source_cls = type(module) + model._modules[name].requires_grad_(False) + + return model + + model = _replace_layers(model, quantization_config, modules_to_not_convert) + has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules()) + + if not has_been_replaced: + logger.warning( + f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied." + " Please check your model architecture, or submit an issue on Github if you think this is a bug." + " https://github.com/huggingface/diffusers/issues/new" + ) + + # We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict + # to match when trying to load weights with load_model_dict_into_meta + if pre_quantized: + freeze(model) + + return model diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6702ea2efbc8..1684c434f55e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -79,6 +79,8 @@ is_matplotlib_available, is_note_seq_available, is_onnx_available, + is_optimum_quanto_available, + is_optimum_quanto_version, is_peft_available, is_peft_version, is_safetensors_available, diff --git a/src/diffusers/utils/dummy_bitsandbytes_objects.py b/src/diffusers/utils/dummy_bitsandbytes_objects.py new file mode 100644 index 000000000000..2dc589428de9 --- /dev/null +++ b/src/diffusers/utils/dummy_bitsandbytes_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class BitsAndBytesConfig(metaclass=DummyObject): + _backends = ["bitsandbytes"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["bitsandbytes"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["bitsandbytes"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["bitsandbytes"]) diff --git a/src/diffusers/utils/dummy_gguf_objects.py b/src/diffusers/utils/dummy_gguf_objects.py new file mode 100644 index 000000000000..4a6d9a060a13 --- /dev/null +++ b/src/diffusers/utils/dummy_gguf_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class GGUFQuantizationConfig(metaclass=DummyObject): + _backends = ["gguf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["gguf"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["gguf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["gguf"]) diff --git a/src/diffusers/utils/dummy_optimum_quanto_objects.py b/src/diffusers/utils/dummy_optimum_quanto_objects.py new file mode 100644 index 000000000000..44f8eaffc246 --- /dev/null +++ b/src/diffusers/utils/dummy_optimum_quanto_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class QuantoConfig(metaclass=DummyObject): + _backends = ["optimum_quanto"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["optimum_quanto"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["optimum_quanto"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["optimum_quanto"]) diff --git a/src/diffusers/utils/dummy_torchao_objects.py b/src/diffusers/utils/dummy_torchao_objects.py new file mode 100644 index 000000000000..16f0f6a55f64 --- /dev/null +++ b/src/diffusers/utils/dummy_torchao_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class TorchAoConfig(metaclass=DummyObject): + _backends = ["torchao"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchao"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torchao"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torchao"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index ae1b9cae6edc..b6aa8e96e619 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -365,6 +365,15 @@ def is_timm_available(): _is_torchao_available = False +_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None +if _is_optimum_quanto_available: + try: + _optimum_quanto_version = importlib_metadata.version("optimum_quanto") + logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}") + except importlib_metadata.PackageNotFoundError: + _is_optimum_quanto_available = False + + def is_torch_available(): return _torch_available @@ -493,6 +502,10 @@ def is_torchao_available(): return _is_torchao_available +def is_optimum_quanto_available(): + return _is_optimum_quanto_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -636,6 +649,11 @@ def is_torchao_available(): torchao` """ +QUANTO_IMPORT_ERROR = """ +{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip +install optimum-quanto` +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -663,6 +681,7 @@ def is_torchao_available(): ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)), ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), + ("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)), ] ) @@ -864,6 +883,21 @@ def is_k_diffusion_version(operation: str, version: str): return compare_versions(parse(_k_diffusion_version), operation, version) +def is_optimum_quanto_version(operation: str, version: str): + """ + Compares the current Accelerate version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _is_optimum_quanto_available: + return False + return compare_versions(parse(_optimum_quanto_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py new file mode 100644 index 000000000000..89a56c15ed24 --- /dev/null +++ b/tests/quantization/quanto/test_quanto.py @@ -0,0 +1,346 @@ +import gc +import tempfile +import unittest + +from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig +from diffusers.models.attention_processor import Attention +from diffusers.utils import is_optimum_quanto_available, is_torch_available +from diffusers.utils.testing_utils import ( + nightly, + numpy_cosine_similarity_distance, + require_accelerate, + require_big_gpu_with_torch_cuda, + torch_device, +) + + +if is_optimum_quanto_available(): + from optimum.quanto import QLinear + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +@nightly +@require_big_gpu_with_torch_cuda +@require_accelerate +class QuantoBaseTesterMixin: + model_id = None + pipeline_model_id = None + model_cls = None + torch_dtype = torch.bfloat16 + # the expected reduction in peak memory used compared to an unquantized model expressed as a percentage + expected_memory_reduction = 0.0 + keep_in_fp32_module = "" + modules_to_not_convert = "" + _test_torch_compile = False + + def setUp(self): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + def tearDown(self): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + gc.collect() + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "float8"} + + def get_dummy_model_init_kwargs(self): + return { + "pretrained_model_name_or_path": self.model_id, + "torch_dtype": self.torch_dtype, + "quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()), + } + + def test_quanto_layers(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert isinstance(module, QLinear) + + def test_quanto_memory_usage(self): + unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype) + unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3 + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + inputs = self.get_dummy_inputs() + + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + model.to(torch_device) + with torch.no_grad(): + model(**inputs) + max_memory = torch.cuda.max_memory_allocated() / 1024**3 + assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules + self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + model.to("cuda") + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + assert module.weight.dtype == torch.float32 + self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_modules_to_not_convert(self): + init_kwargs = self.get_dummy_model_init_kwargs() + + quantization_config_kwargs = self.get_dummy_init_kwargs() + quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) + quantization_config = QuantoConfig(**quantization_config_kwargs) + + init_kwargs.update({"quantization_config": quantization_config}) + + model = self.model_cls.from_pretrained(**init_kwargs) + model.to("cuda") + + for name, module in model.named_modules(): + if name in self.modules_to_not_convert: + assert not isinstance(module, QLinear) + + def test_dtype_assignment(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + + with self.assertRaises(ValueError): + # Tries with a `dtype` + model.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` and `dtype` + model.to(device="cuda:0", dtype=torch.float16) + + with self.assertRaises(ValueError): + # Tries with a cast + model.float() + + with self.assertRaises(ValueError): + # Tries with a cast + model.half() + + # This should work + model.to("cuda") + + def test_serialization(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + inputs = self.get_dummy_inputs() + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**inputs) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + saved_model = self.model_cls.from_pretrained( + tmp_dir, + torch_dtype=torch.bfloat16, + ) + + saved_model.to(torch_device) + with torch.no_grad(): + saved_model_output = saved_model(**inputs) + + assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) + + def test_torch_compile(self): + if not self._test_torch_compile: + return + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**self.get_dummy_inputs()).sample + + compiled_model.to(torch_device) + with torch.no_grad(): + compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample + + model_output = model_output.detach().float().cpu().numpy() + compiled_model_output = compiled_model_output.detach().float().cpu().numpy() + + max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) + assert max_diff < 1e-3 + + def test_device_map_error(self): + with self.assertRaises(ValueError): + _ = self.model_cls.from_pretrained( + **self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"} + ) + + +class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): + model_id = "hf-internal-testing/tiny-flux-transformer" + model_cls = FluxTransformer2DModel + pipeline_cls = FluxPipeline + torch_dtype = torch.bfloat16 + keep_in_fp32_module = "proj_out" + modules_to_not_convert = ["proj_out"] + _test_torch_compile = False + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 768), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), + } + + def get_dummy_training_inputs(self, device=None, seed: int = 0): + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + height = width = 4 + sequence_length = 48 + embedding_dim = 32 + + torch.manual_seed(seed) + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( + device, dtype=torch.bfloat16 + ) + + torch.manual_seed(seed) + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) + + torch.manual_seed(seed) + image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) + + timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "txt_ids": text_ids, + "img_ids": image_ids, + "timestep": timestep, + } + + def test_model_cpu_offload(self): + init_kwargs = self.get_dummy_init_kwargs() + transformer = self.model_cls.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + quantization_config=QuantoConfig(**init_kwargs), + subfolder="transformer", + torch_dtype=torch.bfloat16, + ) + pipe = self.pipeline_cls.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16 + ) + pipe.enable_model_cpu_offload(device=torch_device) + _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) + + def test_training(self): + quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) + quantized_model = self.model_cls.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + # freeze the model as only adapter layers will be trained + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if isinstance(module, Attention): + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_training_inputs(torch_device) + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + + +class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.3 + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "float8"} + + +class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.3 + _test_torch_compile = True + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "int8"} + + +class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "int4"} + + +class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): + expected_memory_reduction = 0.65 + + def get_dummy_init_kwargs(self): + return {"weights_dtype": "int2"}