Skip to content

TypeError: BnB4BitDiffusersQuantizer.create_quantized_param() got an unexpected keyword argument 'dtype' #11149

Closed
@Anopoke

Description

@Anopoke

Describe the bug

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 13
     10 model_path = "/mnt/models/AI-ModelScope/stable-diffusion-3.5-large-turbo"
     11 torch_dtype = torch.float16
---> 13 transformer = SD3Transformer2DModel.from_pretrained(
     14     model_path,
     15     subfolder="transformer",
     16     quantization_config=BitsAndBytesConfig(
     17         load_in_4bit=True,
     18         bnb_4bit_quant_type="nf4",
     19         bnb_4bit_compute_dtype=torch.bfloat16,
     20         bnb_4bit_use_double_quant=True
     21     ),
     22     torch_dtype=torch_dtype
     23 )
     24 text_encoder = CLIPTextModelWithProjection.from_pretrained(
     25     model_path,
     26     subfolder="text_encoder",
   (...)     33     torch_dtype=torch_dtype
     34 )
     35 text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
     36     model_path,
     37     subfolder="text_encoder_2",
   (...)     44     torch_dtype=torch_dtype
     45 )

File ~/python_project/StoryFusion/.venv/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    111 if check_use_auth_token:
    112     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File ~/python_project/StoryFusion/.venv/lib/python3.12/site-packages/diffusers/models/modeling_utils.py:1206, in ModelMixin.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
   1196 if hf_quantizer is not None:
   1197     hf_quantizer.validate_environment(device_map=device_map)
   1199 (
   1200     model,
   1201     missing_keys,
   1202     unexpected_keys,
   1203     mismatched_keys,
   1204     offload_index,
   1205     error_msgs,
-> 1206 ) = cls._load_pretrained_model(
   1207     model,
   1208     state_dict,
   1209     resolved_model_file,
   1210     pretrained_model_name_or_path,
   1211     loaded_keys,
   1212     ignore_mismatched_sizes=ignore_mismatched_sizes,
   1213     low_cpu_mem_usage=low_cpu_mem_usage,
   1214     device_map=device_map,
   1215     offload_folder=offload_folder,
   1216     offload_state_dict=offload_state_dict,
   1217     dtype=torch_dtype,
   1218     hf_quantizer=hf_quantizer,
   1219     keep_in_fp32_modules=keep_in_fp32_modules,
   1220     dduf_entries=dduf_entries,
   1221 )
   1222 loading_info = {
   1223     "missing_keys": missing_keys,
   1224     "unexpected_keys": unexpected_keys,
   1225     "mismatched_keys": mismatched_keys,
   1226     "error_msgs": error_msgs,
   1227 }
   1229 # Dispatch model with hooks on all devices if necessary

File ~/python_project/StoryFusion/.venv/lib/python3.12/site-packages/diffusers/models/modeling_utils.py:1465, in ModelMixin._load_pretrained_model(cls, model, state_dict, resolved_model_file, pretrained_model_name_or_path, loaded_keys, ignore_mismatched_sizes, assign_to_params_buffers, hf_quantizer, low_cpu_mem_usage, dtype, keep_in_fp32_modules, device_map, offload_state_dict, offload_folder, dduf_entries)
   1457 mismatched_keys += _find_mismatched_keys(
   1458     state_dict,
   1459     model_state_dict,
   1460     loaded_keys,
   1461     ignore_mismatched_sizes,
   1462 )
   1464 if low_cpu_mem_usage:
-> 1465     offload_index, state_dict_index = load_model_dict_into_meta(
   1466         model,
   1467         state_dict,
   1468         device_map=device_map,
   1469         dtype=dtype,
   1470         hf_quantizer=hf_quantizer,
   1471         keep_in_fp32_modules=keep_in_fp32_modules,
   1472         unexpected_keys=unexpected_keys,
   1473         offload_folder=offload_folder,
   1474         offload_index=offload_index,
   1475         state_dict_index=state_dict_index,
   1476         state_dict_folder=state_dict_folder,
   1477     )
   1478 else:
   1479     if assign_to_params_buffers is None:

File ~/python_project/StoryFusion/.venv/lib/python3.12/site-packages/diffusers/models/model_loading_utils.py:298, in load_model_dict_into_meta(model, state_dict, dtype, model_name_or_path, hf_quantizer, keep_in_fp32_modules, device_map, unexpected_keys, offload_folder, offload_index, state_dict_index, state_dict_folder)
    294     state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
    295 elif is_quantized and (
    296     hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
    297 ):
--> 298     hf_quantizer.create_quantized_param(
    299         model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
    300     )
    301 else:
    302     set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)

TypeError: BnB4BitDiffusersQuantizer.create_quantized_param() got an unexpected keyword argument 'dtype'

Reproduction

import warnings

import torch
from diffusers import SD3Transformer2DModel, BitsAndBytesConfig
from diffusers.pipelines import StableDiffusion3Pipeline
from transformers import CLIPTextModelWithProjection, T5EncoderModel

warnings.filterwarnings("ignore")

model_path = "/mnt/models/AI-ModelScope/stable-diffusion-3.5-large-turbo"
torch_dtype = torch.float16

transformer = SD3Transformer2DModel.from_pretrained(
    model_path,
    subfolder="transformer",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    ),
    torch_dtype=torch_dtype
)

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.8.0-51-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.2
  • Transformers version: 4.49.0
  • Accelerate version: 1.5.1
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.3
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.28.post3
  • Accelerator: Tesla V100-SXM2-16GB, 16384 MiB
    Tesla V100-SXM2-16GB, 16384 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions