Skip to content

[BUG] Fix Autoencoderkl train script #11113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/research_projects/autoencoderkl/train_autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ def main(args):
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
perceptual_loss = lpips.LPIPS(net="vgg").eval()
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to version-guard torch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. As far as I know, this method first appeared in PyTorch 1.1.0 and is still valid in 2.6.0.

https://pytorch.org/docs/1.1.0/nn.html?highlight=syncbatchnorm#torch.nn.SyncBatchNorm


# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
def unwrap_model(model):
Expand Down Expand Up @@ -951,13 +952,20 @@ def load_model_hook(models, input_dir):
logits_fake = discriminator(reconstructions)
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
logs = {
"disc_loss": disc_loss.detach().mean().item(),
"disc_loss": d_loss.detach().mean().item(),
"logits_real": logits_real.detach().mean().item(),
"logits_fake": logits_fake.detach().mean().item(),
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
}
accelerator.backward(d_loss)
if accelerator.sync_gradients:
params_to_clip = discriminator.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
disc_optimizer.step()
disc_lr_scheduler.step()
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
Loading