@@ -4406,6 +4406,311 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4406
4406
super ().unfuse_lora (components = components )
4407
4407
4408
4408
4409
+ class CogView4LoraLoaderMixin (LoraBaseMixin ):
4410
+ r"""
4411
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
4412
+ """
4413
+
4414
+ _lora_loadable_modules = ["transformer" ]
4415
+ transformer_name = TRANSFORMER_NAME
4416
+
4417
+ @classmethod
4418
+ @validate_hf_hub_args
4419
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
4420
+ def lora_state_dict (
4421
+ cls ,
4422
+ pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
4423
+ ** kwargs ,
4424
+ ):
4425
+ r"""
4426
+ Return state dict for lora weights and the network alphas.
4427
+
4428
+ <Tip warning={true}>
4429
+
4430
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4431
+
4432
+ This function is experimental and might change in the future.
4433
+
4434
+ </Tip>
4435
+
4436
+ Parameters:
4437
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4438
+ Can be either:
4439
+
4440
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4441
+ the Hub.
4442
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4443
+ with [`ModelMixin.save_pretrained`].
4444
+ - A [torch state
4445
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4446
+
4447
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
4448
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4449
+ is not used.
4450
+ force_download (`bool`, *optional*, defaults to `False`):
4451
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4452
+ cached versions if they exist.
4453
+
4454
+ proxies (`Dict[str, str]`, *optional*):
4455
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4456
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4457
+ local_files_only (`bool`, *optional*, defaults to `False`):
4458
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
4459
+ won't be downloaded from the Hub.
4460
+ token (`str` or *bool*, *optional*):
4461
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4462
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
4463
+ revision (`str`, *optional*, defaults to `"main"`):
4464
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4465
+ allowed by Git.
4466
+ subfolder (`str`, *optional*, defaults to `""`):
4467
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
4468
+
4469
+ """
4470
+ # Load the main state dict first which has the LoRA layers for either of
4471
+ # transformer and text encoder or both.
4472
+ cache_dir = kwargs .pop ("cache_dir" , None )
4473
+ force_download = kwargs .pop ("force_download" , False )
4474
+ proxies = kwargs .pop ("proxies" , None )
4475
+ local_files_only = kwargs .pop ("local_files_only" , None )
4476
+ token = kwargs .pop ("token" , None )
4477
+ revision = kwargs .pop ("revision" , None )
4478
+ subfolder = kwargs .pop ("subfolder" , None )
4479
+ weight_name = kwargs .pop ("weight_name" , None )
4480
+ use_safetensors = kwargs .pop ("use_safetensors" , None )
4481
+
4482
+ allow_pickle = False
4483
+ if use_safetensors is None :
4484
+ use_safetensors = True
4485
+ allow_pickle = True
4486
+
4487
+ user_agent = {
4488
+ "file_type" : "attn_procs_weights" ,
4489
+ "framework" : "pytorch" ,
4490
+ }
4491
+
4492
+ state_dict = _fetch_state_dict (
4493
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
4494
+ weight_name = weight_name ,
4495
+ use_safetensors = use_safetensors ,
4496
+ local_files_only = local_files_only ,
4497
+ cache_dir = cache_dir ,
4498
+ force_download = force_download ,
4499
+ proxies = proxies ,
4500
+ token = token ,
4501
+ revision = revision ,
4502
+ subfolder = subfolder ,
4503
+ user_agent = user_agent ,
4504
+ allow_pickle = allow_pickle ,
4505
+ )
4506
+
4507
+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
4508
+ if is_dora_scale_present :
4509
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4510
+ logger .warning (warn_msg )
4511
+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
4512
+
4513
+ return state_dict
4514
+
4515
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4516
+ def load_lora_weights (
4517
+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
4518
+ ):
4519
+ """
4520
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4521
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4522
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4523
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4524
+ dict is loaded into `self.transformer`.
4525
+
4526
+ Parameters:
4527
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4528
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4529
+ adapter_name (`str`, *optional*):
4530
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4531
+ `default_{i}` where i is the total number of adapters being loaded.
4532
+ low_cpu_mem_usage (`bool`, *optional*):
4533
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4534
+ weights.
4535
+ kwargs (`dict`, *optional*):
4536
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4537
+ """
4538
+ if not USE_PEFT_BACKEND :
4539
+ raise ValueError ("PEFT backend is required for this method." )
4540
+
4541
+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
4542
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4543
+ raise ValueError (
4544
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4545
+ )
4546
+
4547
+ # if a dict is passed, copy it instead of modifying it inplace
4548
+ if isinstance (pretrained_model_name_or_path_or_dict , dict ):
4549
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
4550
+
4551
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4552
+ state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4553
+
4554
+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
4555
+ if not is_correct_format :
4556
+ raise ValueError ("Invalid LoRA checkpoint." )
4557
+
4558
+ self .load_lora_into_transformer (
4559
+ state_dict ,
4560
+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4561
+ adapter_name = adapter_name ,
4562
+ _pipeline = self ,
4563
+ low_cpu_mem_usage = low_cpu_mem_usage ,
4564
+ )
4565
+
4566
+ @classmethod
4567
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
4568
+ def load_lora_into_transformer (
4569
+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
4570
+ ):
4571
+ """
4572
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
4573
+
4574
+ Parameters:
4575
+ state_dict (`dict`):
4576
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4577
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4578
+ encoder lora layers.
4579
+ transformer (`CogView4Transformer2DModel`):
4580
+ The Transformer model to load the LoRA layers into.
4581
+ adapter_name (`str`, *optional*):
4582
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4583
+ `default_{i}` where i is the total number of adapters being loaded.
4584
+ low_cpu_mem_usage (`bool`, *optional*):
4585
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4586
+ weights.
4587
+ """
4588
+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
4589
+ raise ValueError (
4590
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4591
+ )
4592
+
4593
+ # Load the layers corresponding to transformer.
4594
+ logger .info (f"Loading { cls .transformer_name } ." )
4595
+ transformer .load_lora_adapter (
4596
+ state_dict ,
4597
+ network_alphas = None ,
4598
+ adapter_name = adapter_name ,
4599
+ _pipeline = _pipeline ,
4600
+ low_cpu_mem_usage = low_cpu_mem_usage ,
4601
+ )
4602
+
4603
+ @classmethod
4604
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4605
+ def save_lora_weights (
4606
+ cls ,
4607
+ save_directory : Union [str , os .PathLike ],
4608
+ transformer_lora_layers : Dict [str , Union [torch .nn .Module , torch .Tensor ]] = None ,
4609
+ is_main_process : bool = True ,
4610
+ weight_name : str = None ,
4611
+ save_function : Callable = None ,
4612
+ safe_serialization : bool = True ,
4613
+ ):
4614
+ r"""
4615
+ Save the LoRA parameters corresponding to the UNet and text encoder.
4616
+
4617
+ Arguments:
4618
+ save_directory (`str` or `os.PathLike`):
4619
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
4620
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4621
+ State dict of the LoRA layers corresponding to the `transformer`.
4622
+ is_main_process (`bool`, *optional*, defaults to `True`):
4623
+ Whether the process calling this is the main process or not. Useful during distributed training and you
4624
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4625
+ process to avoid race conditions.
4626
+ save_function (`Callable`):
4627
+ The function to use to save the state dictionary. Useful during distributed training when you need to
4628
+ replace `torch.save` with another method. Can be configured with the environment variable
4629
+ `DIFFUSERS_SAVE_MODE`.
4630
+ safe_serialization (`bool`, *optional*, defaults to `True`):
4631
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4632
+ """
4633
+ state_dict = {}
4634
+
4635
+ if not transformer_lora_layers :
4636
+ raise ValueError ("You must pass `transformer_lora_layers`." )
4637
+
4638
+ if transformer_lora_layers :
4639
+ state_dict .update (cls .pack_weights (transformer_lora_layers , cls .transformer_name ))
4640
+
4641
+ # Save the model
4642
+ cls .write_lora_layers (
4643
+ state_dict = state_dict ,
4644
+ save_directory = save_directory ,
4645
+ is_main_process = is_main_process ,
4646
+ weight_name = weight_name ,
4647
+ save_function = save_function ,
4648
+ safe_serialization = safe_serialization ,
4649
+ )
4650
+
4651
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4652
+ def fuse_lora (
4653
+ self ,
4654
+ components : List [str ] = ["transformer" ],
4655
+ lora_scale : float = 1.0 ,
4656
+ safe_fusing : bool = False ,
4657
+ adapter_names : Optional [List [str ]] = None ,
4658
+ ** kwargs ,
4659
+ ):
4660
+ r"""
4661
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4662
+
4663
+ <Tip warning={true}>
4664
+
4665
+ This is an experimental API.
4666
+
4667
+ </Tip>
4668
+
4669
+ Args:
4670
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4671
+ lora_scale (`float`, defaults to 1.0):
4672
+ Controls how much to influence the outputs with the LoRA parameters.
4673
+ safe_fusing (`bool`, defaults to `False`):
4674
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4675
+ adapter_names (`List[str]`, *optional*):
4676
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4677
+
4678
+ Example:
4679
+
4680
+ ```py
4681
+ from diffusers import DiffusionPipeline
4682
+ import torch
4683
+
4684
+ pipeline = DiffusionPipeline.from_pretrained(
4685
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4686
+ ).to("cuda")
4687
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4688
+ pipeline.fuse_lora(lora_scale=0.7)
4689
+ ```
4690
+ """
4691
+ super ().fuse_lora (
4692
+ components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
4693
+ )
4694
+
4695
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4696
+ def unfuse_lora (self , components : List [str ] = ["transformer" ], ** kwargs ):
4697
+ r"""
4698
+ Reverses the effect of
4699
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4700
+
4701
+ <Tip warning={true}>
4702
+
4703
+ This is an experimental API.
4704
+
4705
+ </Tip>
4706
+
4707
+ Args:
4708
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4709
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4710
+ """
4711
+ super ().unfuse_lora (components = components )
4712
+
4713
+
4409
4714
class LoraLoaderMixin (StableDiffusionLoraLoaderMixin ):
4410
4715
def __init__ (self , * args , ** kwargs ):
4411
4716
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
0 commit comments