From 298709e13b2137cea7f3747157a3fd887e55d34e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 21:05:43 +0200 Subject: [PATCH 01/19] remove custom scheduler --- .../train_dreambooth_lora_flux_advanced.py | 107 +----------------- 1 file changed, 3 insertions(+), 104 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index b8194507d822..0a1d0a919fd0 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1265,109 +1265,6 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids - -# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: -# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 -class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - with torch.no_grad(): - # create weights for timesteps - num_timesteps = 1000 - - # generate the multiplier based on cosmap loss weighing - # this is only used on linear timesteps for now - - # cosine map weighing is higher in the middle and lower at the ends - # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 - # cosmap_weighing = 2 / (math.pi * bot) - - # sigma sqrt weighing is significantly higher at the end and lower at the beginning - sigma_sqrt_weighing = (self.sigmas**-2.0).float() - # clip at 1e4 (1e6 is too high) - sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) - # bring to a mean of 1 - sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() - - # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") - - self.linear_timesteps = timesteps - # self.linear_timesteps_weights = cosmap_weighing - self.linear_timesteps_weights = sigma_sqrt_weighing - - # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') - pass - - def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: - # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] - - # Get the weights for the timesteps - weights = self.linear_timesteps_weights[step_indices].flatten() - - return weights - - def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: - sigmas = self.sigmas.to(device=device, dtype=dtype) - schedule_timesteps = self.timesteps.to(device) - timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - - return sigma - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise - t_01 = (timesteps / 1000).to(original_samples.device) - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise - return noisy_model_input - - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: - return sample - - def set_train_timesteps(self, num_timesteps, device, linear=False): - if linear: - timesteps = torch.linspace(1000, 0, num_timesteps, device=device) - self.timesteps = timesteps - return timesteps - else: - # distribute them closer to center. Inference distributes them as a bias toward first - # Generate values from 0 to 1 - t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) - - # Scale and reverse the values to go from 1000 to 0 - timesteps = (1 - t) * 1000 - - # Sort the timesteps in descending order - timesteps, _ = torch.sort(timesteps, descending=True) - - self.timesteps = timesteps.to(device=device) - - return timesteps - - def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1499,7 +1396,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) @@ -2337,6 +2234,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) + # save embeddings + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") From 364f478d3f54dc51054c95163dfa3081823aa9ec Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 21:24:04 +0200 Subject: [PATCH 02/19] update requirements.txt --- examples/advanced_diffusion_training/requirements.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/requirements.txt b/examples/advanced_diffusion_training/requirements.txt index 3f86855e1d1e..dbc124ff6526 100644 --- a/examples/advanced_diffusion_training/requirements.txt +++ b/examples/advanced_diffusion_training/requirements.txt @@ -1,7 +1,8 @@ -accelerate>=0.16.0 +accelerate>=0.31.0 torchvision -transformers>=4.25.1 +transformers>=4.41.2 ftfy tensorboard Jinja2 -peft==0.7.0 \ No newline at end of file +peft>=0.11.1 +sentencepiece \ No newline at end of file From 90e9517fb8e25f65a67330d09f516eaff3c1f1a4 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 23:03:08 +0200 Subject: [PATCH 03/19] log_validation with mixed precision --- .../train_dreambooth_lora_flux_advanced.py | 17 +++++++++++---- .../dreambooth/train_dreambooth_lora_flux.py | 21 ++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 0a1d0a919fd0..6b87c9fd197d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -228,10 +228,19 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - autocast_ctx = nullcontext() - - with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + autocast_ctx = torch.autocast(accelerator.device.type) + + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index dda3300d65cc..478e73927789 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -177,16 +177,24 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type) - with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -203,8 +211,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() return images From bdd6caefaa9d6caf4194e24b351cfcb78197c4b8 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 23:16:22 +0200 Subject: [PATCH 04/19] add intermediate embeddings saving when checkpointing is enabled --- .../train_dreambooth_lora_flux_advanced.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 6b87c9fd197d..cd5f271199b7 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2247,6 +2247,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors") logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} From c8e165b5f314019c9b1ce13490b3d4737c26aab4 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 18 Mar 2025 23:16:46 +0200 Subject: [PATCH 05/19] remove comment --- .../train_dreambooth_lora_flux_advanced.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index cd5f271199b7..0e66fda8a455 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2243,8 +2243,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) - # save embeddings - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) if args.train_text_encoder_ti: From 710fcae49a9858ff7bada6356582cef984aee0f5 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 00:57:49 +0200 Subject: [PATCH 06/19] fix validation --- .../train_dreambooth_lora_flux_advanced.py | 32 ++++++++++++------- .../dreambooth/train_dreambooth_lora_flux.py | 1 + 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 0e66fda8a455..ed80c47f0dac 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -230,6 +230,7 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None autocast_ctx = torch.autocast(accelerator.device.type) + # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] ) @@ -2194,16 +2195,25 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: if not freeze_text_encoder: - if args.train_text_encoder: + if args.train_text_encoder: # text encoder tuning params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) elif pure_textual_inversion: - params_to_clip = itertools.chain( - text_encoder_one.parameters(), text_encoder_two.parameters() - ) + if args.enable_t5_ti: + params_to_clip = itertools.chain( + text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain( + text_encoder_one.parameters() + ) else: - params_to_clip = itertools.chain( - transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() - ) + if args.enable_t5_ti: + params_to_clip = itertools.chain( + transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain(transformer.parameters(), + text_encoder_one.parameters()) else: params_to_clip = itertools.chain(transformer.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -2260,8 +2270,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline - if freeze_text_encoder: + if freeze_text_encoder: # no text encoder one, two optimizations text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -2287,9 +2300,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if freeze_text_encoder: del text_encoder_one, text_encoder_two free_memory() - elif args.train_text_encoder: - del text_encoder_two - free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 478e73927789..b36cf6d0e6c4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -184,6 +184,7 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None autocast_ctx = torch.autocast(accelerator.device.type) + # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] ) From 0565932e7d74b14e65fb0295ddf96a78922b98d1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 09:30:37 +0200 Subject: [PATCH 07/19] add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script --- .../train_dreambooth_lora_flux_advanced.py | 14 ++++++++++---- examples/dreambooth/train_dreambooth_lora_flux.py | 13 +++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ed80c47f0dac..ac15c5ce2112 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -231,9 +231,10 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] - ) + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) images = [] for _ in range(args.num_validation_images): with autocast_ctx: @@ -2044,6 +2045,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pivoted_tr = True for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + if not freeze_text_encoder: + models_to_accumulate.extend([text_encoder_one]) + if args.enable_t5_ti: + models_to_accumulate.extend([text_encoder_two]) if pivoted_te: # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 @@ -2052,7 +2058,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): logger.info(f"PIVOT TRANSFORMER {epoch}") optimizer.param_groups[0]["lr"] = 0.0 - with accelerator.accumulate(transformer): + with accelerator.accumulate(models_to_accumulate): prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b36cf6d0e6c4..3ffa0e0ee0e6 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -185,9 +185,10 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] - ) + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) images = [] for _ in range(args.num_validation_images): with autocast_ctx: @@ -940,7 +941,7 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + dtype = unwrap_model(text_encoder).dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -983,7 +984,7 @@ def _encode_prompt_with_clip( # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=unwrap_model(text_encoder).dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -1002,7 +1003,7 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - dtype = text_encoders[0].dtype + dtype = unwrap_model(text_encoders[0]).dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], From ba4decee8ae3df550f167bec06cdb6f21445ff99 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 10:23:10 +0200 Subject: [PATCH 08/19] revert unwrap_model change temp --- examples/dreambooth/train_dreambooth_lora_flux.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3ffa0e0ee0e6..53badfd4c5b0 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -941,7 +941,7 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = unwrap_model(text_encoder).dtype + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -984,7 +984,7 @@ def _encode_prompt_with_clip( # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=unwrap_model(text_encoder).dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -1003,7 +1003,7 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - dtype = unwrap_model(text_encoders[0]).dtype + dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], From c155f22a4753a9b7d4da9a9c303b40bb19da8d1a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 12:38:50 +0200 Subject: [PATCH 09/19] add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model --- .../dreambooth/train_dreambooth_lora_flux.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 53badfd4c5b0..9d1bf03eaf4a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -941,7 +941,10 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -982,9 +985,13 @@ def _encode_prompt_with_clip( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -1003,7 +1010,11 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - dtype = text_encoders[0].dtype + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], @@ -1628,7 +1639,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1719,7 +1730,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if accelerator.unwrap_model(transformer).config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1837,9 +1848,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - transformer=accelerator.unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + transformer=unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, From 9c4368d8f978953915169e04bffad1e2fa49132e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 12:48:21 +0200 Subject: [PATCH 10/19] changes to align advanced script with canonical script --- .../train_dreambooth_lora_flux_advanced.py | 56 +++++++++++-------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ac15c5ce2112..87818bd1059d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -662,12 +662,11 @@ def parse_args(input_args=None): "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument( - "--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params" - ) + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params") parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) + parser.add_argument( "--lora_layers", type=str, @@ -677,6 +676,7 @@ def parse_args(input_args=None): 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' ), ) + parser.add_argument( "--adam_epsilon", type=float, @@ -749,6 +749,15 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) parser.add_argument( "--prior_generation_precision", type=str, @@ -1158,7 +1167,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F return text_input_ids -def _get_t5_prompt_embeds( +def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, @@ -1199,7 +1208,7 @@ def _get_t5_prompt_embeds( return prompt_embeds -def _get_clip_prompt_embeds( +def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, @@ -1249,33 +1258,32 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) dtype = text_encoders[0].dtype - pooled_prompt_embeds = _get_clip_prompt_embeds( + pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) - prompt_embeds = _get_t5_prompt_embeds( + prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, - text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1527,7 +1535,6 @@ def main(args): target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, @@ -1635,7 +1642,6 @@ def load_model_hook(models, input_dir): cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) # if we use textual inversion, we freeze all parameters except for the token embeddings @@ -1736,6 +1742,7 @@ def load_model_hook(models, input_dir): optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW + optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), @@ -2102,7 +2109,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], @@ -2141,7 +2148,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if accelerator.unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -2280,7 +2287,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) text_encoder_one.to(weight_dtype) text_encoder_two.to(weight_dtype) - pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, @@ -2300,18 +2306,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - images = None - del pipeline - - if freeze_text_encoder: + if not freeze_text_encoder: del text_encoder_one, text_encoder_two free_memory() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(weight_dtype) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: @@ -2353,8 +2362,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, - torch_dtype=weight_dtype, is_final_validation=True, + torch_dtype=weight_dtype, ) save_model_card( @@ -2377,6 +2386,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) + images = None del pipeline From 7492e92542f85176819e10b1c172fc961a234617 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 12:50:44 +0200 Subject: [PATCH 11/19] make changes for distributed training + unify unwrap_model calls in advanced script --- .../train_dreambooth_lora_flux_advanced.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 87818bd1059d..33a58b32ef99 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1196,7 +1196,10 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -1237,9 +1240,13 @@ def _encode_prompt_with_clip( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -1258,7 +1265,10 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - dtype = text_encoders[0].dtype + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], @@ -2040,7 +2050,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) elif args.train_text_encoder_ti: # textual inversion / pivotal tuning text_encoder_one.train() if args.enable_t5_ti: @@ -2148,7 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if accelerator.unwrap_model(transformer).config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -2290,9 +2300,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - transformer=accelerator.unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + transformer=unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, From 0729c6624568127f303d2bd6e60d241729cc27e3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 12:57:58 +0200 Subject: [PATCH 12/19] add module.dtype fix to dreambooth script --- examples/dreambooth/train_dreambooth_flux.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 66f533e52a8a..3df4e04b5b75 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -895,7 +895,10 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -936,9 +939,13 @@ def _encode_prompt_with_clip( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -958,7 +965,12 @@ def encode_prompt( ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - dtype = text_encoders[0].dtype + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype + device = device if device is not None else text_encoders[1].device pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], From cc1d2ad85eeae788517cac8b2950793c0b43fcbe Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 19 Mar 2025 12:59:12 +0200 Subject: [PATCH 13/19] unify unwrap_model calls in dreambooth script --- examples/dreambooth/train_dreambooth_flux.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 3df4e04b5b75..6b5adb7a10a8 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1602,7 +1602,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if accelerator.unwrap_model(transformer).config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1728,9 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False), - text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False), - transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False), + text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False), + text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False), + transformer=unwrap_model(transformer, keep_fp32_wrapper=False), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, From 8bf49c77102b3785983498dd0a865bb8834288a4 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 27 Mar 2025 11:51:06 +0200 Subject: [PATCH 14/19] fix condition in validation run --- .../train_dreambooth_lora_flux_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 33a58b32ef99..ced201411f9d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2316,7 +2316,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - if not freeze_text_encoder: + if freeze_text_encoder: del text_encoder_one, text_encoder_two free_memory() From 9b2917fd126a244fa9887c5d16b59e60db8b662d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 3 Apr 2025 17:10:04 +0300 Subject: [PATCH 15/19] mixed precision --- .../train_dreambooth_lora_flux_advanced.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ced201411f9d..9c87a39c8f0a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1661,7 +1661,8 @@ def load_model_hook(models, input_dir): for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) + if args.mixed_precision == "fp16": + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_one.append(param) else: @@ -1671,7 +1672,8 @@ def load_model_hook(models, input_dir): for name, param in text_encoder_two.named_parameters(): if "shared" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) + if args.mixed_precision == "fp16": + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_two.append(param) else: @@ -1946,6 +1948,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): lr_scheduler, ) else: + print("I SHOULD BE HERE") transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler ) From 5d249a779130f5fa00f65f712480701459a4f650 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 8 Apr 2025 14:45:28 +0300 Subject: [PATCH 16/19] Update examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_flux_advanced.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 9c87a39c8f0a..a498ee60c80c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -239,9 +239,10 @@ def log_validation( for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - generator=generator).images[0] + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator + ).images[0] images.append(image) for tracker in accelerator.trackers: From 57ee3cf033672b80b2acbc524f015715f73d21b3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 8 Apr 2025 14:50:14 +0300 Subject: [PATCH 17/19] smol style change --- examples/dreambooth/train_dreambooth_lora_flux.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9d1bf03eaf4a..29068445b72c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -193,9 +193,10 @@ def log_validation( for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - generator=generator).images[0] + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator + ).images[0] images.append(image) for tracker in accelerator.trackers: From 8b991a5c0807a1b0e632288326fb6e3ee885a51f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 8 Apr 2025 15:05:47 +0300 Subject: [PATCH 18/19] change autocast --- .../train_dreambooth_lora_flux_advanced.py | 2 +- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index a498ee60c80c..0d935ba0ad0b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -228,7 +228,7 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - autocast_ctx = torch.autocast(accelerator.device.type) + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 29068445b72c..08bf6f54a745 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -182,7 +182,7 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - autocast_ctx = torch.autocast(accelerator.device.type) + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): From bfd1df62739efffe637e510efde111282b5b4724 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 8 Apr 2025 14:08:57 +0000 Subject: [PATCH 19/19] Apply style fixes --- .../train_dreambooth_lora_flux_advanced.py | 30 ++++++++++--------- .../dreambooth/train_dreambooth_lora_flux.py | 4 +-- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 0d935ba0ad0b..f45e0a51d226 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -24,7 +24,7 @@ import shutil from contextlib import nullcontext from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional import numpy as np import torch @@ -239,9 +239,7 @@ def log_validation( for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - generator=generator + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator ).images[0] images.append(image) @@ -663,7 +661,9 @@ def parse_args(input_args=None): "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params" + ) parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) @@ -2222,7 +2222,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: if not freeze_text_encoder: - if args.train_text_encoder: # text encoder tuning + if args.train_text_encoder: # text encoder tuning params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) elif pure_textual_inversion: if args.enable_t5_ti: @@ -2230,17 +2230,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one.parameters(), text_encoder_two.parameters() ) else: - params_to_clip = itertools.chain( - text_encoder_one.parameters() - ) + params_to_clip = itertools.chain(text_encoder_one.parameters()) else: if args.enable_t5_ti: params_to_clip = itertools.chain( - transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + transformer.parameters(), + text_encoder_one.parameters(), + text_encoder_two.parameters(), ) else: - params_to_clip = itertools.chain(transformer.parameters(), - text_encoder_one.parameters()) + params_to_clip = itertools.chain( + transformer.parameters(), text_encoder_one.parameters() + ) else: params_to_clip = itertools.chain(transformer.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -2284,7 +2285,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.save_state(save_path) if args.train_text_encoder_ti: embedding_handler.save_embeddings( - f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors") + f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors" + ) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -2297,7 +2299,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline - if freeze_text_encoder: # no text encoder one, two optimizations + if freeze_text_encoder: # no text encoder one, two optimizations text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) text_encoder_one.to(weight_dtype) text_encoder_two.to(weight_dtype) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 08bf6f54a745..debdafd04ba1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -193,9 +193,7 @@ def log_validation( for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - generator=generator + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator ).images[0] images.append(image)