diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index fb38687e882e..4e1670df7717 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif")
+## Using FreeInit
+
+[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
+
+FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
+
+The following example demonstrates the usage of FreeInit.
+
+```python
+import torch
+from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
+from diffusers.utils import export_to_gif
+
+adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
+model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
+pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
+pipe.scheduler = DDIMScheduler.from_pretrained(
+ model_id,
+ subfolder="scheduler",
+ beta_schedule="linear",
+ clip_sample=False,
+ timestep_spacing="linspace",
+ steps_offset=1
+)
+
+# enable memory savings
+pipe.enable_vae_slicing()
+pipe.enable_vae_tiling()
+
+# enable FreeInit
+# Refer to the enable_free_init documentation for a full list of configurable parameters
+pipe.enable_free_init(method="butterworth", use_fast_sampling=True)
+
+# run inference
+output = pipe(
+ prompt="a panda playing a guitar, on a boat, in the ocean, high quality",
+ negative_prompt="bad quality, worse quality",
+ num_frames=16,
+ guidance_scale=7.5,
+ num_inference_steps=20,
+ generator=torch.Generator("cpu").manual_seed(666),
+)
+
+# disable FreeInit
+pipe.disable_free_init()
+
+frames = output.frames[0]
+export_to_gif(frames, "animation.gif")
+```
+
+
+
+FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
+
+
+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
@@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- __call__
- enable_freeu
- disable_freeu
+ - enable_free_init
+ - disable_free_init
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index b0fe790c2222..0fb4637dab7f 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -13,11 +13,13 @@
# limitations under the License.
import inspect
+import math
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
+import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
@@ -36,6 +38,7 @@
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
+ deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -79,6 +82,71 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
return outputs
+def _get_freeinit_freq_filter(
+ shape: Tuple[int, ...],
+ device: Union[str, torch.dtype],
+ filter_type: str,
+ order: float,
+ spatial_stop_frequency: float,
+ temporal_stop_frequency: float,
+) -> torch.Tensor:
+ r"""Returns the FreeInit filter based on filter type and other input conditions."""
+
+ T, H, W = shape[-3], shape[-2], shape[-1]
+ mask = torch.zeros(shape)
+
+ if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
+ return mask
+
+ if filter_type == "butterworth":
+
+ def retrieve_mask(x):
+ return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
+ elif filter_type == "gaussian":
+
+ def retrieve_mask(x):
+ return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
+ elif filter_type == "ideal":
+
+ def retrieve_mask(x):
+ return 1 if x <= spatial_stop_frequency * 2 else 0
+ else:
+ raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
+
+ for t in range(T):
+ for h in range(H):
+ for w in range(W):
+ d_square = (
+ ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2
+ + (2 * h / H - 1) ** 2
+ + (2 * w / W - 1) ** 2
+ )
+ mask[..., t, h, w] = retrieve_mask(d_square)
+
+ return mask.to(device)
+
+
+def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
+ r"""Noise reinitialization."""
+ # FFT
+ x_freq = fft.fftn(x, dim=(-3, -2, -1))
+ x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
+ noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
+ noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
+
+ # frequency mix
+ HPF = 1 - LPF
+ x_freq_low = x_freq * LPF
+ noise_freq_high = noise_freq * HPF
+ x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
+
+ # IFFT
+ x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
+ x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
+
+ return x_mixed
+
+
@dataclass
class AnimateDiffPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
@@ -115,6 +183,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -442,6 +511,58 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ @property
+ def free_init_enabled(self):
+ return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
+
+ def enable_free_init(
+ self,
+ num_iters: int = 3,
+ use_fast_sampling: bool = False,
+ method: str = "butterworth",
+ order: int = 4,
+ spatial_stop_frequency: float = 0.25,
+ temporal_stop_frequency: float = 0.25,
+ generator: torch.Generator = None,
+ ):
+ """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
+
+ This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
+
+ Args:
+ num_iters (`int`, *optional*, defaults to `3`):
+ Number of FreeInit noise re-initialization iterations.
+ use_fast_sampling (`bool`, *optional*, defaults to `False`):
+ Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
+ the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
+ method (`str`, *optional*, defaults to `butterworth`):
+ Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
+ FreeInit low pass filter.
+ order (`int`, *optional*, defaults to `4`):
+ Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
+ whereas lower values lead to `gaussian` method behaviour.
+ spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
+ Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
+ the original implementation.
+ temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
+ Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
+ the original implementation.
+ generator (`torch.Generator`, *optional*, defaults to `0.25`):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ FreeInit generation deterministic.
+ """
+ self._free_init_num_iters = num_iters
+ self._free_init_use_fast_sampling = use_fast_sampling
+ self._free_init_method = method
+ self._free_init_order = order
+ self._free_init_spatial_stop_frequency = spatial_stop_frequency
+ self._free_init_temporal_stop_frequency = temporal_stop_frequency
+ self._free_init_generator = generator
+
+ def disable_free_init(self):
+ """Disables the FreeInit mechanism if enabled."""
+ self._free_init_num_iters = None
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -539,6 +660,185 @@ def prepare_latents(
latents = latents * self.scheduler.init_noise_sigma
return latents
+ def _denoise_loop(
+ self,
+ timesteps,
+ num_inference_steps,
+ do_classifier_free_guidance,
+ guidance_scale,
+ num_warmup_steps,
+ prompt_embeds,
+ negative_prompt_embeds,
+ latents,
+ cross_attention_kwargs,
+ added_cond_kwargs,
+ extra_step_kwargs,
+ callback,
+ callback_steps,
+ callback_on_step_end,
+ callback_on_step_end_tensor_inputs,
+ ):
+ """Denoising loop for AnimateDiff."""
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ return latents
+
+ def _free_init_loop(
+ self,
+ height,
+ width,
+ num_frames,
+ num_channels_latents,
+ batch_size,
+ num_videos_per_prompt,
+ denoise_args,
+ device,
+ ):
+ """Denoising loop for AnimateDiff using FreeInit noise reinitialization technique."""
+
+ latents = denoise_args.get("latents")
+ prompt_embeds = denoise_args.get("prompt_embeds")
+ timesteps = denoise_args.get("timesteps")
+ num_inference_steps = denoise_args.get("num_inference_steps")
+
+ latent_shape = (
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ free_init_filter_shape = (
+ 1,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ free_init_freq_filter = _get_freeinit_freq_filter(
+ shape=free_init_filter_shape,
+ device=device,
+ filter_type=self._free_init_method,
+ order=self._free_init_order,
+ spatial_stop_frequency=self._free_init_spatial_stop_frequency,
+ temporal_stop_frequency=self._free_init_temporal_stop_frequency,
+ )
+
+ with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar:
+ for i in range(self._free_init_num_iters):
+ # For the first FreeInit iteration, the original latent is used without modification.
+ # Subsequent iterations apply the noise reinitialization technique.
+ if i == 0:
+ initial_noise = latents.detach().clone()
+ else:
+ current_diffuse_timestep = (
+ self.scheduler.config.num_train_timesteps - 1
+ ) # diffuse to t=999 noise level
+ diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long()
+ z_T = self.scheduler.add_noise(
+ original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device)
+ ).to(dtype=torch.float32)
+ z_rand = randn_tensor(
+ shape=latent_shape,
+ generator=self._free_init_generator,
+ device=device,
+ dtype=torch.float32,
+ )
+ latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter)
+ latents = latents.to(prompt_embeds.dtype)
+
+ # Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
+ if self._free_init_use_fast_sampling:
+ current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1))
+ self.scheduler.set_timesteps(current_num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps})
+
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps})
+ latents = self._denoise_loop(**denoise_args)
+
+ free_init_progress_bar.update()
+
+ return latents
+
+ def _retrieve_video_frames(self, latents, output_type, return_dict):
+ """Helper function to handle latents to output conversion."""
+ if output_type == "latent":
+ return AnimateDiffPipelineOutput(frames=latents)
+
+ video_tensor = self.decode_latents(latents)
+
+ if output_type == "pt":
+ video = video_tensor
+ else:
+ video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
+
+ if not return_dict:
+ return (video,)
+
+ return AnimateDiffPipelineOutput(frames=video)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -559,10 +859,11 @@ def __call__(
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -603,25 +904,30 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
of a plain tuple.
- callback (`Callable`, *optional*):
- A function that calls every `callback_steps` steps during inference. The function is called with the
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function is called. If not specified, the callback is called at
- every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
+
Examples:
Returns:
@@ -629,6 +935,23 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -637,9 +960,20 @@ def __call__(
# 1. Check inputs. Raise error if not correct
self.check_inputs(
- prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -649,30 +983,26 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
- do_classifier_free_guidance,
+ self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=clip_skip,
+ clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
@@ -680,12 +1010,13 @@ def __call__(
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
- if do_classifier_free_guidance:
+ if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
+ self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -703,55 +1034,47 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7 Add image embeds for IP-Adapter
+
+ # 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
- # Denoising loop
+ # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- with self.progress_bar(total=num_inference_steps) as progress_bar:
- for i, t in enumerate(timesteps):
- # expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-
- # predict the noise residual
- noise_pred = self.unet(
- latent_model_input,
- t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
- ).sample
-
- # perform guidance
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
-
- # compute the previous noisy sample x_t -> x_t-1
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
-
- # call the callback, if provided
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
- progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
-
- if output_type == "latent":
- return AnimateDiffPipelineOutput(frames=latents)
-
- # Post-processing
- video_tensor = self.decode_latents(latents)
-
- if output_type == "pt":
- video = video_tensor
+ denoise_args = {
+ "timesteps": timesteps,
+ "num_inference_steps": num_inference_steps,
+ "do_classifier_free_guidance": self.do_classifier_free_guidance,
+ "guidance_scale": guidance_scale,
+ "num_warmup_steps": num_warmup_steps,
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "latents": latents,
+ "cross_attention_kwargs": self.cross_attention_kwargs,
+ "added_cond_kwargs": added_cond_kwargs,
+ "extra_step_kwargs": extra_step_kwargs,
+ "callback": callback,
+ "callback_steps": callback_steps,
+ "callback_on_step_end": callback_on_step_end,
+ "callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs,
+ }
+
+ if self.free_init_enabled:
+ latents = self._free_init_loop(
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_channels_latents=num_channels_latents,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ denoise_args=denoise_args,
+ device=device,
+ )
else:
- video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
+ latents = self._denoise_loop(**denoise_args)
- # Offload all models
- self.maybe_free_model_hooks()
+ video = self._retrieve_video_frames(latents, output_type, return_dict)
- if not return_dict:
- return (video,)
+ # 9. Offload all models
+ self.maybe_free_model_hooks()
- return AnimateDiffPipelineOutput(frames=video)
+ return video
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index 633ed9fc233e..44cb730a9501 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -38,8 +38,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator",
"latents",
"return_dict",
- "callback",
- "callback_steps",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
]
)
@@ -233,6 +233,43 @@ def test_prompt_embeds(self):
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
pipe(**inputs)
+ def test_free_init(self):
+ components = self.get_dummy_components()
+ pipe: AnimateDiffPipeline = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs_normal = self.get_dummy_inputs(torch_device)
+ frames_normal = pipe(**inputs_normal).frames[0]
+
+ free_init_generator = torch.Generator(device=torch_device).manual_seed(0)
+ pipe.enable_free_init(
+ num_iters=2,
+ use_fast_sampling=True,
+ method="butterworth",
+ order=4,
+ spatial_stop_frequency=0.25,
+ temporal_stop_frequency=0.25,
+ generator=free_init_generator,
+ )
+ inputs_enable_free_init = self.get_dummy_inputs(torch_device)
+ frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
+
+ pipe.disable_free_init()
+ inputs_disable_free_init = self.get_dummy_inputs(torch_device)
+ frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
+
+ sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
+ max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
+ self.assertGreater(
+ sum_enabled, 1e2, "Enabling of FreeInit should lead to results different from the default pipeline results"
+ )
+ self.assertLess(
+ max_diff_disabled,
+ 1e-4,
+ "Disabling of FreeInit should lead to results similar to the default pipeline results",
+ )
+
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",