Skip to content

Commit fc28791

Browse files
authored
[BUG] Fix Autoencoderkl train script (#11113)
* add disc_optimizer step (not fix) * support syncbatchnorm in discriminator
1 parent ae14612 commit fc28791

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/research_projects/autoencoderkl/train_autoencoderkl.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def main(args):
627627
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
628628
perceptual_loss = lpips.LPIPS(net="vgg").eval()
629629
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
630+
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
630631

631632
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
632633
def unwrap_model(model):
@@ -951,13 +952,20 @@ def load_model_hook(models, input_dir):
951952
logits_fake = discriminator(reconstructions)
952953
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
953954
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
954-
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
955+
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
955956
logs = {
956-
"disc_loss": disc_loss.detach().mean().item(),
957+
"disc_loss": d_loss.detach().mean().item(),
957958
"logits_real": logits_real.detach().mean().item(),
958959
"logits_fake": logits_fake.detach().mean().item(),
959960
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
960961
}
962+
accelerator.backward(d_loss)
963+
if accelerator.sync_gradients:
964+
params_to_clip = discriminator.parameters()
965+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
966+
disc_optimizer.step()
967+
disc_lr_scheduler.step()
968+
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
961969
# Checks if the accelerator has performed an optimization step behind the scenes
962970
if accelerator.sync_gradients:
963971
progress_bar.update(1)

0 commit comments

Comments
 (0)