Description
Describe the bug
This bug seems to effectively make the script unusable on v100 or lower with the latest update to the dreambooth_lora_sdxl script, I tried a lot with every compromise I could think of (even lowering resolution) but still can't get it to run without running out of memory, using full precision.
Reproduction
#!/usr/bin/env bash
!accelerate launch train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"
--pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix"
--instance_data_dir="/content/in"
--output_dir="/content/out"
--variant="fp16"
--instance_prompt="Photo of a TOK person"
--resolution=1024
--train_batch_size=1
--adam_weight_decay=0.01
--mixed_precision="fp16"
--gradient_accumulation_steps=3
--gradient_checkpointing
--learning_rate=1e-4
--use_8bit_adam
--lr_scheduler="constant"
--lr_warmup_steps=0
--rank=16
--max_train_steps=625
--checkpointing_steps=250
--seed="0"
Logs
Steps: 0% 0/625 [00:02<?, ?it/s, loss=0.0276, lr=0.0001]Traceback (most recent call last):
File "/content/train_dreambooth_lora_sdxl.py", line 1716, in <module>
main(args)
File "/content/train_dreambooth_lora_sdxl.py", line 1494, in main
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2040, in clip_grad_norm_
self.unscale_gradients()
File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2003, in unscale_gradients
self.scaler.unscale_(opt)
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py", line 307, in unscale_
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py", line 229, in _unscale_grads_
raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.
Steps: 0% 0/625 [00:03<?, ?it/s, loss=0.0276, lr=0.0001]
Traceback (most recent call last):
File "/usr/local/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 47, in main
args.func(args)
File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 1017, in launch_command
simple_launcher(args)
File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 637, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
System Info
colab v100