Skip to content

update conversion script for Kandinsky unet #3766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 89 additions & 78 deletions scripts/convert_kandinsky_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from diffusers import UNet2DConditionModel
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.models.vq_model import VQModel
from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel


"""
Expand Down Expand Up @@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix

UNET_CONFIG = {
"act_fn": "silu",
"addition_embed_type": "text_image",
"addition_embed_type_num_heads": 64,
"attention_head_dim": 64,
"block_out_channels": (384, 768, 1152, 1536),
"block_out_channels": [384, 768, 1152, 1536],
"center_input_sample": False,
"class_embed_type": "identity",
"class_embed_type": None,
"class_embeddings_concat": False,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 768,
"down_block_types": (
"cross_attention_norm": None,
"down_block_types": [
"ResnetDownsampleBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
),
],
"downsample_padding": 1,
"dual_cross_attention": False,
"encoder_hid_dim": 1024,
"encoder_hid_dim_type": "text_image_proj",
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 3,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 8,
"projection_class_embeddings_input_dim": None,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "scale_shift",
"sample_size": 64,
"up_block_types": (
"time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"up_block_types": [
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"ResnetUpsampleBlock2D",
),
],
"upcast_attention": False,
"use_linear_projection": False,
}
Expand All @@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):

diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
diffusers_checkpoint.update(unet_conv_in(checkpoint))
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))

# <original>.input_blocks -> <diffusers>.down_blocks

Expand Down Expand Up @@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):

INPAINT_UNET_CONFIG = {
"act_fn": "silu",
"addition_embed_type": "text_image",
"addition_embed_type_num_heads": 64,
"attention_head_dim": 64,
"block_out_channels": (384, 768, 1152, 1536),
"block_out_channels": [384, 768, 1152, 1536],
"center_input_sample": False,
"class_embed_type": "identity",
"class_embed_type": None,
"class_embeddings_concat": None,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 768,
"down_block_types": (
"cross_attention_norm": None,
"down_block_types": [
"ResnetDownsampleBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
),
],
"downsample_padding": 1,
"dual_cross_attention": False,
"encoder_hid_dim": 1024,
"encoder_hid_dim_type": "text_image_proj",
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 9,
"layers_per_block": 3,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 8,
"projection_class_embeddings_input_dim": None,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "scale_shift",
"sample_size": 64,
"up_block_types": (
"time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"up_block_types": [
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"ResnetUpsampleBlock2D",
),
],
"upcast_attention": False,
"use_linear_projection": False,
}
Expand All @@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}

num_head_channels = UNET_CONFIG["attention_head_dim"]
num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]

diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
diffusers_checkpoint.update(unet_conv_in(checkpoint))
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))

# <original>.input_blocks -> <diffusers>.down_blocks

Expand Down Expand Up @@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):

# done inpaint unet

# text proj

TEXT_PROJ_CONFIG = {}


def text_proj_from_original_config():
model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG)
return model


# Note that the input checkpoint is the original text2img model checkpoint
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
diffusers_checkpoint = {
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
"encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"],
"encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"],
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
"clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"],
"clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"],
# <original>.proj_n -> <diffusers>.embedding_proj
"embedding_proj.weight": checkpoint["proj_n.weight"],
"embedding_proj.bias": checkpoint["proj_n.bias"],
# <original>.ln_model_n -> <diffusers>.embedding_norm
"embedding_norm.weight": checkpoint["ln_model_n.weight"],
"embedding_norm.bias": checkpoint["ln_model_n.bias"],
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"],
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"],
}

return diffusers_checkpoint


# unet utils

Expand Down Expand Up @@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
return diffusers_checkpoint


def unet_add_embedding(checkpoint):
diffusers_checkpoint = {}

diffusers_checkpoint.update(
{
"add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
"add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
"add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
"add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
"add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
"add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
}
)

return diffusers_checkpoint


def unet_encoder_hid_proj(checkpoint):
diffusers_checkpoint = {}

diffusers_checkpoint.update(
{
"encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
"encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
"encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
"encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
}
)

return diffusers_checkpoint


# <original>.out.0 -> <diffusers>.conv_norm_out
def unet_conv_norm_out(checkpoint):
diffusers_checkpoint = {}
Expand Down Expand Up @@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):

unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)

# text proj interlude

# The original decoder implementation includes a set of parameters that are used
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
# the parameters into the KandinskyTextProjModel class
text_proj_model = text_proj_from_original_config()

text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint)

load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)

del text2img_checkpoint

load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)

print("done loading text2img")

return unet_model, text_proj_model
return unet_model


def inpaint_text2img(*, args, checkpoint_map_location):
Expand All @@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
inpaint_unet_model, inpaint_text2img_checkpoint
)

# text proj interlude

# The original decoder implementation includes a set of parameters that are used
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
# the parameters into the KandinskyTextProjModel class
text_proj_model = text_proj_from_original_config()

text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint)

load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)

del inpaint_text2img_checkpoint

load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)

print("done loading inpaint text2img")

return inpaint_unet_model, text_proj_model
return inpaint_unet_model


# movq
Expand Down Expand Up @@ -1384,15 +1399,11 @@ def load_checkpoint_to_model(checkpoint, model, strict=False):
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path)
elif args.debug == "text2img":
unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
unet_model.save_pretrained(f"{args.dump_path}/unet")
text_proj_model.save_pretrained(f"{args.dump_path}/text_proj")
elif args.debug == "inpaint_text2img":
inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img(
args=args, checkpoint_map_location=checkpoint_map_location
)
inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj")
elif args.debug == "decoder":
decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
decoder.save_pretrained(f"{args.dump_path}/decoder")
Expand Down