8
8
from diffusers import UNet2DConditionModel
9
9
from diffusers .models .prior_transformer import PriorTransformer
10
10
from diffusers .models .vq_model import VQModel
11
- from diffusers .pipelines .kandinsky .text_proj import KandinskyTextProjModel
12
11
13
12
14
13
"""
@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
225
224
226
225
UNET_CONFIG = {
227
226
"act_fn" : "silu" ,
227
+ "addition_embed_type" : "text_image" ,
228
+ "addition_embed_type_num_heads" : 64 ,
228
229
"attention_head_dim" : 64 ,
229
- "block_out_channels" : ( 384 , 768 , 1152 , 1536 ) ,
230
+ "block_out_channels" : [ 384 , 768 , 1152 , 1536 ] ,
230
231
"center_input_sample" : False ,
231
- "class_embed_type" : "identity" ,
232
+ "class_embed_type" : None ,
233
+ "class_embeddings_concat" : False ,
234
+ "conv_in_kernel" : 3 ,
235
+ "conv_out_kernel" : 3 ,
232
236
"cross_attention_dim" : 768 ,
233
- "down_block_types" : (
237
+ "cross_attention_norm" : None ,
238
+ "down_block_types" : [
234
239
"ResnetDownsampleBlock2D" ,
235
240
"SimpleCrossAttnDownBlock2D" ,
236
241
"SimpleCrossAttnDownBlock2D" ,
237
242
"SimpleCrossAttnDownBlock2D" ,
238
- ) ,
243
+ ] ,
239
244
"downsample_padding" : 1 ,
240
245
"dual_cross_attention" : False ,
246
+ "encoder_hid_dim" : 1024 ,
247
+ "encoder_hid_dim_type" : "text_image_proj" ,
241
248
"flip_sin_to_cos" : True ,
242
249
"freq_shift" : 0 ,
243
250
"in_channels" : 4 ,
244
251
"layers_per_block" : 3 ,
252
+ "mid_block_only_cross_attention" : None ,
245
253
"mid_block_scale_factor" : 1 ,
246
254
"mid_block_type" : "UNetMidBlock2DSimpleCrossAttn" ,
247
255
"norm_eps" : 1e-05 ,
248
256
"norm_num_groups" : 32 ,
257
+ "num_class_embeds" : None ,
249
258
"only_cross_attention" : False ,
250
259
"out_channels" : 8 ,
260
+ "projection_class_embeddings_input_dim" : None ,
261
+ "resnet_out_scale_factor" : 1.0 ,
262
+ "resnet_skip_time_act" : False ,
251
263
"resnet_time_scale_shift" : "scale_shift" ,
252
264
"sample_size" : 64 ,
253
- "up_block_types" : (
265
+ "time_cond_proj_dim" : None ,
266
+ "time_embedding_act_fn" : None ,
267
+ "time_embedding_dim" : None ,
268
+ "time_embedding_type" : "positional" ,
269
+ "timestep_post_act" : None ,
270
+ "up_block_types" : [
254
271
"SimpleCrossAttnUpBlock2D" ,
255
272
"SimpleCrossAttnUpBlock2D" ,
256
273
"SimpleCrossAttnUpBlock2D" ,
257
274
"ResnetUpsampleBlock2D" ,
258
- ) ,
275
+ ] ,
259
276
"upcast_attention" : False ,
260
277
"use_linear_projection" : False ,
261
278
}
@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
274
291
275
292
diffusers_checkpoint .update (unet_time_embeddings (checkpoint ))
276
293
diffusers_checkpoint .update (unet_conv_in (checkpoint ))
294
+ diffusers_checkpoint .update (unet_add_embedding (checkpoint ))
295
+ diffusers_checkpoint .update (unet_encoder_hid_proj (checkpoint ))
277
296
278
297
# <original>.input_blocks -> <diffusers>.down_blocks
279
298
@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
336
355
337
356
INPAINT_UNET_CONFIG = {
338
357
"act_fn" : "silu" ,
358
+ "addition_embed_type" : "text_image" ,
359
+ "addition_embed_type_num_heads" : 64 ,
339
360
"attention_head_dim" : 64 ,
340
- "block_out_channels" : ( 384 , 768 , 1152 , 1536 ) ,
361
+ "block_out_channels" : [ 384 , 768 , 1152 , 1536 ] ,
341
362
"center_input_sample" : False ,
342
- "class_embed_type" : "identity" ,
363
+ "class_embed_type" : None ,
364
+ "class_embeddings_concat" : None ,
365
+ "conv_in_kernel" : 3 ,
366
+ "conv_out_kernel" : 3 ,
343
367
"cross_attention_dim" : 768 ,
344
- "down_block_types" : (
368
+ "cross_attention_norm" : None ,
369
+ "down_block_types" : [
345
370
"ResnetDownsampleBlock2D" ,
346
371
"SimpleCrossAttnDownBlock2D" ,
347
372
"SimpleCrossAttnDownBlock2D" ,
348
373
"SimpleCrossAttnDownBlock2D" ,
349
- ) ,
374
+ ] ,
350
375
"downsample_padding" : 1 ,
351
376
"dual_cross_attention" : False ,
377
+ "encoder_hid_dim" : 1024 ,
378
+ "encoder_hid_dim_type" : "text_image_proj" ,
352
379
"flip_sin_to_cos" : True ,
353
380
"freq_shift" : 0 ,
354
381
"in_channels" : 9 ,
355
382
"layers_per_block" : 3 ,
383
+ "mid_block_only_cross_attention" : None ,
356
384
"mid_block_scale_factor" : 1 ,
357
385
"mid_block_type" : "UNetMidBlock2DSimpleCrossAttn" ,
358
386
"norm_eps" : 1e-05 ,
359
387
"norm_num_groups" : 32 ,
388
+ "num_class_embeds" : None ,
360
389
"only_cross_attention" : False ,
361
390
"out_channels" : 8 ,
391
+ "projection_class_embeddings_input_dim" : None ,
392
+ "resnet_out_scale_factor" : 1.0 ,
393
+ "resnet_skip_time_act" : False ,
362
394
"resnet_time_scale_shift" : "scale_shift" ,
363
395
"sample_size" : 64 ,
364
- "up_block_types" : (
396
+ "time_cond_proj_dim" : None ,
397
+ "time_embedding_act_fn" : None ,
398
+ "time_embedding_dim" : None ,
399
+ "time_embedding_type" : "positional" ,
400
+ "timestep_post_act" : None ,
401
+ "up_block_types" : [
365
402
"SimpleCrossAttnUpBlock2D" ,
366
403
"SimpleCrossAttnUpBlock2D" ,
367
404
"SimpleCrossAttnUpBlock2D" ,
368
405
"ResnetUpsampleBlock2D" ,
369
- ) ,
406
+ ] ,
370
407
"upcast_attention" : False ,
371
408
"use_linear_projection" : False ,
372
409
}
@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
381
418
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint (model , checkpoint ):
382
419
diffusers_checkpoint = {}
383
420
384
- num_head_channels = UNET_CONFIG ["attention_head_dim" ]
421
+ num_head_channels = INPAINT_UNET_CONFIG ["attention_head_dim" ]
385
422
386
423
diffusers_checkpoint .update (unet_time_embeddings (checkpoint ))
387
424
diffusers_checkpoint .update (unet_conv_in (checkpoint ))
425
+ diffusers_checkpoint .update (unet_add_embedding (checkpoint ))
426
+ diffusers_checkpoint .update (unet_encoder_hid_proj (checkpoint ))
388
427
389
428
# <original>.input_blocks -> <diffusers>.down_blocks
390
429
@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
440
479
441
480
# done inpaint unet
442
481
443
- # text proj
444
-
445
- TEXT_PROJ_CONFIG = {}
446
-
447
-
448
- def text_proj_from_original_config ():
449
- model = KandinskyTextProjModel (** TEXT_PROJ_CONFIG )
450
- return model
451
-
452
-
453
- # Note that the input checkpoint is the original text2img model checkpoint
454
- def text_proj_original_checkpoint_to_diffusers_checkpoint (checkpoint ):
455
- diffusers_checkpoint = {
456
- # <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
457
- "encoder_hidden_states_proj.weight" : checkpoint ["to_model_dim_n.weight" ],
458
- "encoder_hidden_states_proj.bias" : checkpoint ["to_model_dim_n.bias" ],
459
- # <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
460
- "clip_extra_context_tokens_proj.weight" : checkpoint ["clip_to_seq.weight" ],
461
- "clip_extra_context_tokens_proj.bias" : checkpoint ["clip_to_seq.bias" ],
462
- # <original>.proj_n -> <diffusers>.embedding_proj
463
- "embedding_proj.weight" : checkpoint ["proj_n.weight" ],
464
- "embedding_proj.bias" : checkpoint ["proj_n.bias" ],
465
- # <original>.ln_model_n -> <diffusers>.embedding_norm
466
- "embedding_norm.weight" : checkpoint ["ln_model_n.weight" ],
467
- "embedding_norm.bias" : checkpoint ["ln_model_n.bias" ],
468
- # <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
469
- "clip_image_embeddings_project_to_time_embeddings.weight" : checkpoint ["img_layer.weight" ],
470
- "clip_image_embeddings_project_to_time_embeddings.bias" : checkpoint ["img_layer.bias" ],
471
- }
472
-
473
- return diffusers_checkpoint
474
-
475
482
476
483
# unet utils
477
484
@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
506
513
return diffusers_checkpoint
507
514
508
515
516
+ def unet_add_embedding (checkpoint ):
517
+ diffusers_checkpoint = {}
518
+
519
+ diffusers_checkpoint .update (
520
+ {
521
+ "add_embedding.text_norm.weight" : checkpoint ["ln_model_n.weight" ],
522
+ "add_embedding.text_norm.bias" : checkpoint ["ln_model_n.bias" ],
523
+ "add_embedding.text_proj.weight" : checkpoint ["proj_n.weight" ],
524
+ "add_embedding.text_proj.bias" : checkpoint ["proj_n.bias" ],
525
+ "add_embedding.image_proj.weight" : checkpoint ["img_layer.weight" ],
526
+ "add_embedding.image_proj.bias" : checkpoint ["img_layer.bias" ],
527
+ }
528
+ )
529
+
530
+ return diffusers_checkpoint
531
+
532
+
533
+ def unet_encoder_hid_proj (checkpoint ):
534
+ diffusers_checkpoint = {}
535
+
536
+ diffusers_checkpoint .update (
537
+ {
538
+ "encoder_hid_proj.image_embeds.weight" : checkpoint ["clip_to_seq.weight" ],
539
+ "encoder_hid_proj.image_embeds.bias" : checkpoint ["clip_to_seq.bias" ],
540
+ "encoder_hid_proj.text_proj.weight" : checkpoint ["to_model_dim_n.weight" ],
541
+ "encoder_hid_proj.text_proj.bias" : checkpoint ["to_model_dim_n.bias" ],
542
+ }
543
+ )
544
+
545
+ return diffusers_checkpoint
546
+
547
+
509
548
# <original>.out.0 -> <diffusers>.conv_norm_out
510
549
def unet_conv_norm_out (checkpoint ):
511
550
diffusers_checkpoint = {}
@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):
857
896
858
897
unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint (unet_model , text2img_checkpoint )
859
898
860
- # text proj interlude
861
-
862
- # The original decoder implementation includes a set of parameters that are used
863
- # for creating the `encoder_hidden_states` which are what the U-net is conditioned
864
- # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
865
- # the parameters into the KandinskyTextProjModel class
866
- text_proj_model = text_proj_from_original_config ()
867
-
868
- text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint (text2img_checkpoint )
869
-
870
- load_checkpoint_to_model (text_proj_checkpoint , text_proj_model , strict = True )
871
-
872
899
del text2img_checkpoint
873
900
874
901
load_checkpoint_to_model (unet_diffusers_checkpoint , unet_model , strict = True )
875
902
876
903
print ("done loading text2img" )
877
904
878
- return unet_model , text_proj_model
905
+ return unet_model
879
906
880
907
881
908
def inpaint_text2img (* , args , checkpoint_map_location ):
@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
891
918
inpaint_unet_model , inpaint_text2img_checkpoint
892
919
)
893
920
894
- # text proj interlude
895
-
896
- # The original decoder implementation includes a set of parameters that are used
897
- # for creating the `encoder_hidden_states` which are what the U-net is conditioned
898
- # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
899
- # the parameters into the KandinskyTextProjModel class
900
- text_proj_model = text_proj_from_original_config ()
901
-
902
- text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint (inpaint_text2img_checkpoint )
903
-
904
- load_checkpoint_to_model (text_proj_checkpoint , text_proj_model , strict = True )
905
-
906
921
del inpaint_text2img_checkpoint
907
922
908
923
load_checkpoint_to_model (inpaint_unet_diffusers_checkpoint , inpaint_unet_model , strict = True )
909
924
910
925
print ("done loading inpaint text2img" )
911
926
912
- return inpaint_unet_model , text_proj_model
927
+ return inpaint_unet_model
913
928
914
929
915
930
# movq
@@ -1384,15 +1399,11 @@ def load_checkpoint_to_model(checkpoint, model, strict=False):
1384
1399
prior_model = prior (args = args , checkpoint_map_location = checkpoint_map_location )
1385
1400
prior_model .save_pretrained (args .dump_path )
1386
1401
elif args .debug == "text2img" :
1387
- unet_model , text_proj_model = text2img (args = args , checkpoint_map_location = checkpoint_map_location )
1402
+ unet_model = text2img (args = args , checkpoint_map_location = checkpoint_map_location )
1388
1403
unet_model .save_pretrained (f"{ args .dump_path } /unet" )
1389
- text_proj_model .save_pretrained (f"{ args .dump_path } /text_proj" )
1390
1404
elif args .debug == "inpaint_text2img" :
1391
- inpaint_unet_model , inpaint_text_proj_model = inpaint_text2img (
1392
- args = args , checkpoint_map_location = checkpoint_map_location
1393
- )
1405
+ inpaint_unet_model = inpaint_text2img (args = args , checkpoint_map_location = checkpoint_map_location )
1394
1406
inpaint_unet_model .save_pretrained (f"{ args .dump_path } /inpaint_unet" )
1395
- inpaint_text_proj_model .save_pretrained (f"{ args .dump_path } /inpaint_text_proj" )
1396
1407
elif args .debug == "decoder" :
1397
1408
decoder = movq (args = args , checkpoint_map_location = checkpoint_map_location )
1398
1409
decoder .save_pretrained (f"{ args .dump_path } /decoder" )
0 commit comments