48
48
DDPMScheduler ,
49
49
DiffusionPipeline ,
50
50
DPMSolverMultistepScheduler ,
51
+ StableDiffusionPipeline ,
51
52
UNet2DConditionModel ,
52
53
)
53
54
from diffusers .optimization import get_scheduler
@@ -111,18 +112,15 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
111
112
f .write (yaml + model_card )
112
113
113
114
114
- def log_validation (
115
- text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
116
- ):
115
+ def log_validation (text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch ):
117
116
logger .info (
118
117
f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
119
118
f" { args .validation_prompt } ."
120
119
)
121
- # create pipeline (note: unet and vae are loaded again in float32)
122
120
pipeline = DiffusionPipeline .from_pretrained (
123
121
args .pretrained_model_name_or_path ,
124
122
text_encoder = accelerator .unwrap_model (text_encoder_1 ),
125
- text_encoder_2 = accelerator . unwrap_model ( text_encoder_2 ) ,
123
+ text_encoder_2 = text_encoder_2 ,
126
124
tokenizer = tokenizer_1 ,
127
125
tokenizer_2 = tokenizer_2 ,
128
126
unet = unet ,
@@ -361,7 +359,7 @@ def parse_args():
361
359
parser .add_argument (
362
360
"--validation_prompt" ,
363
361
type = str ,
364
- default = None ,
362
+ default = "A <cat-toy> backpack" ,
365
363
help = "A prompt that is used during validation to verify that the model is learning." ,
366
364
)
367
365
parser .add_argument (
@@ -380,16 +378,6 @@ def parse_args():
380
378
" and logging the images."
381
379
),
382
380
)
383
- parser .add_argument (
384
- "--validation_epochs" ,
385
- type = int ,
386
- default = None ,
387
- help = (
388
- "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
389
- " `args.validation_prompt` multiple times: `args.num_validation_images`"
390
- " and logging the images."
391
- ),
392
- )
393
381
parser .add_argument ("--local_rank" , type = int , default = - 1 , help = "For distributed training: local_rank" )
394
382
parser .add_argument (
395
383
"--checkpointing_steps" ,
@@ -418,11 +406,6 @@ def parse_args():
418
406
parser .add_argument (
419
407
"--enable_xformers_memory_efficient_attention" , action = "store_true" , help = "Whether or not to use xformers."
420
408
)
421
- parser .add_argument (
422
- "--no_safe_serialization" ,
423
- action = "store_true" ,
424
- help = "If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead." ,
425
- )
426
409
427
410
args = parser .parse_args ()
428
411
env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -529,6 +512,7 @@ def __init__(
529
512
530
513
self .templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
531
514
self .flip_transform = transforms .RandomHorizontalFlip (p = self .flip_p )
515
+ self .crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
532
516
533
517
def __len__ (self ):
534
518
return self ._length
@@ -543,6 +527,18 @@ def __getitem__(self, i):
543
527
placeholder_string = self .placeholder_token
544
528
text = random .choice (self .templates ).format (placeholder_string )
545
529
530
+ example ["original_size" ] = (image .height , image .width )
531
+
532
+ if self .center_crop :
533
+ y1 = max (0 , int (round ((image .height - self .size ) / 2.0 )))
534
+ x1 = max (0 , int (round ((image .width - self .size ) / 2.0 )))
535
+ image = self .crop (image )
536
+ else :
537
+ y1 , x1 , h , w = self .crop .get_params (image , (self .size , self .size ))
538
+ image = transforms .functional .crop (image , y1 , x1 , h , w )
539
+
540
+ example ["crop_top_left" ] = (y1 , x1 )
541
+
546
542
example ["input_ids_1" ] = self .tokenizer_1 (
547
543
text ,
548
544
padding = "max_length" ,
@@ -564,13 +560,7 @@ def __getitem__(self, i):
564
560
565
561
if self .center_crop :
566
562
crop = min (img .shape [0 ], img .shape [1 ])
567
- (
568
- h ,
569
- w ,
570
- ) = (
571
- img .shape [0 ],
572
- img .shape [1 ],
573
- )
563
+ (h , w ,) = (img .shape [0 ], img .shape [1 ],)
574
564
img = img [(h - crop ) // 2 : (h + crop ) // 2 , (w - crop ) // 2 : (w + crop ) // 2 ]
575
565
576
566
image = Image .fromarray (img )
@@ -646,6 +636,7 @@ def main():
646
636
args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
647
637
)
648
638
639
+
649
640
# Add the placeholder token in tokenizer_1
650
641
placeholder_tokens = [args .placeholder_token ]
651
642
@@ -686,21 +677,14 @@ def main():
686
677
# Freeze vae and unet
687
678
vae .requires_grad_ (False )
688
679
unet .requires_grad_ (False )
680
+ text_encoder_2 .requires_grad_ (False )
689
681
# Freeze all parameters except for the token embeddings in text encoder
690
682
text_encoder_1 .text_model .encoder .requires_grad_ (False )
691
683
text_encoder_1 .text_model .final_layer_norm .requires_grad_ (False )
692
684
text_encoder_1 .text_model .embeddings .position_embedding .requires_grad_ (False )
693
- text_encoder_2 .text_model .encoder .requires_grad_ (False )
694
- text_encoder_2 .text_model .final_layer_norm .requires_grad_ (False )
695
- text_encoder_2 .text_model .embeddings .position_embedding .requires_grad_ (False )
696
685
697
686
if args .gradient_checkpointing :
698
- # Keep unet in train mode if we are using gradient checkpointing to save memory.
699
- # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
700
- unet .train ()
701
687
text_encoder_1 .gradient_checkpointing_enable ()
702
- text_encoder_2 .gradient_checkpointing_enable ()
703
- unet .enable_gradient_checkpointing ()
704
688
705
689
if args .enable_xformers_memory_efficient_attention :
706
690
if is_xformers_available ():
@@ -749,15 +733,6 @@ def main():
749
733
train_dataloader = torch .utils .data .DataLoader (
750
734
train_dataset , batch_size = args .train_batch_size , shuffle = True , num_workers = args .dataloader_num_workers
751
735
)
752
- if args .validation_epochs is not None :
753
- warnings .warn (
754
- f"FutureWarning: You are doing logging with validation_epochs={ args .validation_epochs } ."
755
- " Deprecated validation_epochs in favor of `validation_steps`"
756
- f"Setting `args.validation_steps` to { args .validation_epochs * len (train_dataset )} " ,
757
- FutureWarning ,
758
- stacklevel = 2 ,
759
- )
760
- args .validation_steps = args .validation_epochs * len (train_dataset )
761
736
762
737
# Scheduler and math around the number of training steps.
763
738
overrode_max_train_steps = False
@@ -791,7 +766,7 @@ def main():
791
766
# Move vae and unet and text_encoder_2 to device and cast to weight_dtype
792
767
unet .to (accelerator .device , dtype = weight_dtype )
793
768
vae .to (accelerator .device , dtype = weight_dtype )
794
- text_encoder_2 = text_encoder_2 .to (accelerator .device , dtype = weight_dtype )
769
+ text_encoder_2 .to (accelerator .device , dtype = weight_dtype )
795
770
796
771
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
797
772
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -876,27 +851,18 @@ def main():
876
851
noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
877
852
878
853
# Get the text embedding for conditioning
879
- encoder_hidden_states_1 = (
880
- text_encoder_1 (batch ["input_ids_1" ], output_hidden_states = True )
881
- .hidden_states [- 2 ]
882
- .to (dtype = weight_dtype )
883
- )
884
- encoder_output_2 = text_encoder_2 (
885
- batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True
886
- )
854
+ encoder_hidden_states_1 = text_encoder_1 (batch ["input_ids_1" ], output_hidden_states = True ).hidden_states [- 2 ].to (dtype = weight_dtype )
855
+ encoder_output_2 = text_encoder_2 (batch ["input_ids_2" ].reshape (batch ["input_ids_1" ].shape [0 ], - 1 ), output_hidden_states = True )
887
856
encoder_hidden_states_2 = encoder_output_2 .hidden_states [- 2 ].to (dtype = weight_dtype )
888
- sample_size = unet .config .sample_size * (2 ** (len (vae .config .block_out_channels ) - 1 ))
889
- original_size = (sample_size , sample_size )
890
- add_time_ids = torch .tensor (
891
- [list (original_size + (0 , 0 ) + original_size )], dtype = weight_dtype , device = accelerator .device
892
- )
857
+ original_size = [(batch ["original_size" ][0 ][i ].item (), batch ["original_size" ][1 ][i ].item ()) for i in range (args .train_batch_size )]
858
+ crop_top_left = [(batch ["crop_top_left" ][0 ][i ].item (), batch ["crop_top_left" ][1 ][i ].item ()) for i in range (args .train_batch_size )]
859
+ target_size = (args .resolution , args .resolution )
860
+ add_time_ids = torch .cat ([torch .tensor (original_size [i ] + crop_top_left [i ] + target_size ) for i in range (args .train_batch_size )]).to (accelerator .device , dtype = weight_dtype )
893
861
added_cond_kwargs = {"text_embeds" : encoder_output_2 [0 ], "time_ids" : add_time_ids }
894
862
encoder_hidden_states = torch .cat ([encoder_hidden_states_1 , encoder_hidden_states_2 ], dim = - 1 )
895
863
896
864
# Predict the noise residual
897
- model_pred = unet (
898
- noisy_latents , timesteps , encoder_hidden_states , added_cond_kwargs = added_cond_kwargs
899
- ).sample
865
+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states , added_cond_kwargs = added_cond_kwargs ).sample
900
866
901
867
# Get the target for loss depending on the prediction type
902
868
if noise_scheduler .config .prediction_type == "epsilon" :
@@ -929,19 +895,15 @@ def main():
929
895
progress_bar .update (1 )
930
896
global_step += 1
931
897
if global_step % args .save_steps == 0 :
932
- weight_name = (
933
- f"learned_embeds-steps-{ global_step } .bin"
934
- if args .no_safe_serialization
935
- else f"learned_embeds-steps-{ global_step } .safetensors"
936
- )
898
+ weight_name = (f"learned_embeds-steps-{ global_step } .safetensors" )
937
899
save_path = os .path .join (args .output_dir , weight_name )
938
900
save_progress (
939
901
text_encoder_1 ,
940
902
placeholder_token_ids ,
941
903
accelerator ,
942
904
args ,
943
905
save_path ,
944
- safe_serialization = not args . no_safe_serialization ,
906
+ safe_serialization = True ,
945
907
)
946
908
947
909
if accelerator .is_main_process :
@@ -972,16 +934,7 @@ def main():
972
934
973
935
if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
974
936
images = log_validation (
975
- text_encoder_1 ,
976
- text_encoder_2 ,
977
- tokenizer_1 ,
978
- tokenizer_2 ,
979
- unet ,
980
- vae ,
981
- args ,
982
- accelerator ,
983
- weight_dtype ,
984
- epoch ,
937
+ text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
985
938
)
986
939
987
940
logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler_1 .get_last_lr ()[0 ]}
@@ -993,6 +946,10 @@ def main():
993
946
# Create the pipeline using the trained modules and save it.
994
947
accelerator .wait_for_everyone ()
995
948
if accelerator .is_main_process :
949
+ images = log_validation (
950
+ text_encoder_1 , text_encoder_2 , tokenizer_1 , tokenizer_2 , unet , vae , args , accelerator , weight_dtype , epoch
951
+ )
952
+
996
953
if args .push_to_hub and not args .save_as_full_pipeline :
997
954
logger .warn ("Enabling full model saving because --push_to_hub=True was specified." )
998
955
save_full_model = True
@@ -1002,23 +959,23 @@ def main():
1002
959
pipeline = DiffusionPipeline .from_pretrained (
1003
960
args .pretrained_model_name_or_path ,
1004
961
text_encoder = accelerator .unwrap_model (text_encoder_1 ),
1005
- text_encoder_2 = accelerator . unwrap_model ( text_encoder_2 ) ,
962
+ text_encoder_2 = text_encoder_2 ,
1006
963
vae = vae ,
1007
964
unet = unet ,
1008
965
tokenizer = tokenizer_1 ,
1009
966
tokenizer_2 = tokenizer_2 ,
1010
967
)
1011
968
pipeline .save_pretrained (args .output_dir )
1012
969
# Save the newly trained embeddings
1013
- weight_name = "learned_embeds.bin" if args . no_safe_serialization else "learned_embeds. safetensors"
970
+ weight_name = "learned_embeds.safetensors"
1014
971
save_path = os .path .join (args .output_dir , weight_name )
1015
972
save_progress (
1016
973
text_encoder_1 ,
1017
974
placeholder_token_ids ,
1018
975
accelerator ,
1019
976
args ,
1020
977
save_path ,
1021
- safe_serialization = not args . no_safe_serialization ,
978
+ safe_serialization = True ,
1022
979
)
1023
980
1024
981
if args .push_to_hub :
@@ -1035,6 +992,9 @@ def main():
1035
992
ignore_patterns = ["step_*" , "epoch_*" ],
1036
993
)
1037
994
995
+ for i in range (len (images )):
996
+ images [i ].save (f"cat-backpack_sdxl_test_{ i } .png" )
997
+
1038
998
accelerator .end_training ()
1039
999
1040
1000
0 commit comments