Skip to content

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 #11018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```

Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.

```python
import torch
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
}
}


if is_accelerate_available():
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
Expand Down
41 changes: 40 additions & 1 deletion src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@

from packaging import version

from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
from ...utils import (
get_module_from_name,
is_torch_available,
is_torch_version,
is_torchao_available,
is_torchao_version,
logging,
)
from ..base import DiffusersQuantizer


Expand Down Expand Up @@ -62,6 +69,38 @@
from torchao.quantization import quantize_


def _update_torch_safe_globals():
safe_globals = [
(torch.uint1, "torch.uint1"),
(torch.uint2, "torch.uint2"),
(torch.uint3, "torch.uint3"),
(torch.uint4, "torch.uint4"),
(torch.uint5, "torch.uint5"),
(torch.uint6, "torch.uint6"),
(torch.uint7, "torch.uint7"),
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor

safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])

except (ImportError, ModuleNotFoundError) as e:
logger.warning(
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
)
logger.debug(e)

finally:
torch.serialization.add_safe_globals(safe_globals=safe_globals)


if is_torch_version(">=", "2.6.0") and is_torchao_available() and is_torchao_version(">=", "0.7.0"):
_update_torch_safe_globals()


logger = logging.get_logger(__name__)


Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
is_torch_xla_available,
is_torch_xla_version,
is_torchao_available,
is_torchao_version,
is_torchsde_available,
is_torchvision_available,
is_transformers_available,
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,21 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)


def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.

Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _is_torchao_available:
return False
return compare_versions(parse(_torchao_version), operation, version)


def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
Expand Down
Loading