Skip to content

Commit 25f826e

Browse files
authored
Merge branch 'main' into metadata-lora
2 parents ba546bc + 7edace9 commit 25f826e

File tree

13 files changed

+632
-45
lines changed

13 files changed

+632
-45
lines changed

docs/source/en/api/loaders/lora.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
2020
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
2121
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
2222
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
23+
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
2324
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
2425
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
2526
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
@@ -56,6 +57,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
5657
## Mochi1LoraLoaderMixin
5758

5859
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
60+
## AuraFlowLoraLoaderMixin
61+
62+
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
5963

6064
## LTXVideoLoraLoaderMixin
6165

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder):
6565
"AmusedLoraLoaderMixin",
6666
"StableDiffusionLoraLoaderMixin",
6767
"SD3LoraLoaderMixin",
68+
"AuraFlowLoraLoaderMixin",
6869
"StableDiffusionXLLoraLoaderMixin",
6970
"LTXVideoLoraLoaderMixin",
7071
"LoraLoaderMixin",
@@ -103,6 +104,7 @@ def text_encoder_attn_modules(text_encoder):
103104
)
104105
from .lora_pipeline import (
105106
AmusedLoraLoaderMixin,
107+
AuraFlowLoraLoaderMixin,
106108
CogVideoXLoraLoaderMixin,
107109
CogView4LoraLoaderMixin,
108110
FluxLoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 333 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
5454
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
5555
"SanaTransformer2DModel": lambda model_cls, weights: weights,
56+
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
5657
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
5758
"WanTransformer3DModel": lambda model_cls, weights: weights,
5859
"CogView4Transformer2DModel": lambda model_cls, weights: weights,

src/diffusers/models/autoencoders/vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def __init__(
255255
num_layers=self.layers_per_block + 1,
256256
in_channels=prev_output_channel,
257257
out_channels=output_channel,
258-
prev_output_channel=None,
258+
prev_output_channel=prev_output_channel,
259259
add_upsample=not is_final_block,
260260
resnet_eps=1e-6,
261261
resnet_act_fn=act_fn,

src/diffusers/models/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3939

4040

41-
def text_encoder_attn_modules(text_encoder):
41+
def text_encoder_attn_modules(text_encoder: nn.Module):
4242
attn_modules = []
4343

4444
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
5252
return attn_modules
5353

5454

55-
def text_encoder_mlp_modules(text_encoder):
55+
def text_encoder_mlp_modules(text_encoder: nn.Module):
5656
mlp_modules = []
5757

5858
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict, Union
16+
from typing import Any, Dict, Optional, Union
1717

1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23-
from ...loaders import FromOriginalModelMixin
24-
from ...utils import logging
23+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2525
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention_processor import (
2727
Attention,
@@ -160,14 +160,20 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
160160
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
161161
self.ff = AuraFlowFeedForward(dim, dim * 4)
162162

163-
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
163+
def forward(
164+
self,
165+
hidden_states: torch.FloatTensor,
166+
temb: torch.FloatTensor,
167+
attention_kwargs: Optional[Dict[str, Any]] = None,
168+
):
164169
residual = hidden_states
170+
attention_kwargs = attention_kwargs or {}
165171

166172
# Norm + Projection.
167173
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
168174

169175
# Attention.
170-
attn_output = self.attn(hidden_states=norm_hidden_states)
176+
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
171177

172178
# Process attention outputs for the `hidden_states`.
173179
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -223,10 +229,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
223229
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
224230

225231
def forward(
226-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
232+
self,
233+
hidden_states: torch.FloatTensor,
234+
encoder_hidden_states: torch.FloatTensor,
235+
temb: torch.FloatTensor,
236+
attention_kwargs: Optional[Dict[str, Any]] = None,
227237
):
228238
residual = hidden_states
229239
residual_context = encoder_hidden_states
240+
attention_kwargs = attention_kwargs or {}
230241

231242
# Norm + Projection.
232243
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -236,7 +247,9 @@ def forward(
236247

237248
# Attention.
238249
attn_output, context_attn_output = self.attn(
239-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
250+
hidden_states=norm_hidden_states,
251+
encoder_hidden_states=norm_encoder_hidden_states,
252+
**attention_kwargs,
240253
)
241254

242255
# Process attention outputs for the `hidden_states`.
@@ -254,7 +267,7 @@ def forward(
254267
return encoder_hidden_states, hidden_states
255268

256269

257-
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
270+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
258271
r"""
259272
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
260273
@@ -449,8 +462,24 @@ def forward(
449462
hidden_states: torch.FloatTensor,
450463
encoder_hidden_states: torch.FloatTensor = None,
451464
timestep: torch.LongTensor = None,
465+
attention_kwargs: Optional[Dict[str, Any]] = None,
452466
return_dict: bool = True,
453467
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
468+
if attention_kwargs is not None:
469+
attention_kwargs = attention_kwargs.copy()
470+
lora_scale = attention_kwargs.pop("scale", 1.0)
471+
else:
472+
lora_scale = 1.0
473+
474+
if USE_PEFT_BACKEND:
475+
# weight the lora layers by setting `lora_scale` for each PEFT layer
476+
scale_lora_layers(self, lora_scale)
477+
else:
478+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
479+
logger.warning(
480+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
481+
)
482+
454483
height, width = hidden_states.shape[-2:]
455484

456485
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -474,7 +503,10 @@ def forward(
474503

475504
else:
476505
encoder_hidden_states, hidden_states = block(
477-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
506+
hidden_states=hidden_states,
507+
encoder_hidden_states=encoder_hidden_states,
508+
temb=temb,
509+
attention_kwargs=attention_kwargs,
478510
)
479511

480512
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -491,7 +523,9 @@ def forward(
491523
)
492524

493525
else:
494-
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
526+
combined_hidden_states = block(
527+
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
528+
)
495529

496530
hidden_states = combined_hidden_states[:, encoder_seq_len:]
497531

@@ -512,6 +546,10 @@ def forward(
512546
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
513547
)
514548

549+
if USE_PEFT_BACKEND:
550+
# remove `lora_scale` from each PEFT layer
551+
unscale_lora_layers(self, lora_scale)
552+
515553
if not return_dict:
516554
return (output,)
517555

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15-
from typing import Callable, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
from transformers import T5Tokenizer, UMT5EncoderModel
1919

2020
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2121
from ...image_processor import VaeImageProcessor
22+
from ...loaders import AuraFlowLoraLoaderMixin
2223
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2324
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2425
from ...schedulers import FlowMatchEulerDiscreteScheduler
25-
from ...utils import is_torch_xla_available, logging, replace_example_docstring
26+
from ...utils import (
27+
USE_PEFT_BACKEND,
28+
is_torch_xla_available,
29+
logging,
30+
replace_example_docstring,
31+
scale_lora_layers,
32+
unscale_lora_layers,
33+
)
2634
from ...utils.torch_utils import randn_tensor
2735
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2836

@@ -112,7 +120,7 @@ def retrieve_timesteps(
112120
return timesteps, num_inference_steps
113121

114122

115-
class AuraFlowPipeline(DiffusionPipeline):
123+
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
116124
r"""
117125
Args:
118126
tokenizer (`T5TokenizerFast`):
@@ -233,6 +241,7 @@ def encode_prompt(
233241
prompt_attention_mask: Optional[torch.Tensor] = None,
234242
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
235243
max_sequence_length: int = 256,
244+
lora_scale: Optional[float] = None,
236245
):
237246
r"""
238247
Encodes the prompt into text encoder hidden states.
@@ -259,10 +268,20 @@ def encode_prompt(
259268
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
260269
Pre-generated attention mask for negative text embeddings.
261270
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
271+
lora_scale (`float`, *optional*):
272+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
262273
"""
274+
# set lora scale so that monkey patched LoRA
275+
# function of text encoder can correctly access it
276+
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
277+
self._lora_scale = lora_scale
278+
279+
# dynamically adjust the LoRA scale
280+
if self.text_encoder is not None and USE_PEFT_BACKEND:
281+
scale_lora_layers(self.text_encoder, lora_scale)
282+
263283
if device is None:
264284
device = self._execution_device
265-
266285
if prompt is not None and isinstance(prompt, str):
267286
batch_size = 1
268287
elif prompt is not None and isinstance(prompt, list):
@@ -346,6 +365,11 @@ def encode_prompt(
346365
negative_prompt_embeds = None
347366
negative_prompt_attention_mask = None
348367

368+
if self.text_encoder is not None:
369+
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
370+
# Retrieve the original scale by scaling back the LoRA layers
371+
unscale_lora_layers(self.text_encoder, lora_scale)
372+
349373
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
350374

351375
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -403,6 +427,10 @@ def upcast_vae(self):
403427
def guidance_scale(self):
404428
return self._guidance_scale
405429

430+
@property
431+
def attention_kwargs(self):
432+
return self._attention_kwargs
433+
406434
@property
407435
def num_timesteps(self):
408436
return self._num_timesteps
@@ -428,6 +456,7 @@ def __call__(
428456
max_sequence_length: int = 256,
429457
output_type: Optional[str] = "pil",
430458
return_dict: bool = True,
459+
attention_kwargs: Optional[Dict[str, Any]] = None,
431460
callback_on_step_end: Optional[
432461
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
433462
] = None,
@@ -486,6 +515,10 @@ def __call__(
486515
return_dict (`bool`, *optional*, defaults to `True`):
487516
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
488517
of a plain tuple.
518+
attention_kwargs (`dict`, *optional*):
519+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
520+
`self.processor` in
521+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
489522
callback_on_step_end (`Callable`, *optional*):
490523
A function that calls at the end of each denoising steps during the inference. The function is called
491524
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -520,6 +553,7 @@ def __call__(
520553
)
521554

522555
self._guidance_scale = guidance_scale
556+
self._attention_kwargs = attention_kwargs
523557

524558
# 2. Determine batch size.
525559
if prompt is not None and isinstance(prompt, str):
@@ -530,6 +564,7 @@ def __call__(
530564
batch_size = prompt_embeds.shape[0]
531565

532566
device = self._execution_device
567+
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
533568

534569
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
535570
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -553,6 +588,7 @@ def __call__(
553588
prompt_attention_mask=prompt_attention_mask,
554589
negative_prompt_attention_mask=negative_prompt_attention_mask,
555590
max_sequence_length=max_sequence_length,
591+
lora_scale=lora_scale,
556592
)
557593
if do_classifier_free_guidance:
558594
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -594,6 +630,7 @@ def __call__(
594630
encoder_hidden_states=prompt_embeds,
595631
timestep=timestep,
596632
return_dict=False,
633+
attention_kwargs=self.attention_kwargs,
597634
)[0]
598635

599636
# perform guidance

0 commit comments

Comments
 (0)