Description
Describe the bug
Certain adapter functions in PeftAdapterMixin such as enable_adapters
and disable_adapters
do not work if the only adapters that were added were LoRAs loaded with load_lora_adapter
, instead failing with a ValueError message indicating that no adapters are loaded:
>>> transformer.enable_adapters()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/<redacted>/diffusers/loaders/peft.py", line 604, in enable_adapters
raise ValueError("No adapter loaded. Please load an adapter first.")
ValueError: No adapter loaded. Please load an adapter first.
I encountered this bug when I was trying to apply a LoRA to a SD3Transformer2DModel
instead of a StableDiffusion3Pipeline
, as the training script I'm using does not initialize a StableDiffusion3Pipeline
during training.
Based on a cursory glance of the code, it looks like the issue is that _hf_peft_config_loaded
is not set to True when load_lora_adapter
is called, before inject_adapter_in_model
:
diffusers/src/diffusers/loaders/peft.py
Lines 301 to 322 in 7aac77a
Compare to the corresponding code in add_adapter
:
diffusers/src/diffusers/loaders/peft.py
Lines 504 to 520 in 7aac77a
Setting the _hf_peft_config_loaded
flag to True works around the issue.
I intend to submit a PR unless there's a reason why the code is set up this way.
Reproduction
from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
import os
sd3_path='/path/to/stable-diffusion-3-medium-diffusers'
# Prerequisite: we need to have a LoRA to test with.
# In lieu of finding a premade LoRA, we can just make a fresh dummy LoRA from scratch
transformer = SD3Transformer2DModel.from_pretrained(sd3_path, subfolder='transformer')
lc = LoraConfig(r=4, lora_alpha=2, init_lora_weights='gaussian', target_modules=["attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", "attn.to_k", "attn.to_out.0", "attn.to_q", "attn.to_v"])
transformer.add_adapter(lc)
test_lora = get_peft_model_state_dict(transformer)
StableDiffusion3Pipeline.save_lora_weights(save_directory='test_lora', transformer_lora_layers=test_lora, text_encoder_lora_layers=None, text_encoder_2_lora_layers=None)
# Reinitialize the transformer to a fresh state to reproduce the bug
transformer = SD3Transformer2DModel.from_pretrained(sd3_path, subfolder='transformer')
transformer.load_lora_adapter(os.path.join('test_lora', 'pytorch_lora_weights.safetensors'), adapter_name='bugtest_lora')
# Some commands work
transformer.set_adapters(['bugtest_lora'], weights=[0.5])
# This one should work, but it doesn't
transformer.enable_adapters()
# Workaround: set _hf_peft_config_loaded to True
transformer._hf_peft_config_loaded = True
transformer.enable_adapters()
There's very likely a simpler way to reproduce this that doesn't require SD3.
Logs
System Info
Using a clean test environment with current diffusers git revision 1ddf3f3:
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.11.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.3
- Transformers version: 4.50.0
- Accelerate version: 1.5.2
- PEFT version: 0.11.1
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA RTX A6000, 49140 MiB
- Using GPU in script?: Yes, but likely irrelevant
- Using distributed or parallel set-up in script?: N/A