Skip to content

Commit f4c8d6c

Browse files
committed
fix: missing AutoencoderKL lora adapter
1 parent 0d1d267 commit f4c8d6c

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
)
3232
from ..modeling_outputs import AutoencoderKLOutput
3333
from ..modeling_utils import ModelMixin
34+
from ...loaders import PeftAdapterMixin
3435
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3536

3637

37-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
38+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3839
r"""
3940
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4041

tests/models/autoencoders/test_models_vae.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,39 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
846846
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
847847

848848

849+
def test_lora_adapter(self):
850+
851+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
852+
853+
target_modules_vae = [
854+
"conv1",
855+
"conv2",
856+
"conv_in",
857+
"conv_shortcut",
858+
"conv",
859+
"conv_out",
860+
"skip_conv_1",
861+
"skip_conv_2",
862+
"skip_conv_3",
863+
"skip_conv_4",
864+
"to_k",
865+
"to_q",
866+
"to_v",
867+
"to_out.0",
868+
]
869+
vae_lora_config = LoraConfig(
870+
r=16,
871+
init_lora_weights="gaussian",
872+
target_modules=target_modules_vae,
873+
)
874+
875+
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
876+
877+
active_lora = vae.active_adapters()
878+
self.assertTrue(len(active_lora) == 1)
879+
self.assertTrue(active_lora[0] == "vae_lora")
880+
881+
849882
@slow
850883
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
851884
def get_file_format(self, seed, shape):

0 commit comments

Comments
 (0)