diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index cf13ecdbf8ac..31cf8414ac10 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -627,6 +627,7 @@ def main(args): ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config) perceptual_loss = lpips.LPIPS(net="vgg").eval() discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init) + discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator) # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) def unwrap_model(model): @@ -951,13 +952,20 @@ def load_model_hook(models, input_dir): logits_fake = discriminator(reconstructions) disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 - disc_loss = disc_factor * disc_loss(logits_real, logits_fake) + d_loss = disc_factor * disc_loss(logits_real, logits_fake) logs = { - "disc_loss": disc_loss.detach().mean().item(), + "disc_loss": d_loss.detach().mean().item(), "logits_real": logits_real.detach().mean().item(), "logits_fake": logits_fake.detach().mean().item(), "disc_lr": disc_lr_scheduler.get_last_lr()[0], } + accelerator.backward(d_loss) + if accelerator.sync_gradients: + params_to_clip = discriminator.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + disc_optimizer.step() + disc_lr_scheduler.step() + disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1)