diff --git a/examples/advanced_diffusion_training/requirements.txt b/examples/advanced_diffusion_training/requirements.txt index 3f86855e1d1e..dbc124ff6526 100644 --- a/examples/advanced_diffusion_training/requirements.txt +++ b/examples/advanced_diffusion_training/requirements.txt @@ -1,7 +1,8 @@ -accelerate>=0.16.0 +accelerate>=0.31.0 torchvision -transformers>=4.25.1 +transformers>=4.41.2 ftfy tensorboard Jinja2 -peft==0.7.0 \ No newline at end of file +peft>=0.11.1 +sentencepiece \ No newline at end of file diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index b8194507d822..f45e0a51d226 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -24,7 +24,7 @@ import shutil from contextlib import nullcontext from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional import numpy as np import torch @@ -228,10 +228,20 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -657,6 +667,7 @@ def parse_args(input_args=None): parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) + parser.add_argument( "--lora_layers", type=str, @@ -666,6 +677,7 @@ def parse_args(input_args=None): 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' ), ) + parser.add_argument( "--adam_epsilon", type=float, @@ -738,6 +750,15 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) parser.add_argument( "--prior_generation_precision", type=str, @@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F return text_input_ids -def _get_t5_prompt_embeds( +def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, @@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds( return prompt_embeds -def _get_clip_prompt_embeds( +def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, @@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -1238,136 +1266,35 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - dtype = text_encoders[0].dtype + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype - pooled_prompt_embeds = _get_clip_prompt_embeds( + pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, - text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, + text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) - prompt_embeds = _get_t5_prompt_embeds( + prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, - text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None, + text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids -# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: -# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 -class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - with torch.no_grad(): - # create weights for timesteps - num_timesteps = 1000 - - # generate the multiplier based on cosmap loss weighing - # this is only used on linear timesteps for now - - # cosine map weighing is higher in the middle and lower at the ends - # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 - # cosmap_weighing = 2 / (math.pi * bot) - - # sigma sqrt weighing is significantly higher at the end and lower at the beginning - sigma_sqrt_weighing = (self.sigmas**-2.0).float() - # clip at 1e4 (1e6 is too high) - sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) - # bring to a mean of 1 - sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() - - # Create linear timesteps from 1000 to 0 - timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") - - self.linear_timesteps = timesteps - # self.linear_timesteps_weights = cosmap_weighing - self.linear_timesteps_weights = sigma_sqrt_weighing - - # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') - pass - - def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: - # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] - - # Get the weights for the timesteps - weights = self.linear_timesteps_weights[step_indices].flatten() - - return weights - - def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: - sigmas = self.sigmas.to(device=device, dtype=dtype) - schedule_timesteps = self.timesteps.to(device) - timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - - return sigma - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise - t_01 = (timesteps / 1000).to(original_samples.device) - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise - return noisy_model_input - - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: - return sample - - def set_train_timesteps(self, num_timesteps, device, linear=False): - if linear: - timesteps = torch.linspace(1000, 0, num_timesteps, device=device) - self.timesteps = timesteps - return timesteps - else: - # distribute them closer to center. Inference distributes them as a bias toward first - # Generate values from 0 to 1 - t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) - - # Scale and reverse the values to go from 1000 to 0 - timesteps = (1 - t) * 1000 - - # Sort the timesteps in descending order - timesteps, _ = torch.sort(timesteps, descending=True) - - self.timesteps = timesteps.to(device=device) - - return timesteps - - def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1499,7 +1426,7 @@ def main(args): ) # Load scheduler and models - noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) noise_scheduler_copy = copy.deepcopy(noise_scheduler) @@ -1619,7 +1546,6 @@ def main(args): target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, @@ -1727,7 +1653,6 @@ def load_model_hook(models, input_dir): cast_training_params(models, dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) # if we use textual inversion, we freeze all parameters except for the token embeddings @@ -1737,7 +1662,8 @@ def load_model_hook(models, input_dir): for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) + if args.mixed_precision == "fp16": + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_one.append(param) else: @@ -1747,7 +1673,8 @@ def load_model_hook(models, input_dir): for name, param in text_encoder_two.named_parameters(): if "shared" in name: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - param.data = param.to(dtype=torch.float32) + if args.mixed_precision == "fp16": + param.data = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_two.append(param) else: @@ -1828,6 +1755,7 @@ def load_model_hook(models, input_dir): optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW + optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), @@ -2021,6 +1949,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): lr_scheduler, ) else: + print("I SHOULD BE HERE") transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler ) @@ -2125,7 +2054,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) elif args.train_text_encoder_ti: # textual inversion / pivotal tuning text_encoder_one.train() if args.enable_t5_ti: @@ -2137,6 +2066,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pivoted_tr = True for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + if not freeze_text_encoder: + models_to_accumulate.extend([text_encoder_one]) + if args.enable_t5_ti: + models_to_accumulate.extend([text_encoder_two]) if pivoted_te: # stopping optimization of text_encoder params optimizer.param_groups[te_idx]["lr"] = 0.0 @@ -2145,7 +2079,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): logger.info(f"PIVOT TRANSFORMER {epoch}") optimizer.param_groups[0]["lr"] = 0.0 - with accelerator.accumulate(transformer): + with accelerator.accumulate(models_to_accumulate): prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -2189,7 +2123,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], @@ -2228,7 +2162,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -2288,16 +2222,26 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.backward(loss) if accelerator.sync_gradients: if not freeze_text_encoder: - if args.train_text_encoder: + if args.train_text_encoder: # text encoder tuning params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) elif pure_textual_inversion: - params_to_clip = itertools.chain( - text_encoder_one.parameters(), text_encoder_two.parameters() - ) + if args.enable_t5_ti: + params_to_clip = itertools.chain( + text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain(text_encoder_one.parameters()) else: - params_to_clip = itertools.chain( - transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() - ) + if args.enable_t5_ti: + params_to_clip = itertools.chain( + transformer.parameters(), + text_encoder_one.parameters(), + text_encoder_two.parameters(), + ) + else: + params_to_clip = itertools.chain( + transformer.parameters(), text_encoder_one.parameters() + ) else: params_to_clip = itertools.chain(transformer.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -2339,6 +2283,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors" + ) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -2351,14 +2299,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline - if freeze_text_encoder: + if freeze_text_encoder: # no text encoder one, two optimizations text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - transformer=accelerator.unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + transformer=unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -2372,21 +2322,21 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - images = None - del pipeline - if freeze_text_encoder: del text_encoder_one, text_encoder_two free_memory() - elif args.train_text_encoder: - del text_encoder_two - free_memory() + + images = None + del pipeline # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(weight_dtype) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: @@ -2428,8 +2378,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, - torch_dtype=weight_dtype, is_final_validation=True, + torch_dtype=weight_dtype, ) save_model_card( @@ -2452,6 +2402,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) + images = None del pipeline diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 66f533e52a8a..6b5adb7a10a8 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -895,7 +895,10 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -936,9 +939,13 @@ def _encode_prompt_with_clip( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -958,7 +965,12 @@ def encode_prompt( ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - dtype = text_encoders[0].dtype + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype + device = device if device is not None else text_encoders[1].device pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], @@ -1590,7 +1602,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if accelerator.unwrap_model(transformer).config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1716,9 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False), - text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False), - transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False), + text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False), + text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False), + transformer=unwrap_model(transformer, keep_fp32_wrapper=False), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index dda3300d65cc..debdafd04ba1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -177,16 +177,25 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - autocast_ctx = nullcontext() + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] + ) + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" @@ -203,8 +212,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() return images @@ -932,7 +940,10 @@ def _encode_prompt_with_t5( prompt_embeds = text_encoder(text_input_ids.to(device))[0] - dtype = text_encoder.dtype + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -973,9 +984,13 @@ def _encode_prompt_with_clip( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + if hasattr(text_encoder, "module"): + dtype = text_encoder.module.dtype + else: + dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output - prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -994,7 +1009,11 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - dtype = text_encoders[0].dtype + + if hasattr(text_encoders[0], "module"): + dtype = text_encoders[0].module.dtype + else: + dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], @@ -1619,7 +1638,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: text_encoder_one.train() # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] @@ -1710,7 +1729,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if accelerator.unwrap_model(transformer).config.guidance_embeds: + if unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1828,9 +1847,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, - text_encoder=accelerator.unwrap_model(text_encoder_one), - text_encoder_2=accelerator.unwrap_model(text_encoder_two), - transformer=accelerator.unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder_one), + text_encoder_2=unwrap_model(text_encoder_two), + transformer=unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype,