Skip to content

Commit 1235862

Browse files
linjiaproyiyixuxu
andauthored
Improve control net block index for sd3 (#9758)
* improve control net index --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 805aa93 commit 1235862

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(
5656
out_channels: int = 16,
5757
pos_embed_max_size: int = 96,
5858
extra_conditioning_channels: int = 0,
59+
dual_attention_layers: Tuple[int, ...] = (),
60+
qk_norm: Optional[str] = None,
5961
):
6062
super().__init__()
6163
default_out_channels = in_channels
@@ -84,6 +86,8 @@ def __init__(
8486
num_attention_heads=num_attention_heads,
8587
attention_head_dim=self.config.attention_head_dim,
8688
context_pre_only=False,
89+
qk_norm=qk_norm,
90+
use_dual_attention=True if i in dual_attention_layers else False,
8791
)
8892
for i in range(num_layers)
8993
]
@@ -248,7 +252,7 @@ def from_transformer(
248252
config = transformer.config
249253
config["num_layers"] = num_layers or config.num_layers
250254
config["extra_conditioning_channels"] = num_extra_conditioning_channels
251-
controlnet = cls(**config)
255+
controlnet = cls.from_config(config)
252256

253257
if load_weights_from_transformer:
254258
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

18+
import numpy as np
1819
import torch
1920
import torch.nn as nn
2021

@@ -349,7 +350,8 @@ def custom_forward(*inputs):
349350

350351
# controlnet residual
351352
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
352-
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
353+
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
354+
interval_control = int(np.ceil(interval_control))
353355
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
354356

355357
hidden_states = self.norm_out(hidden_states, temb)

tests/pipelines/controlnet_sd3/test_controlnet_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import gc
1717
import unittest
18+
from typing import Optional
1819

1920
import numpy as np
2021
import pytest
@@ -59,7 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
5960
)
6061
batch_params = frozenset(["prompt", "negative_prompt"])
6162

62-
def get_dummy_components(self):
63+
def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"):
6364
torch.manual_seed(0)
6465
transformer = SD3Transformer2DModel(
6566
sample_size=32,
@@ -72,14 +73,15 @@ def get_dummy_components(self):
7273
caption_projection_dim=32,
7374
pooled_projection_dim=64,
7475
out_channels=8,
76+
qk_norm=qk_norm,
7577
)
7678

7779
torch.manual_seed(0)
7880
controlnet = SD3ControlNetModel(
7981
sample_size=32,
8082
patch_size=1,
8183
in_channels=8,
82-
num_layers=1,
84+
num_layers=num_controlnet_layers,
8385
attention_head_dim=8,
8486
num_attention_heads=4,
8587
joint_attention_dim=32,

0 commit comments

Comments
 (0)