@@ -627,6 +627,7 @@ def main(args):
627
627
ema_vae = EMAModel (vae .parameters (), model_cls = AutoencoderKL , model_config = vae .config )
628
628
perceptual_loss = lpips .LPIPS (net = "vgg" ).eval ()
629
629
discriminator = NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).apply (weights_init )
630
+ discriminator = torch .nn .SyncBatchNorm .convert_sync_batchnorm (discriminator )
630
631
631
632
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
632
633
def unwrap_model (model ):
@@ -951,13 +952,20 @@ def load_model_hook(models, input_dir):
951
952
logits_fake = discriminator (reconstructions )
952
953
disc_loss = hinge_d_loss if args .disc_loss == "hinge" else vanilla_d_loss
953
954
disc_factor = args .disc_factor if global_step >= args .disc_start else 0.0
954
- disc_loss = disc_factor * disc_loss (logits_real , logits_fake )
955
+ d_loss = disc_factor * disc_loss (logits_real , logits_fake )
955
956
logs = {
956
- "disc_loss" : disc_loss .detach ().mean ().item (),
957
+ "disc_loss" : d_loss .detach ().mean ().item (),
957
958
"logits_real" : logits_real .detach ().mean ().item (),
958
959
"logits_fake" : logits_fake .detach ().mean ().item (),
959
960
"disc_lr" : disc_lr_scheduler .get_last_lr ()[0 ],
960
961
}
962
+ accelerator .backward (d_loss )
963
+ if accelerator .sync_gradients :
964
+ params_to_clip = discriminator .parameters ()
965
+ accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
966
+ disc_optimizer .step ()
967
+ disc_lr_scheduler .step ()
968
+ disc_optimizer .zero_grad (set_to_none = args .set_grads_to_none )
961
969
# Checks if the accelerator has performed an optimization step behind the scenes
962
970
if accelerator .sync_gradients :
963
971
progress_bar .update (1 )
0 commit comments