From de9c2d0a645701a7cf0229badde13eddb4935596 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 25 Dec 2023 10:19:40 +0530 Subject: [PATCH 01/16] freeinit --- .../pipelines/animatediff/freeinit_utils.py | 124 ++++++++++++++ .../animatediff/pipeline_animatediff.py | 157 ++++++++++++++---- 2 files changed, 248 insertions(+), 33 deletions(-) create mode 100644 src/diffusers/pipelines/animatediff/freeinit_utils.py diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py new file mode 100644 index 000000000000..1e1071cc25ae --- /dev/null +++ b/src/diffusers/pipelines/animatediff/freeinit_utils.py @@ -0,0 +1,124 @@ +import math + +import torch +import torch.fft as fft + + +def freq_mix_3d(x, noise, LPF): + """ + Noise reinitialization. + + Args: + x: diffused latent + noise: randomly sampled noise + LPF: low pass filter + """ + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed + + +def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the gaussian low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ((d_s / d_t) * (2 * t / T - 1)) ** 2 + (2 * h / H - 1) ** 2 + (2 * w / W - 1) ** 2 + mask[..., t, h, w] = math.exp(-1 / (2 * d_s**2) * d_square) + return mask + + +def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): + """ + Compute the butterworth low pass filter mask. + + Args: + shape: shape of the filter (volume) + n: order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ((d_s / d_t) * (2 * t / T - 1)) ** 2 + (2 * h / H - 1) ** 2 + (2 * w / W - 1) ** 2 + mask[..., t, h, w] = 1 / (1 + (d_square / d_s**2) ** n) + return mask + + +def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask. + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ((d_s / d_t) * (2 * t / T - 1)) ** 2 + (2 * h / H - 1) ** 2 + (2 * w / W - 1) ** 2 + mask[..., t, h, w] = 1 if d_square <= d_s * 2 else 0 + return mask + + +def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): + """ + Compute the ideal low pass filter mask (approximated version). + + Args: + shape: shape of the filter (volume) + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + if d_s == 0 or d_t == 0: + return mask + + threshold_s = round(int(H // 2) * d_s) + threshold_t = round(T // 2 * d_t) + + cframe, crow, ccol = T // 2, H // 2, W // 2 + mask[ + ..., + cframe - threshold_t : cframe + threshold_t, + crow - threshold_s : crow + threshold_s, + ccol - threshold_s : ccol + threshold_s, + ] = 1.0 + + return mask diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 68b358f7645c..33b7032536cf 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -36,6 +36,13 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .freeinit_utils import ( + box_low_pass_filter, + butterworth_low_pass_filter, + freq_mix_3d, + gaussian_low_pass_filter, + ideal_low_pass_filter, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -72,6 +79,29 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): return outputs +def get_freeinit_freq_filter(shape, device, filter_type, n, d_s, d_t): + """ + Form the frequency filter for noise reinitialization. + + Args: + shape: shape of latent (B, C, T, H, W) + filter_type: type of the freq filter + n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian + d_s: normalized stop frequency for spatial dimensions (0.0-1.0) + d_t: normalized stop frequency for temporal dimension (0.0-1.0) + """ + if filter_type == "gaussian": + return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "ideal": + return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "box": + return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + elif filter_type == "butterworth": + return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) + else: + raise NotImplementedError + + @dataclass class AnimateDiffPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] @@ -555,6 +585,8 @@ def __call__( callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, + use_freeinit: bool = False, + freeinit_kwargs: Optional[Dict[str, Any]] = None, ): r""" The call function to the pipeline for generation. @@ -695,44 +727,103 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7 Add image embeds for IP-Adapter + + # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # 7.1 FreeInit + freeinit_num_steps = freeinit_kwargs.get("num_steps", 5) if use_freeinit else 1 + freeinit_use_fast_sampling = freeinit_kwargs.get("use_fast_sampling", False) + + if use_freeinit: + if "method" not in freeinit_kwargs.keys(): + raise ValueError("`use_freeinit` was set to True but required freeinit kwarg `method` was not set") + + freeinit_filter_shape = ( + 1, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + freeinit_freq_filter = get_freeinit_freq_filter( + shape=freeinit_filter_shape, + device=device, + filter_type=freeinit_kwargs.get("method"), + n=freeinit_kwargs.get("n", 4) if freeinit_kwargs.get("method") == "butterworth" else None, + d_s=freeinit_kwargs.get("d_s", 0.25), + d_t=freeinit_kwargs.get("d_t", 0.25), + ) + + # 8. Denoising loop + with self.progress_bar(total=freeinit_num_steps) as freeinit_steps_bar: + for freeinit_iter in range(freeinit_num_steps): + if freeinit_iter == 0: + initial_noise = latents.detach().clone() + else: + current_diffuse_timestep = ( + self.scheduler.config.num_train_timesteps - 1 + ) # diffuse to t=999 noise level + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = torch.randn( + ( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ), + device=device, + ) + latents = freq_mix_3d(z_T, z_rand, LPF=freeinit_freq_filter) + latents = latents.to(prompt_embeds.dtype) + + if freeinit_use_fast_sampling: + current_num_inference_steps = int(num_inference_steps / freeinit_num_steps * (freeinit_iter + 1)) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + freeinit_steps_bar.update() if output_type == "latent": return AnimateDiffPipelineOutput(frames=latents) - # Post-processing + # 9. Post-processing video_tensor = self.decode_latents(latents) if output_type == "pt": @@ -740,7 +831,7 @@ def __call__( else: video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - # Offload all models + # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: From eed328b32182d9a691ca86b3e5e642b48511ecf2 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Dec 2023 05:47:37 +0530 Subject: [PATCH 02/16] update freeinit implementation based on review Co-Authored-By: Dhruv Nair --- .../animatediff/pipeline_animatediff.py | 332 ++++++++++++------ 1 file changed, 221 insertions(+), 111 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 33b7032536cf..7e349d869888 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -33,7 +33,14 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .freeinit_utils import ( @@ -79,25 +86,19 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): return outputs -def get_freeinit_freq_filter(shape, device, filter_type, n, d_s, d_t): - """ - Form the frequency filter for noise reinitialization. - - Args: - shape: shape of latent (B, C, T, H, W) - filter_type: type of the freq filter - n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian - d_s: normalized stop frequency for spatial dimensions (0.0-1.0) - d_t: normalized stop frequency for temporal dimension (0.0-1.0) - """ +def _get_freeinit_freq_filter(shape, device, filter_type, order, spatial_stop_frequency, temporal_stop_frequency): if filter_type == "gaussian": - return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + return gaussian_low_pass_filter(shape=shape, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency).to( + device + ) elif filter_type == "ideal": - return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + return ideal_low_pass_filter(shape=shape, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency).to(device) elif filter_type == "box": - return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) + return box_low_pass_filter(shape=shape, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency).to(device) elif filter_type == "butterworth": - return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) + return butterworth_low_pass_filter( + shape=shape, n=order, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency + ).to(device) else: raise NotImplementedError @@ -136,7 +137,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder"] def __init__( @@ -465,6 +466,62 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + @property + def free_init_enabled(self): + return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None + + def enable_freeinit( + self, + num_iters: int = 3, + use_fast_sampling: bool = False, + method: str = "butterworth", + order: int = 4, + spatial_stop_frequency: float = 0.25, + temporal_stop_frequency: float = 0.25, + generator: torch.Generator = None, + return_intermediate_results: bool = False, + ): + """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. + + This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). + + Args: + num_iters (`int`, *optional*, defaults to `3`): + Number of FreeInit noise re-initialization iterations. + use_fast_sampling (`bool`, *optional*, defaults to `False`): + Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables + the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. + method (`str`, *optional*, defaults to `butterworth`): + Must be one of `box`, `butterworth`, `ideal` or `gaussian` to use as the filtering method for the + FreeInit low pass filter. + order (`int`, *optional*, defaults to `4`): + Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour + whereas lower values lead to `gaussian` method behaviour. + spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in + the original implementation. + temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): + Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in + the original implementation. + generator (`torch.Generator`, *optional*, defaults to `0.25`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + FreeInit generation deterministic. + return_intermediate_results (`bool`, *optional*, defaults to `False`): + Whether or not to return intermediate sampling results for every FreeInit iteration. + """ + self._free_init_num_iters = num_iters + self._free_init_use_fast_sampling = use_fast_sampling + self._free_init_method = method + self._free_init_order = order + self._free_init_spatial_stop_frequency = spatial_stop_frequency + self._free_init_temporal_stop_frequency = temporal_stop_frequency + self._free_init_generator = generator + self._free_init_return_intermediate_results = return_intermediate_results + + def disable_freeinit(self): + """Disables the FreeInit mechanism if enabled.""" + self._free_init_num_iters = None + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -562,7 +619,72 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents + def _denoise_loop( + self, + timesteps, + num_inference_steps, + do_classifier_free_guidance, + guidance_scale, + num_warmup_steps, + prompt_embeds, + latents, + cross_attention_kwargs, + added_cond_kwargs, + extra_step_kwargs, + callback, + callback_steps, + ): + """Denoising loop for AnimateDiff.""" + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + return latents + + def _retrieve_output(self, latents, output_type, return_dict): + """Helper function to handle latents to output conversion.""" + if output_type == "latent": + return AnimateDiffPipelineOutput(frames=latents) + + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, @@ -585,8 +707,6 @@ def __call__( callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, - use_freeinit: bool = False, - freeinit_kwargs: Optional[Dict[str, Any]] = None, ): r""" The call function to the pipeline for generation. @@ -627,7 +747,8 @@ def __call__( negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -731,110 +852,99 @@ def __call__( # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # 7.1 FreeInit - freeinit_num_steps = freeinit_kwargs.get("num_steps", 5) if use_freeinit else 1 - freeinit_use_fast_sampling = freeinit_kwargs.get("use_fast_sampling", False) - - if use_freeinit: - if "method" not in freeinit_kwargs.keys(): - raise ValueError("`use_freeinit` was set to True but required freeinit kwarg `method` was not set") - - freeinit_filter_shape = ( + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args = { + "timesteps": timesteps, + "num_inference_steps": num_inference_steps, + "do_classifier_free_guidance": do_classifier_free_guidance, + "guidance_scale": guidance_scale, + "num_warmup_steps": num_warmup_steps, + "prompt_embeds": prompt_embeds, + "latents": latents, + "cross_attention_kwargs": cross_attention_kwargs, + "added_cond_kwargs": added_cond_kwargs, + "extra_step_kwargs": extra_step_kwargs, + "callback": callback, + "callback_steps": callback_steps, + } + + if self.free_init_enabled: + video = [] + free_init_filter_shape = ( 1, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) - freeinit_freq_filter = get_freeinit_freq_filter( - shape=freeinit_filter_shape, + free_init_freq_filter = _get_freeinit_freq_filter( + shape=free_init_filter_shape, device=device, - filter_type=freeinit_kwargs.get("method"), - n=freeinit_kwargs.get("n", 4) if freeinit_kwargs.get("method") == "butterworth" else None, - d_s=freeinit_kwargs.get("d_s", 0.25), - d_t=freeinit_kwargs.get("d_t", 0.25), + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, ) - - # 8. Denoising loop - with self.progress_bar(total=freeinit_num_steps) as freeinit_steps_bar: - for freeinit_iter in range(freeinit_num_steps): - if freeinit_iter == 0: - initial_noise = latents.detach().clone() - else: - current_diffuse_timestep = ( - self.scheduler.config.num_train_timesteps - 1 - ) # diffuse to t=999 noise level - diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - z_T = self.scheduler.add_noise( - original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) - ).to(dtype=torch.float32) - z_rand = torch.randn( - ( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ), - device=device, + with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: + for i in range(self._free_init_num_iters): + # For the first FreeInit iteration, the original latent is used without modification. + # Subsequent iterations apply the noise reinitialization technique. + if i == 0: + initial_noise = latents.detach().clone() + else: + current_diffuse_timestep = ( + self.scheduler.config.num_train_timesteps - 1 + ) # diffuse to t=999 noise level + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = randn_tensor( + shape=( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ), + generator=self._free_init_generator, + device=device, + dtype=torch.float32, + ) + latents = freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = latents.to(prompt_embeds.dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args.update( + { + "timesteps": timesteps, + "num_inference_steps": current_num_inference_steps, + "latents": latents, + } ) - latents = freq_mix_3d(z_T, z_rand, LPF=freeinit_freq_filter) - latents = latents.to(prompt_embeds.dtype) - - if freeinit_use_fast_sampling: - current_num_inference_steps = int(num_inference_steps / freeinit_num_steps * (freeinit_iter + 1)) - self.scheduler.set_timesteps(current_num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - freeinit_steps_bar.update() + self._denoise_loop(**denoise_args) - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) + # Whether or not to return intermediate generation results + if self._free_init_return_intermediate_results: + intermediate_video = self._retrieve_output(latents, output_type, return_dict) + video.append(intermediate_video) - # 9. Post-processing - video_tensor = self.decode_latents(latents) + free_init_progress_bar.update() - if output_type == "pt": - video = video_tensor + if not self._free_init_return_intermediate_results: + video = self._retrieve_output(latents, output_type, return_dict) else: - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + latents = self._denoise_loop(**denoise_args) + video = self._retrieve_output(latents, output_type, return_dict) - # 10. Offload all models + # 9. Offload all models self.maybe_free_model_hooks() - if not return_dict: - return (video,) - - return AnimateDiffPipelineOutput(frames=video) + return video From dab6a6009b98f0abadd0dec56a75eee155c440d0 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Dec 2023 05:56:20 +0530 Subject: [PATCH 03/16] fix --- .../pipelines/animatediff/pipeline_animatediff.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b488af585ba4..cacd6558d642 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -470,7 +470,7 @@ def disable_freeu(self): def free_init_enabled(self): return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None - def enable_freeinit( + def enable_free_init( self, num_iters: int = 3, use_fast_sampling: bool = False, @@ -518,7 +518,7 @@ def enable_freeinit( self._free_init_generator = generator self._free_init_return_intermediate_results = return_intermediate_results - def disable_freeinit(self): + def disable_free_init(self): """Disables the FreeInit mechanism if enabled.""" self._free_init_num_iters = None @@ -920,15 +920,12 @@ def __call__( current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) self.scheduler.set_timesteps(current_num_inference_steps, device=device) timesteps = self.scheduler.timesteps + denoise_args.update( + {"timesteps": timesteps, "num_inference_steps": current_num_inference_steps} + ) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args.update( - { - "timesteps": timesteps, - "num_inference_steps": current_num_inference_steps, - "latents": latents, - } - ) + denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) self._denoise_loop(**denoise_args) # Whether or not to return intermediate generation results From 009f3eab4337d2fc8339ec8f262e25d4336babf2 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 28 Dec 2023 06:06:45 +0530 Subject: [PATCH 04/16] another fix --- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index cacd6558d642..f1aef385b780 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -926,7 +926,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) - self._denoise_loop(**denoise_args) + latents = self._denoise_loop(**denoise_args) # Whether or not to return intermediate generation results if self._free_init_return_intermediate_results: From 574197bec98b6ac8e6018f9a8962ac7037749025 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 31 Dec 2023 10:53:42 +0530 Subject: [PATCH 05/16] refactor --- .../animatediff/pipeline_animatediff.py | 164 +++++++++++------- 1 file changed, 99 insertions(+), 65 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index f1aef385b780..3989a9f1e17e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -666,6 +666,95 @@ def _denoise_loop( return latents + def _free_init_loop( + self, + height, + width, + num_frames, + num_channels_latents, + batch_size, + num_videos_per_prompt, + denoise_args, + device, + output_type, + return_dict, + ): + """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" + + latents = denoise_args.get("latents") + prompt_embeds = denoise_args.get("prompt_embeds") + num_inference_steps = denoise_args.get("num_inference_steps") + + video = [] + latent_shape = ( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + free_init_filter_shape = ( + 1, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + free_init_freq_filter = _get_freeinit_freq_filter( + shape=free_init_filter_shape, + device=device, + filter_type=self._free_init_method, + order=self._free_init_order, + spatial_stop_frequency=self._free_init_spatial_stop_frequency, + temporal_stop_frequency=self._free_init_temporal_stop_frequency, + ) + + with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: + for i in range(self._free_init_num_iters): + # For the first FreeInit iteration, the original latent is used without modification. + # Subsequent iterations apply the noise reinitialization technique. + if i == 0: + initial_noise = latents.detach().clone() + else: + current_diffuse_timestep = ( + self.scheduler.config.num_train_timesteps - 1 + ) # diffuse to t=999 noise level + diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() + z_T = self.scheduler.add_noise( + original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) + ).to(dtype=torch.float32) + z_rand = randn_tensor( + shape=latent_shape, + generator=self._free_init_generator, + device=device, + dtype=torch.float32, + ) + latents = freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = latents.to(prompt_embeds.dtype) + + # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) + if self._free_init_use_fast_sampling: + current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) + self.scheduler.set_timesteps(current_num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps}) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) + latents = self._denoise_loop(**denoise_args) + + # Whether or not to return intermediate generation results + if self._free_init_return_intermediate_results: + intermediate_video = self._retrieve_output(latents, output_type, return_dict) + video.append(intermediate_video) + + free_init_progress_bar.update() + + if not self._free_init_return_intermediate_results: + video = self._retrieve_output(latents, output_type, return_dict) + + return video + def _retrieve_output(self, latents, output_type, return_dict): """Helper function to handle latents to output conversion.""" if output_type == "latent": @@ -870,73 +959,18 @@ def __call__( } if self.free_init_enabled: - video = [] - free_init_filter_shape = ( - 1, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - free_init_freq_filter = _get_freeinit_freq_filter( - shape=free_init_filter_shape, + video = self._free_init_loop( + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channels_latents, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + denoise_args=denoise_args, device=device, - filter_type=self._free_init_method, - order=self._free_init_order, - spatial_stop_frequency=self._free_init_spatial_stop_frequency, - temporal_stop_frequency=self._free_init_temporal_stop_frequency, + output_type=output_type, + return_dict=return_dict, ) - with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar: - for i in range(self._free_init_num_iters): - # For the first FreeInit iteration, the original latent is used without modification. - # Subsequent iterations apply the noise reinitialization technique. - if i == 0: - initial_noise = latents.detach().clone() - else: - current_diffuse_timestep = ( - self.scheduler.config.num_train_timesteps - 1 - ) # diffuse to t=999 noise level - diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long() - z_T = self.scheduler.add_noise( - original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device) - ).to(dtype=torch.float32) - z_rand = randn_tensor( - shape=( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ), - generator=self._free_init_generator, - device=device, - dtype=torch.float32, - ) - latents = freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) - latents = latents.to(prompt_embeds.dtype) - - # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) - if self._free_init_use_fast_sampling: - current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1)) - self.scheduler.set_timesteps(current_num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - denoise_args.update( - {"timesteps": timesteps, "num_inference_steps": current_num_inference_steps} - ) - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) - latents = self._denoise_loop(**denoise_args) - - # Whether or not to return intermediate generation results - if self._free_init_return_intermediate_results: - intermediate_video = self._retrieve_output(latents, output_type, return_dict) - video.append(intermediate_video) - - free_init_progress_bar.update() - - if not self._free_init_return_intermediate_results: - video = self._retrieve_output(latents, output_type, return_dict) else: latents = self._denoise_loop(**denoise_args) video = self._retrieve_output(latents, output_type, return_dict) From 0dc875316677782b8ff26c3cef49c71ac61c6c66 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 31 Dec 2023 10:56:21 +0530 Subject: [PATCH 06/16] fix timesteps missing bug --- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 3989a9f1e17e..b9d94e55ea41 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -683,6 +683,7 @@ def _free_init_loop( latents = denoise_args.get("latents") prompt_embeds = denoise_args.get("prompt_embeds") + timesteps = denoise_args.get("timesteps") num_inference_steps = denoise_args.get("num_inference_steps") video = [] From 54c370539096209109bfa01be5be9384141a2901 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 4 Jan 2024 20:01:46 +0530 Subject: [PATCH 07/16] apply suggestions from review Co-Authored-By: Dhruv Nair --- .../animatediff/pipeline_animatediff.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index b9d94e55ea41..94e51fe5d63c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -676,8 +676,6 @@ def _free_init_loop( num_videos_per_prompt, denoise_args, device, - output_type, - return_dict, ): """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique.""" @@ -686,7 +684,7 @@ def _free_init_loop( timesteps = denoise_args.get("timesteps") num_inference_steps = denoise_args.get("num_inference_steps") - video = [] + video_latents = [] latent_shape = ( batch_size * num_videos_per_prompt, num_channels_latents, @@ -746,17 +744,16 @@ def _free_init_loop( # Whether or not to return intermediate generation results if self._free_init_return_intermediate_results: - intermediate_video = self._retrieve_output(latents, output_type, return_dict) - video.append(intermediate_video) + video_latents.append(latents) free_init_progress_bar.update() if not self._free_init_return_intermediate_results: - video = self._retrieve_output(latents, output_type, return_dict) + video_latents = latents - return video + return video_latents - def _retrieve_output(self, latents, output_type, return_dict): + def _retrieve_video_frames(self, latents, output_type, return_dict): """Helper function to handle latents to output conversion.""" if output_type == "latent": return AnimateDiffPipelineOutput(frames=latents) @@ -960,7 +957,7 @@ def __call__( } if self.free_init_enabled: - video = self._free_init_loop( + latents = self._free_init_loop( height=height, width=width, num_frames=num_frames, @@ -969,12 +966,14 @@ def __call__( num_videos_per_prompt=num_videos_per_prompt, denoise_args=denoise_args, device=device, - output_type=output_type, - return_dict=return_dict, ) + if self._free_init_return_intermediate_results: + video = [self._retrieve_video_frames(latent, output_type, return_dict) for latent in latents] + else: + video = self._retrieve_video_frames(latents, output_type, return_dict) else: latents = self._denoise_loop(**denoise_args) - video = self._retrieve_output(latents, output_type, return_dict) + video = self._retrieve_video_frames(latents, output_type, return_dict) # 9. Offload all models self.maybe_free_model_hooks() From ae0656520202c5cd3c6ef20a857b869287ef0ad1 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 4 Jan 2024 20:02:25 +0530 Subject: [PATCH 08/16] add test for freeinit --- .../pipelines/animatediff/test_animatediff.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 633ed9fc233e..9bbfc768d2bc 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -233,6 +233,44 @@ def test_prompt_embeds(self): inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device) pipe(**inputs) + def test_free_init(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + free_init_generator = torch.Generator(device=torch_device).manual_seed(0) + pipe.enable_free_init( + num_iters=2, + use_fast_sampling=True, + method="butterworth", + order=4, + spatial_stop_frequency=0.25, + temporal_stop_frequency=0.25, + generator=free_init_generator, + return_intermediate_results=False, + ) + inputs_enable_free_init = self.get_dummy_inputs(torch_device) + frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] + + pipe.disable_free_init() + inputs_disable_free_init = self.get_dummy_inputs(torch_device) + frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max() + self.assertGreater( + sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results" + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeInit should lead to results similar to the default pipeline results", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", From 8d3c7e662449eaa0823488daa9014633a0a78251 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 4 Jan 2024 20:19:38 +0530 Subject: [PATCH 09/16] apply suggestions from review Co-Authored-By: Dhruv Nair --- .../animatediff/pipeline_animatediff.py | 119 ++++++++++++++---- 1 file changed, 95 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 94e51fe5d63c..d2519b2ed76d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -36,6 +36,7 @@ from ...utils import ( USE_PEFT_BACKEND, BaseOutput, + deprecate, logging, replace_example_docstring, scale_lora_layers, @@ -139,6 +140,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, @@ -627,12 +629,15 @@ def _denoise_loop( guidance_scale, num_warmup_steps, prompt_embeds, + negative_prompt_embeds, latents, cross_attention_kwargs, added_cond_kwargs, extra_step_kwargs, callback, callback_steps, + callback_on_step_end, + callback_on_step_end_tensor_inputs, ): """Denoising loop for AnimateDiff.""" with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -658,6 +663,16 @@ def _denoise_loop( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -770,6 +785,29 @@ def _retrieve_video_frames(self, latents, output_type, return_dict): return AnimateDiffPipelineOutput(frames=video) + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -790,10 +828,11 @@ def __call__( ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, ): r""" The call function to the pipeline for generation. @@ -842,18 +881,22 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + Examples: Returns: @@ -861,6 +904,23 @@ def __call__( If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -869,9 +929,20 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -881,30 +952,26 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, - clip_skip=clip_skip, + clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: @@ -912,12 +979,13 @@ def __call__( image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, num_videos_per_prompt, output_hidden_state ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -944,16 +1012,19 @@ def __call__( denoise_args = { "timesteps": timesteps, "num_inference_steps": num_inference_steps, - "do_classifier_free_guidance": do_classifier_free_guidance, + "do_classifier_free_guidance": self.do_classifier_free_guidance, "guidance_scale": guidance_scale, "num_warmup_steps": num_warmup_steps, "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, "latents": latents, - "cross_attention_kwargs": cross_attention_kwargs, + "cross_attention_kwargs": self.cross_attention_kwargs, "added_cond_kwargs": added_cond_kwargs, "extra_step_kwargs": extra_step_kwargs, "callback": callback, "callback_steps": callback_steps, + "callback_on_step_end": callback_on_step_end, + "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs, } if self.free_init_enabled: @@ -967,12 +1038,12 @@ def __call__( denoise_args=denoise_args, device=device, ) - if self._free_init_return_intermediate_results: - video = [self._retrieve_video_frames(latent, output_type, return_dict) for latent in latents] - else: - video = self._retrieve_video_frames(latents, output_type, return_dict) else: latents = self._denoise_loop(**denoise_args) + + if self.free_init_enabled and self._free_init_return_intermediate_results: + video = [self._retrieve_video_frames(latent, output_type, return_dict) for latent in latents] + else: video = self._retrieve_video_frames(latents, output_type, return_dict) # 9. Offload all models From 30b629c74867af4777585cbe5f7ce9da2aea6fea Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 4 Jan 2024 20:47:01 +0530 Subject: [PATCH 10/16] refactor --- .../pipelines/animatediff/freeinit_utils.py | 124 ------------------ .../animatediff/pipeline_animatediff.py | 82 +++++++++--- 2 files changed, 62 insertions(+), 144 deletions(-) delete mode 100644 src/diffusers/pipelines/animatediff/freeinit_utils.py diff --git a/src/diffusers/pipelines/animatediff/freeinit_utils.py b/src/diffusers/pipelines/animatediff/freeinit_utils.py deleted file mode 100644 index 1e1071cc25ae..000000000000 --- a/src/diffusers/pipelines/animatediff/freeinit_utils.py +++ /dev/null @@ -1,124 +0,0 @@ -import math - -import torch -import torch.fft as fft - - -def freq_mix_3d(x, noise, LPF): - """ - Noise reinitialization. - - Args: - x: diffused latent - noise: randomly sampled noise - LPF: low pass filter - """ - # FFT - x_freq = fft.fftn(x, dim=(-3, -2, -1)) - x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) - noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) - noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) - - # frequency mix - HPF = 1 - LPF - x_freq_low = x_freq * LPF - noise_freq_high = noise_freq * HPF - x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain - - # IFFT - x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) - x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real - - return x_mixed - - -def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): - """ - Compute the gaussian low pass filter mask. - - Args: - shape: shape of the filter (volume) - d_s: normalized stop frequency for spatial dimensions (0.0-1.0) - d_t: normalized stop frequency for temporal dimension (0.0-1.0) - """ - T, H, W = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - if d_s == 0 or d_t == 0: - return mask - for t in range(T): - for h in range(H): - for w in range(W): - d_square = ((d_s / d_t) * (2 * t / T - 1)) ** 2 + (2 * h / H - 1) ** 2 + (2 * w / W - 1) ** 2 - mask[..., t, h, w] = math.exp(-1 / (2 * d_s**2) * d_square) - return mask - - -def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): - """ - Compute the butterworth low pass filter mask. - - Args: - shape: shape of the filter (volume) - n: order of the filter, larger n ~ ideal, smaller n ~ gaussian - d_s: normalized stop frequency for spatial dimensions (0.0-1.0) - d_t: normalized stop frequency for temporal dimension (0.0-1.0) - """ - T, H, W = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - if d_s == 0 or d_t == 0: - return mask - for t in range(T): - for h in range(H): - for w in range(W): - d_square = ((d_s / d_t) * (2 * t / T - 1)) ** 2 + (2 * h / H - 1) ** 2 + (2 * w / W - 1) ** 2 - mask[..., t, h, w] = 1 / (1 + (d_square / d_s**2) ** n) - return mask - - -def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): - """ - Compute the ideal low pass filter mask. - - Args: - shape: shape of the filter (volume) - d_s: normalized stop frequency for spatial dimensions (0.0-1.0) - d_t: normalized stop frequency for temporal dimension (0.0-1.0) - """ - T, H, W = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - if d_s == 0 or d_t == 0: - return mask - for t in range(T): - for h in range(H): - for w in range(W): - d_square = ((d_s / d_t) * (2 * t / T - 1)) ** 2 + (2 * h / H - 1) ** 2 + (2 * w / W - 1) ** 2 - mask[..., t, h, w] = 1 if d_square <= d_s * 2 else 0 - return mask - - -def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): - """ - Compute the ideal low pass filter mask (approximated version). - - Args: - shape: shape of the filter (volume) - d_s: normalized stop frequency for spatial dimensions (0.0-1.0) - d_t: normalized stop frequency for temporal dimension (0.0-1.0) - """ - T, H, W = shape[-3], shape[-2], shape[-1] - mask = torch.zeros(shape) - if d_s == 0 or d_t == 0: - return mask - - threshold_s = round(int(H // 2) * d_s) - threshold_t = round(T // 2 * d_t) - - cframe, crow, ccol = T // 2, H // 2, W // 2 - mask[ - ..., - cframe - threshold_t : cframe + threshold_t, - crow - threshold_s : crow + threshold_s, - ccol - threshold_s : ccol + threshold_s, - ] = 1.0 - - return mask diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index d2519b2ed76d..10218f863c43 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -13,11 +13,13 @@ # limitations under the License. import inspect +import math from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch +import torch.fft as fft from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput, VaeImageProcessor @@ -44,13 +46,6 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .freeinit_utils import ( - box_low_pass_filter, - butterworth_low_pass_filter, - freq_mix_3d, - gaussian_low_pass_filter, - ideal_low_pass_filter, -) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -88,20 +83,67 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): def _get_freeinit_freq_filter(shape, device, filter_type, order, spatial_stop_frequency, temporal_stop_frequency): - if filter_type == "gaussian": - return gaussian_low_pass_filter(shape=shape, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency).to( - device - ) + T, H, W = shape[-3], shape[-2], shape[-1] + mask = torch.zeros(shape) + + if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: + return mask + + if filter_type == "butterworth": + + def retrieve_mask(x): + return 1 / (1 + (x / spatial_stop_frequency**2) ** order) + elif filter_type == "gaussian": + + def retrieve_mask(x): + return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) elif filter_type == "ideal": - return ideal_low_pass_filter(shape=shape, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency).to(device) - elif filter_type == "box": - return box_low_pass_filter(shape=shape, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency).to(device) - elif filter_type == "butterworth": - return butterworth_low_pass_filter( - shape=shape, n=order, d_s=spatial_stop_frequency, d_t=temporal_stop_frequency - ).to(device) + + def retrieve_mask(x): + return 1 if x <= spatial_stop_frequency * 2 else 0 else: - raise NotImplementedError + raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") + + for t in range(T): + for h in range(H): + for w in range(W): + d_square = ( + ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2 + + (2 * h / H - 1) ** 2 + + (2 * w / W - 1) ** 2 + ) + mask[..., t, h, w] = retrieve_mask(d_square) + + mask.to(device) + return mask + + +def _freq_mix_3d(x, noise, LPF): + """ + Noise reinitialization. + + Args: + x: diffused latent + noise: randomly sampled noise + LPF: low pass filter + """ + # FFT + x_freq = fft.fftn(x, dim=(-3, -2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) + noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) + noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) + + # frequency mix + HPF = 1 - LPF + x_freq_low = x_freq * LPF + noise_freq_high = noise_freq * HPF + x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain + + # IFFT + x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) + x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real + + return x_mixed @dataclass @@ -743,7 +785,7 @@ def _free_init_loop( device=device, dtype=torch.float32, ) - latents = freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) + latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter) latents = latents.to(prompt_embeds.dtype) # Coarse-to-Fine Sampling for faster inference (can lead to lower quality) From 59e916ffd21a7ac5c8d95a188241749954319dfb Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 4 Jan 2024 20:52:03 +0530 Subject: [PATCH 11/16] fix test --- tests/pipelines/animatediff/test_animatediff.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 9bbfc768d2bc..841da5994925 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -38,8 +38,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "generator", "latents", "return_dict", - "callback", - "callback_steps", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", ] ) From 6a5949ea7de88e0d25f80717c5323ea5eebee487 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 5 Jan 2024 12:22:52 +0530 Subject: [PATCH 12/16] fix tensor not on same device --- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 10218f863c43..44afc6a3b828 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -114,8 +114,7 @@ def retrieve_mask(x): ) mask[..., t, h, w] = retrieve_mask(d_square) - mask.to(device) - return mask + return mask.to(device) def _freq_mix_3d(x, noise, LPF): From 3109a3ea4fceebc29997cf38d6ffcbf300c31ad4 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 5 Jan 2024 12:28:31 +0530 Subject: [PATCH 13/16] update --- .../animatediff/pipeline_animatediff.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 44afc6a3b828..19ec01e914e7 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -15,7 +15,7 @@ import inspect import math from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -82,7 +82,16 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): return outputs -def _get_freeinit_freq_filter(shape, device, filter_type, order, spatial_stop_frequency, temporal_stop_frequency): +def _get_freeinit_freq_filter( + shape: Tuple[int, ...], + device: Union[str, torch.dtype], + filter_type: str, + order: float, + spatial_stop_frequency: float, + temporal_stop_frequency: float, +) -> torch.Tensor: + r"""Returns the FreeInit filter based on filter type and other input conditions.""" + T, H, W = shape[-3], shape[-2], shape[-1] mask = torch.zeros(shape) @@ -117,15 +126,8 @@ def retrieve_mask(x): return mask.to(device) -def _freq_mix_3d(x, noise, LPF): - """ - Noise reinitialization. - - Args: - x: diffused latent - noise: randomly sampled noise - LPF: low pass filter - """ +def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor: + r"""Noise reinitialization.""" # FFT x_freq = fft.fftn(x, dim=(-3, -2, -1)) x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) @@ -535,7 +537,7 @@ def enable_free_init( Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. method (`str`, *optional*, defaults to `butterworth`): - Must be one of `box`, `butterworth`, `ideal` or `gaussian` to use as the filtering method for the + Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the FreeInit low pass filter. order (`int`, *optional*, defaults to `4`): Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour From af332bb304ed231380416e0b1482bcd8455a704d Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 5 Jan 2024 18:10:50 +0530 Subject: [PATCH 14/16] remove return_intermediate_results --- .../animatediff/pipeline_animatediff.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 19ec01e914e7..0fb4637dab7f 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -524,7 +524,6 @@ def enable_free_init( spatial_stop_frequency: float = 0.25, temporal_stop_frequency: float = 0.25, generator: torch.Generator = None, - return_intermediate_results: bool = False, ): """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. @@ -551,8 +550,6 @@ def enable_free_init( generator (`torch.Generator`, *optional*, defaults to `0.25`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make FreeInit generation deterministic. - return_intermediate_results (`bool`, *optional*, defaults to `False`): - Whether or not to return intermediate sampling results for every FreeInit iteration. """ self._free_init_num_iters = num_iters self._free_init_use_fast_sampling = use_fast_sampling @@ -561,7 +558,6 @@ def enable_free_init( self._free_init_spatial_stop_frequency = spatial_stop_frequency self._free_init_temporal_stop_frequency = temporal_stop_frequency self._free_init_generator = generator - self._free_init_return_intermediate_results = return_intermediate_results def disable_free_init(self): """Disables the FreeInit mechanism if enabled.""" @@ -742,7 +738,6 @@ def _free_init_loop( timesteps = denoise_args.get("timesteps") num_inference_steps = denoise_args.get("num_inference_steps") - video_latents = [] latent_shape = ( batch_size * num_videos_per_prompt, num_channels_latents, @@ -800,16 +795,9 @@ def _free_init_loop( denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps}) latents = self._denoise_loop(**denoise_args) - # Whether or not to return intermediate generation results - if self._free_init_return_intermediate_results: - video_latents.append(latents) - free_init_progress_bar.update() - if not self._free_init_return_intermediate_results: - video_latents = latents - - return video_latents + return latents def _retrieve_video_frames(self, latents, output_type, return_dict): """Helper function to handle latents to output conversion.""" @@ -1084,10 +1072,7 @@ def __call__( else: latents = self._denoise_loop(**denoise_args) - if self.free_init_enabled and self._free_init_return_intermediate_results: - video = [self._retrieve_video_frames(latent, output_type, return_dict) for latent in latents] - else: - video = self._retrieve_video_frames(latents, output_type, return_dict) + video = self._retrieve_video_frames(latents, output_type, return_dict) # 9. Offload all models self.maybe_free_model_hooks() From eb95450d1c4cc218d54101924c860f74c9f1e72a Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 8 Jan 2024 14:18:06 +0530 Subject: [PATCH 15/16] fix broken freeinit test --- tests/pipelines/animatediff/test_animatediff.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 841da5994925..44cb730a9501 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -251,7 +251,6 @@ def test_free_init(self): spatial_stop_frequency=0.25, temporal_stop_frequency=0.25, generator=free_init_generator, - return_intermediate_results=False, ) inputs_enable_free_init = self.get_dummy_inputs(torch_device) frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0] From d9fdf1eea9aaa72bf3873a09492890393f3d176d Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Fri, 12 Jan 2024 19:18:59 +0530 Subject: [PATCH 16/16] update animatediff docs --- docs/source/en/api/pipelines/animatediff.md | 58 +++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index fb38687e882e..4e1670df7717 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif") +## Using FreeInit + +[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu. + +FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper. + +The following example demonstrates the usage of FreeInit. + +```python +import torch +from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler +from diffusers.utils import export_to_gif + +adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" +pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda") +pipe.scheduler = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + beta_schedule="linear", + clip_sample=False, + timestep_spacing="linspace", + steps_offset=1 +) + +# enable memory savings +pipe.enable_vae_slicing() +pipe.enable_vae_tiling() + +# enable FreeInit +# Refer to the enable_free_init documentation for a full list of configurable parameters +pipe.enable_free_init(method="butterworth", use_fast_sampling=True) + +# run inference +output = pipe( + prompt="a panda playing a guitar, on a boat, in the ocean, high quality", + negative_prompt="bad quality, worse quality", + num_frames=16, + guidance_scale=7.5, + num_inference_steps=20, + generator=torch.Generator("cpu").manual_seed(666), +) + +# disable FreeInit +pipe.disable_free_init() + +frames = output.frames[0] +export_to_gif(frames, "animation.gif") +``` + + + +FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models). + + + Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. @@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) - __call__ - enable_freeu - disable_freeu + - enable_free_init + - disable_free_init - enable_vae_slicing - disable_vae_slicing - enable_vae_tiling