From 8db89e7453ab9b36198530fc63e63555f4c5d14c Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 10 Mar 2025 06:55:44 +0100 Subject: [PATCH 1/9] update --- src/diffusers/models/model_loading_utils.py | 34 ++++++++++++++++++++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 15 +++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index f019a3cc67a6..5efd95236a13 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -39,6 +39,8 @@ deprecate, is_accelerate_available, is_gguf_available, + is_torchao_available, + is_torchao_version, is_torch_available, is_torch_version, logging, @@ -54,12 +56,42 @@ } } - if is_accelerate_available(): from accelerate import infer_auto_device_map from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device +def _update_torch_safe_globals(): + safe_globals = [ + (torch.uint1, "torch.uint1"), + (torch.uint2, "torch.uint2"), + (torch.uint3, "torch.uint3"), + (torch.uint4, "torch.uint4"), + (torch.uint5, "torch.uint5"), + (torch.uint6, "torch.uint6"), + (torch.uint7, "torch.uint7"), + ] + try: + from torchao.dtypes.uintx.uintx_layout import UintxTensor, UintxAQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes import NF4Tensor + + safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) + + except (ImportError, ModuleNotFoundError, NotImplementedError) as e: + logger.warning( + "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" + ) + + finally: + torch.serialization.add_safe_globals(safe_globals=safe_globals) + + +if is_torchao_available() and is_torch_version(">=", "2.6") and is_torchao_version(">=", "0.7.0"): + _update_torch_safe_globals() + + # Adapted from `transformers` (see modeling_utils.py) def _determine_device_map( model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6702ea2efbc8..94711d51274f 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -92,6 +92,7 @@ is_torch_xla_available, is_torch_xla_version, is_torchao_available, + is_torchao_version, is_torchsde_available, is_torchvision_available, is_transformers_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index ae1b9cae6edc..649705218a35 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -849,6 +849,21 @@ def is_gguf_version(operation: str, version: str): return compare_versions(parse(_gguf_version), operation, version) +def is_torchao_version(operation: str, version: str): + """ + Compares the current torchao version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _is_torchao_available: + return False + return compare_versions(parse(is_torch_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. From 56ec287e8a37f4249e1bf265d10673a51d1de196 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 16:02:48 +0530 Subject: [PATCH 2/9] update --- docs/source/en/quantization/torchao.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index c056876c2f09..19a8970fa9df 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] image.save("output.png") ``` -Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. +If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. ```python import torch From 08b8503ffb20bd2201cd343c89d0112e8ef2dd92 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 16:29:41 +0530 Subject: [PATCH 3/9] update --- src/diffusers/models/model_loading_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5efd95236a13..22eea49c2894 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -79,16 +79,17 @@ def _update_torch_safe_globals(): safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) - except (ImportError, ModuleNotFoundError, NotImplementedError) as e: + except (ImportError, ModuleNotFoundError) as e: logger.warning( - "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" + f"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" ) + logger.debug(e) finally: torch.serialization.add_safe_globals(safe_globals=safe_globals) -if is_torchao_available() and is_torch_version(">=", "2.6") and is_torchao_version(">=", "0.7.0"): +if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): _update_torch_safe_globals() From 6a0ae75b55435871e9fcfa1d0aa0f6bfb42a4897 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 16:35:49 +0530 Subject: [PATCH 4/9] update --- src/diffusers/models/model_loading_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 22eea49c2894..e074d0517792 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -39,10 +39,10 @@ deprecate, is_accelerate_available, is_gguf_available, - is_torchao_available, - is_torchao_version, is_torch_available, is_torch_version, + is_torchao_available, + is_torchao_version, logging, ) @@ -72,16 +72,16 @@ def _update_torch_safe_globals(): (torch.uint7, "torch.uint7"), ] try: - from torchao.dtypes.uintx.uintx_layout import UintxTensor, UintxAQTTensorImpl - from torchao.dtypes.uintx.uint4_layout import UInt4Tensor - from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl from torchao.dtypes import NF4Tensor + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) except (ImportError, ModuleNotFoundError) as e: logger.warning( - f"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" + "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" ) logger.debug(e) From 9297598dff8fc356a50778b28de800c33d0007f2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 16:59:02 +0530 Subject: [PATCH 5/9] update --- src/diffusers/models/model_loading_utils.py | 32 ------------------- .../quantizers/torchao/torchao_quantizer.py | 32 +++++++++++++++++++ src/diffusers/utils/import_utils.py | 2 +- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index e074d0517792..8b26c49db8b6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -61,38 +61,6 @@ from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device -def _update_torch_safe_globals(): - safe_globals = [ - (torch.uint1, "torch.uint1"), - (torch.uint2, "torch.uint2"), - (torch.uint3, "torch.uint3"), - (torch.uint4, "torch.uint4"), - (torch.uint5, "torch.uint5"), - (torch.uint6, "torch.uint6"), - (torch.uint7, "torch.uint7"), - ] - try: - from torchao.dtypes import NF4Tensor - from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl - from torchao.dtypes.uintx.uint4_layout import UInt4Tensor - from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor - - safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) - - except (ImportError, ModuleNotFoundError) as e: - logger.warning( - "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" - ) - logger.debug(e) - - finally: - torch.serialization.add_safe_globals(safe_globals=safe_globals) - - -if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): - _update_torch_safe_globals() - - # Adapted from `transformers` (see modeling_utils.py) def _determine_device_map( model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index e86ce2f64278..97970691f275 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -62,6 +62,38 @@ from torchao.quantization import quantize_ +def _update_torch_safe_globals(): + safe_globals = [ + (torch.uint1, "torch.uint1"), + (torch.uint2, "torch.uint2"), + (torch.uint3, "torch.uint3"), + (torch.uint4, "torch.uint4"), + (torch.uint5, "torch.uint5"), + (torch.uint6, "torch.uint6"), + (torch.uint7, "torch.uint7"), + ] + try: + from torchao.dtypes import NF4Tensor + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor + + safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) + + except (ImportError, ModuleNotFoundError) as e: + logger.warning( + "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" + ) + logger.debug(e) + + finally: + torch.serialization.add_safe_globals(safe_globals=safe_globals) + + +if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): + _update_torch_safe_globals() + + logger = logging.get_logger(__name__) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 649705218a35..f23f576ed4e9 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -861,7 +861,7 @@ def is_torchao_version(operation: str, version: str): """ if not _is_torchao_available: return False - return compare_versions(parse(is_torch_version), operation, version) + return compare_versions(parse(_torchao_version), operation, version) def is_k_diffusion_version(operation: str, version: str): From 280a0aca4c522a50558542eac4dad4c9d2a398fb Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 17:00:43 +0530 Subject: [PATCH 6/9] update --- src/diffusers/models/model_loading_utils.py | 2 -- src/diffusers/quantizers/torchao/torchao_quantizer.py | 9 ++++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b26c49db8b6..8d95fb7bfc75 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -41,8 +41,6 @@ is_gguf_available, is_torch_available, is_torch_version, - is_torchao_available, - is_torchao_version, logging, ) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 97970691f275..c256e983166c 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -23,7 +23,14 @@ from packaging import version -from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging +from ...utils import ( + get_module_from_name, + is_torch_available, + is_torch_version, + is_torchao_available, + is_torchao_version, + logging, +) from ..base import DiffusersQuantizer From 6cf941c69f9361d1de0de918ab3da31b9b49f941 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 17:07:54 +0530 Subject: [PATCH 7/9] update --- src/diffusers/quantizers/torchao/torchao_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index c256e983166c..b6b2b622eecf 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -97,7 +97,7 @@ def _update_torch_safe_globals(): torch.serialization.add_safe_globals(safe_globals=safe_globals) -if is_torch_version(">=", "2.6") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): +if is_torch_version(">=", "2.6.0") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): _update_torch_safe_globals() From fdf1c11e180d6844f480e741e979a8ee5a2fac2b Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 10 Mar 2025 17:10:26 +0530 Subject: [PATCH 8/9] update --- src/diffusers/models/model_loading_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8d95fb7bfc75..f019a3cc67a6 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -54,6 +54,7 @@ } } + if is_accelerate_available(): from accelerate import infer_auto_device_map from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device From 16ebfb77549753247a654cb31cbc3c6aa45ad0b9 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 11 Mar 2025 08:46:42 +0530 Subject: [PATCH 9/9] update --- src/diffusers/__init__.py | 22 +++++++------------ .../quantizers/torchao/torchao_quantizer.py | 7 +++++- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c482ed324179..6421ea871a75 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -2,20 +2,14 @@ from typing import TYPE_CHECKING -from diffusers.quantizers import quantization_config -from diffusers.utils import dummy_gguf_objects -from diffusers.utils.import_utils import ( - is_bitsandbytes_available, - is_gguf_available, - is_optimum_quanto_version, - is_torchao_available, -) - from .utils import ( DIFFUSERS_SLOW_IMPORT, OptionalDependencyNotAvailable, _LazyModule, + is_accelerate_available, + is_bitsandbytes_available, is_flax_available, + is_gguf_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, @@ -24,6 +18,7 @@ is_scipy_available, is_sentencepiece_available, is_torch_available, + is_torchao_available, is_torchsde_available, is_transformers_available, ) @@ -65,7 +60,7 @@ } try: - if not is_bitsandbytes_available(): + if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_bitsandbytes_objects @@ -77,7 +72,7 @@ _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig") try: - if not is_gguf_available(): + if not is_torch_available() and not is_accelerate_available() and not is_gguf_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_gguf_objects @@ -89,7 +84,7 @@ _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig") try: - if not is_torchao_available(): + if not is_torch_available() and not is_accelerate_available() and not is_torchao_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torchao_objects @@ -101,7 +96,7 @@ _import_structure["quantizers.quantization_config"].append("TorchAoConfig") try: - if not is_optimum_quanto_available(): + if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_optimum_quanto_objects @@ -112,7 +107,6 @@ else: _import_structure["quantizers.quantization_config"].append("QuantoConfig") - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 209135deb979..f9fb217ed6bd 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -97,7 +97,12 @@ def _update_torch_safe_globals(): torch.serialization.add_safe_globals(safe_globals=safe_globals) -if is_torch_version(">=", "2.6.0") and is_torchao_available() and is_torchao_version(">=", "0.7.0"): +if ( + is_torch_available() + and is_torch_version(">=", "2.6.0") + and is_torchao_available() + and is_torchao_version(">=", "0.7.0") +): _update_torch_safe_globals()