Skip to content

Commit 7761b89

Browse files
yiyixuxuyiyixuxu
andauthored
update conversion script for Kandinsky unet (#3766)
* update kandinsky conversion script * style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent ce55049 commit 7761b89

File tree

1 file changed

+89
-78
lines changed

1 file changed

+89
-78
lines changed

scripts/convert_kandinsky_to_diffusers.py

Lines changed: 89 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from diffusers import UNet2DConditionModel
99
from diffusers.models.prior_transformer import PriorTransformer
1010
from diffusers.models.vq_model import VQModel
11-
from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel
1211

1312

1413
"""
@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
225224

226225
UNET_CONFIG = {
227226
"act_fn": "silu",
227+
"addition_embed_type": "text_image",
228+
"addition_embed_type_num_heads": 64,
228229
"attention_head_dim": 64,
229-
"block_out_channels": (384, 768, 1152, 1536),
230+
"block_out_channels": [384, 768, 1152, 1536],
230231
"center_input_sample": False,
231-
"class_embed_type": "identity",
232+
"class_embed_type": None,
233+
"class_embeddings_concat": False,
234+
"conv_in_kernel": 3,
235+
"conv_out_kernel": 3,
232236
"cross_attention_dim": 768,
233-
"down_block_types": (
237+
"cross_attention_norm": None,
238+
"down_block_types": [
234239
"ResnetDownsampleBlock2D",
235240
"SimpleCrossAttnDownBlock2D",
236241
"SimpleCrossAttnDownBlock2D",
237242
"SimpleCrossAttnDownBlock2D",
238-
),
243+
],
239244
"downsample_padding": 1,
240245
"dual_cross_attention": False,
246+
"encoder_hid_dim": 1024,
247+
"encoder_hid_dim_type": "text_image_proj",
241248
"flip_sin_to_cos": True,
242249
"freq_shift": 0,
243250
"in_channels": 4,
244251
"layers_per_block": 3,
252+
"mid_block_only_cross_attention": None,
245253
"mid_block_scale_factor": 1,
246254
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
247255
"norm_eps": 1e-05,
248256
"norm_num_groups": 32,
257+
"num_class_embeds": None,
249258
"only_cross_attention": False,
250259
"out_channels": 8,
260+
"projection_class_embeddings_input_dim": None,
261+
"resnet_out_scale_factor": 1.0,
262+
"resnet_skip_time_act": False,
251263
"resnet_time_scale_shift": "scale_shift",
252264
"sample_size": 64,
253-
"up_block_types": (
265+
"time_cond_proj_dim": None,
266+
"time_embedding_act_fn": None,
267+
"time_embedding_dim": None,
268+
"time_embedding_type": "positional",
269+
"timestep_post_act": None,
270+
"up_block_types": [
254271
"SimpleCrossAttnUpBlock2D",
255272
"SimpleCrossAttnUpBlock2D",
256273
"SimpleCrossAttnUpBlock2D",
257274
"ResnetUpsampleBlock2D",
258-
),
275+
],
259276
"upcast_attention": False,
260277
"use_linear_projection": False,
261278
}
@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
274291

275292
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
276293
diffusers_checkpoint.update(unet_conv_in(checkpoint))
294+
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
295+
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
277296

278297
# <original>.input_blocks -> <diffusers>.down_blocks
279298

@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
336355

337356
INPAINT_UNET_CONFIG = {
338357
"act_fn": "silu",
358+
"addition_embed_type": "text_image",
359+
"addition_embed_type_num_heads": 64,
339360
"attention_head_dim": 64,
340-
"block_out_channels": (384, 768, 1152, 1536),
361+
"block_out_channels": [384, 768, 1152, 1536],
341362
"center_input_sample": False,
342-
"class_embed_type": "identity",
363+
"class_embed_type": None,
364+
"class_embeddings_concat": None,
365+
"conv_in_kernel": 3,
366+
"conv_out_kernel": 3,
343367
"cross_attention_dim": 768,
344-
"down_block_types": (
368+
"cross_attention_norm": None,
369+
"down_block_types": [
345370
"ResnetDownsampleBlock2D",
346371
"SimpleCrossAttnDownBlock2D",
347372
"SimpleCrossAttnDownBlock2D",
348373
"SimpleCrossAttnDownBlock2D",
349-
),
374+
],
350375
"downsample_padding": 1,
351376
"dual_cross_attention": False,
377+
"encoder_hid_dim": 1024,
378+
"encoder_hid_dim_type": "text_image_proj",
352379
"flip_sin_to_cos": True,
353380
"freq_shift": 0,
354381
"in_channels": 9,
355382
"layers_per_block": 3,
383+
"mid_block_only_cross_attention": None,
356384
"mid_block_scale_factor": 1,
357385
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
358386
"norm_eps": 1e-05,
359387
"norm_num_groups": 32,
388+
"num_class_embeds": None,
360389
"only_cross_attention": False,
361390
"out_channels": 8,
391+
"projection_class_embeddings_input_dim": None,
392+
"resnet_out_scale_factor": 1.0,
393+
"resnet_skip_time_act": False,
362394
"resnet_time_scale_shift": "scale_shift",
363395
"sample_size": 64,
364-
"up_block_types": (
396+
"time_cond_proj_dim": None,
397+
"time_embedding_act_fn": None,
398+
"time_embedding_dim": None,
399+
"time_embedding_type": "positional",
400+
"timestep_post_act": None,
401+
"up_block_types": [
365402
"SimpleCrossAttnUpBlock2D",
366403
"SimpleCrossAttnUpBlock2D",
367404
"SimpleCrossAttnUpBlock2D",
368405
"ResnetUpsampleBlock2D",
369-
),
406+
],
370407
"upcast_attention": False,
371408
"use_linear_projection": False,
372409
}
@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
381418
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
382419
diffusers_checkpoint = {}
383420

384-
num_head_channels = UNET_CONFIG["attention_head_dim"]
421+
num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]
385422

386423
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
387424
diffusers_checkpoint.update(unet_conv_in(checkpoint))
425+
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
426+
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
388427

389428
# <original>.input_blocks -> <diffusers>.down_blocks
390429

@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
440479

441480
# done inpaint unet
442481

443-
# text proj
444-
445-
TEXT_PROJ_CONFIG = {}
446-
447-
448-
def text_proj_from_original_config():
449-
model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG)
450-
return model
451-
452-
453-
# Note that the input checkpoint is the original text2img model checkpoint
454-
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
455-
diffusers_checkpoint = {
456-
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
457-
"encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"],
458-
"encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"],
459-
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
460-
"clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"],
461-
"clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"],
462-
# <original>.proj_n -> <diffusers>.embedding_proj
463-
"embedding_proj.weight": checkpoint["proj_n.weight"],
464-
"embedding_proj.bias": checkpoint["proj_n.bias"],
465-
# <original>.ln_model_n -> <diffusers>.embedding_norm
466-
"embedding_norm.weight": checkpoint["ln_model_n.weight"],
467-
"embedding_norm.bias": checkpoint["ln_model_n.bias"],
468-
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
469-
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"],
470-
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"],
471-
}
472-
473-
return diffusers_checkpoint
474-
475482

476483
# unet utils
477484

@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
506513
return diffusers_checkpoint
507514

508515

516+
def unet_add_embedding(checkpoint):
517+
diffusers_checkpoint = {}
518+
519+
diffusers_checkpoint.update(
520+
{
521+
"add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
522+
"add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
523+
"add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
524+
"add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
525+
"add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
526+
"add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
527+
}
528+
)
529+
530+
return diffusers_checkpoint
531+
532+
533+
def unet_encoder_hid_proj(checkpoint):
534+
diffusers_checkpoint = {}
535+
536+
diffusers_checkpoint.update(
537+
{
538+
"encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
539+
"encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
540+
"encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
541+
"encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
542+
}
543+
)
544+
545+
return diffusers_checkpoint
546+
547+
509548
# <original>.out.0 -> <diffusers>.conv_norm_out
510549
def unet_conv_norm_out(checkpoint):
511550
diffusers_checkpoint = {}
@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):
857896

858897
unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)
859898

860-
# text proj interlude
861-
862-
# The original decoder implementation includes a set of parameters that are used
863-
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
864-
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
865-
# the parameters into the KandinskyTextProjModel class
866-
text_proj_model = text_proj_from_original_config()
867-
868-
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint)
869-
870-
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
871-
872899
del text2img_checkpoint
873900

874901
load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)
875902

876903
print("done loading text2img")
877904

878-
return unet_model, text_proj_model
905+
return unet_model
879906

880907

881908
def inpaint_text2img(*, args, checkpoint_map_location):
@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
891918
inpaint_unet_model, inpaint_text2img_checkpoint
892919
)
893920

894-
# text proj interlude
895-
896-
# The original decoder implementation includes a set of parameters that are used
897-
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
898-
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
899-
# the parameters into the KandinskyTextProjModel class
900-
text_proj_model = text_proj_from_original_config()
901-
902-
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint)
903-
904-
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
905-
906921
del inpaint_text2img_checkpoint
907922

908923
load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)
909924

910925
print("done loading inpaint text2img")
911926

912-
return inpaint_unet_model, text_proj_model
927+
return inpaint_unet_model
913928

914929

915930
# movq
@@ -1384,15 +1399,11 @@ def load_checkpoint_to_model(checkpoint, model, strict=False):
13841399
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
13851400
prior_model.save_pretrained(args.dump_path)
13861401
elif args.debug == "text2img":
1387-
unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
1402+
unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
13881403
unet_model.save_pretrained(f"{args.dump_path}/unet")
1389-
text_proj_model.save_pretrained(f"{args.dump_path}/text_proj")
13901404
elif args.debug == "inpaint_text2img":
1391-
inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img(
1392-
args=args, checkpoint_map_location=checkpoint_map_location
1393-
)
1405+
inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
13941406
inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
1395-
inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj")
13961407
elif args.debug == "decoder":
13971408
decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
13981409
decoder.save_pretrained(f"{args.dump_path}/decoder")

0 commit comments

Comments
 (0)