diff --git a/examples/community/README.md b/examples/community/README.md index 7d8d190f037f..4262615ef97c 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2989,7 +2989,7 @@ pipe = DiffusionPipeline.from_pretrained( custom_pipeline="pipeline_animatediff_controlnet", ).to(device="cuda", dtype=torch.float16) pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( - model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1 + model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", ) pipe.enable_vae_slicing() @@ -3005,7 +3005,7 @@ result = pipe( width=512, height=768, conditioning_frames=conditioning_frames, - num_inference_steps=12, + num_inference_steps=20, ).frames[0] from diffusers.utils import export_to_gif @@ -3029,6 +3029,79 @@ export_to_gif(result.frames[0], "result.gif") +You can also use multiple controlnets at once! + +```python +import torch +from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter +from diffusers.pipelines import DiffusionPipeline +from diffusers.schedulers import DPMSolverMultistepScheduler +from PIL import Image + +motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" +adapter = MotionAdapter.from_pretrained(motion_id) +controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) +controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + +model_id = "SG161222/Realistic_Vision_V5.1_noVAE" +pipe = DiffusionPipeline.from_pretrained( + model_id, + motion_adapter=adapter, + controlnet=[controlnet1, controlnet2], + vae=vae, + custom_pipeline="pipeline_animatediff_controlnet", +).to(device="cuda", dtype=torch.float16) +pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", +) +pipe.enable_vae_slicing() + +def load_video(file_path: str): + images = [] + + if file_path.startswith(('http://', 'https://')): + # If the file_path is a URL + response = requests.get(file_path) + response.raise_for_status() + content = BytesIO(response.content) + vid = imageio.get_reader(content) + else: + # Assuming it's a local file path + vid = imageio.get_reader(file_path) + + for frame in vid: + pil_image = Image.fromarray(frame) + images.append(pil_image) + + return images + +video = load_video("dance.gif") + +# You need to install it using `pip install controlnet_aux` +from controlnet_aux.processor import Processor + +p1 = Processor("openpose_full") +cn1 = [p1(frame) for frame in video] + +p2 = Processor("canny") +cn2 = [p2(frame) for frame in video] + +prompt = "astronaut in space, dancing" +negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" +result = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + conditioning_frames=[cn1, cn2], + num_inference_steps=20, +) + +from diffusers.utils import export_to_gif +export_to_gif(result.frames[0], "result.gif") +``` + ### DemoFusion This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973). diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index 785f1ee55ec2..cf0c66bb50d0 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -14,7 +14,7 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -66,7 +66,7 @@ ... custom_pipeline="pipeline_animatediff_controlnet", ... ).to(device="cuda", dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( - ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1 + ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear", ... ) >>> pipe.enable_vae_slicing() @@ -83,7 +83,7 @@ ... height=768, ... conditioning_frames=conditioning_frames, ... num_inference_steps=12, - ... ).frames[0] + ... ) >>> from diffusers.utils import export_to_gif >>> export_to_gif(result.frames[0], "result.gif") @@ -151,7 +151,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, motion_adapter: MotionAdapter, - controlnet: Union[ControlNetModel, MultiControlNetModel], + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: Union[ DDIMScheduler, PNDMScheduler, @@ -166,6 +166,9 @@ def __init__( super().__init__() unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -488,6 +491,7 @@ def check_inputs( prompt, height, width, + num_frames, callback_steps, negative_prompt=None, prompt_embeds=None, @@ -557,31 +561,21 @@ def check_inputs( or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): - if isinstance(image, list): - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - self.check_image(image, prompt, prompt_embeds) + if not isinstance(image, list): + raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}") + if len(image) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - - for control_ in image: - for image_ in control_: - self.check_image(image_, prompt, prompt_embeds) + if not isinstance(image, list) or not isinstance(image[0], list): + raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}") + if len(image[0]) != num_frames: + raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}") + if any(len(img) != len(image[0]) for img in image): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") else: assert False @@ -913,6 +907,7 @@ def __call__( prompt=prompt, height=height, width=width, + num_frames=num_frames, callback_steps=callback_steps, negative_prompt=negative_prompt, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -1000,9 +995,7 @@ def __call__( do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) - cond_prepared_frames.append(prepared_frame) - conditioning_frames = cond_prepared_frames else: assert False