Skip to content

Commit 9add071

Browse files
authored
[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018)
* update * update * update * update * update * update * update * update * update
1 parent b88fef4 commit 9add071

File tree

5 files changed

+70
-16
lines changed

5 files changed

+70
-16
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
126126
image.save("output.png")
127127
```
128128

129-
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
129+
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
130130

131131
```python
132132
import torch

src/diffusers/__init__.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,14 @@
22

33
from typing import TYPE_CHECKING
44

5-
from diffusers.quantizers import quantization_config
6-
from diffusers.utils import dummy_gguf_objects
7-
from diffusers.utils.import_utils import (
8-
is_bitsandbytes_available,
9-
is_gguf_available,
10-
is_optimum_quanto_version,
11-
is_torchao_available,
12-
)
13-
145
from .utils import (
156
DIFFUSERS_SLOW_IMPORT,
167
OptionalDependencyNotAvailable,
178
_LazyModule,
9+
is_accelerate_available,
10+
is_bitsandbytes_available,
1811
is_flax_available,
12+
is_gguf_available,
1913
is_k_diffusion_available,
2014
is_librosa_available,
2115
is_note_seq_available,
@@ -24,6 +18,7 @@
2418
is_scipy_available,
2519
is_sentencepiece_available,
2620
is_torch_available,
21+
is_torchao_available,
2722
is_torchsde_available,
2823
is_transformers_available,
2924
)
@@ -65,7 +60,7 @@
6560
}
6661

6762
try:
68-
if not is_bitsandbytes_available():
63+
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
6964
raise OptionalDependencyNotAvailable()
7065
except OptionalDependencyNotAvailable:
7166
from .utils import dummy_bitsandbytes_objects
@@ -77,7 +72,7 @@
7772
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
7873

7974
try:
80-
if not is_gguf_available():
75+
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
8176
raise OptionalDependencyNotAvailable()
8277
except OptionalDependencyNotAvailable:
8378
from .utils import dummy_gguf_objects
@@ -89,7 +84,7 @@
8984
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
9085

9186
try:
92-
if not is_torchao_available():
87+
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
9388
raise OptionalDependencyNotAvailable()
9489
except OptionalDependencyNotAvailable:
9590
from .utils import dummy_torchao_objects
@@ -101,7 +96,7 @@
10196
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
10297

10398
try:
104-
if not is_optimum_quanto_available():
99+
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
105100
raise OptionalDependencyNotAvailable()
106101
except OptionalDependencyNotAvailable:
107102
from .utils import dummy_optimum_quanto_objects
@@ -112,7 +107,6 @@
112107
else:
113108
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
114109

115-
116110
try:
117111
if not is_onnx_available():
118112
raise OptionalDependencyNotAvailable()

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@
2323

2424
from packaging import version
2525

26-
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
26+
from ...utils import (
27+
get_module_from_name,
28+
is_torch_available,
29+
is_torch_version,
30+
is_torchao_available,
31+
is_torchao_version,
32+
logging,
33+
)
2734
from ..base import DiffusersQuantizer
2835

2936

@@ -62,6 +69,43 @@
6269
from torchao.quantization import quantize_
6370

6471

72+
def _update_torch_safe_globals():
73+
safe_globals = [
74+
(torch.uint1, "torch.uint1"),
75+
(torch.uint2, "torch.uint2"),
76+
(torch.uint3, "torch.uint3"),
77+
(torch.uint4, "torch.uint4"),
78+
(torch.uint5, "torch.uint5"),
79+
(torch.uint6, "torch.uint6"),
80+
(torch.uint7, "torch.uint7"),
81+
]
82+
try:
83+
from torchao.dtypes import NF4Tensor
84+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
85+
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
86+
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
87+
88+
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
89+
90+
except (ImportError, ModuleNotFoundError) as e:
91+
logger.warning(
92+
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
93+
)
94+
logger.debug(e)
95+
96+
finally:
97+
torch.serialization.add_safe_globals(safe_globals=safe_globals)
98+
99+
100+
if (
101+
is_torch_available()
102+
and is_torch_version(">=", "2.6.0")
103+
and is_torchao_available()
104+
and is_torchao_version(">=", "0.7.0")
105+
):
106+
_update_torch_safe_globals()
107+
108+
65109
logger = logging.get_logger(__name__)
66110

67111

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
is_torch_xla_available,
9595
is_torch_xla_version,
9696
is_torchao_available,
97+
is_torchao_version,
9798
is_torchsde_available,
9899
is_torchvision_available,
99100
is_transformers_available,

src/diffusers/utils/import_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,21 @@ def is_gguf_version(operation: str, version: str):
868868
return compare_versions(parse(_gguf_version), operation, version)
869869

870870

871+
def is_torchao_version(operation: str, version: str):
872+
"""
873+
Compares the current torchao version to a given reference with an operation.
874+
875+
Args:
876+
operation (`str`):
877+
A string representation of an operator, such as `">"` or `"<="`
878+
version (`str`):
879+
A version string
880+
"""
881+
if not _is_torchao_available:
882+
return False
883+
return compare_versions(parse(_torchao_version), operation, version)
884+
885+
871886
def is_k_diffusion_version(operation: str, version: str):
872887
"""
873888
Compares the current k-diffusion version to a given reference with an operation.

0 commit comments

Comments
 (0)