Skip to content

Commit b261322

Browse files
committed
controlnet
1 parent bd12cb3 commit b261322

File tree

1 file changed

+204
-1
lines changed

1 file changed

+204
-1
lines changed

docs/source/en/using-diffusers/controlnet.md

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ image = np.concatenate([image, image, image], axis=2)
4444
canny_image = Image.fromarray(image)
4545
```
4646

47-
Pass the canny image to the pipeline.
47+
Pass the canny image to the pipeline. Use the `controlnet_conditioning_scale` parameter to determine how much weight to assign to the control.
4848

4949
```py
5050
import torch
@@ -91,17 +91,220 @@ pipeline(
9191
</hfoption>
9292
<hfoption id="image-to-image">
9393

94+
Generate a depth map with a depth estimation pipeline from Transformers.
95+
96+
```py
97+
import torch
98+
import numpy as np
99+
from PIL import Image
100+
from transformers import DPTImageProcessor, DPTForDepthEstimation
101+
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
102+
from diffusers.utils import load_image
103+
104+
105+
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
106+
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
107+
108+
def get_depth_map(image):
109+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
110+
with torch.no_grad(), torch.autocast("cuda"):
111+
depth_map = depth_estimator(image).predicted_depth
112+
113+
depth_map = torch.nn.functional.interpolate(
114+
depth_map.unsqueeze(1),
115+
size=(1024, 1024),
116+
mode="bicubic",
117+
align_corners=False,
118+
)
119+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
120+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
121+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
122+
image = torch.cat([depth_map] * 3, dim=1)
123+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
124+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
125+
return image
126+
127+
depth_image = get_depth_map(image)
128+
```
129+
130+
Pass the depth map to the pipeline. Use the `controlnet_conditioning_scale` parameter to determine how much weight to assign to the control.
131+
94132
```py
133+
controlnet = ControlNetModel.from_pretrained(
134+
"diffusers/controlnet-depth-sdxl-1.0-small",
135+
torch_dtype=torch.float16,
136+
)
137+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
138+
pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
139+
"stabilityai/stable-diffusion-xl-base-1.0",
140+
controlnet=controlnet,
141+
vae=vae,
142+
torch_dtype=torch.float16,
143+
).to("cuda")
95144

145+
prompt = """
146+
A photorealistic overhead image of a cat reclining sideways in a flamingo pool floatie holding a margarita.
147+
The cat is floating leisurely in the pool and completely relaxed and happy.
148+
"""
149+
image = load_image(
150+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png"
151+
).resize((1024, 1024))
152+
controlnet_conditioning_scale = 0.5
153+
pipeline(
154+
prompt,
155+
image=image,
156+
control_image=depth_image,
157+
controlnet_conditioning_scale=controlnet_conditioning_scale,
158+
strength=0.99,
159+
num_inference_steps=100,
160+
).images[0]
96161
```
97162

163+
<div style="display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;">
164+
<figure>
165+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png" width="300" alt="Generated image (prompt only)"/>
166+
<figcaption style="text-align: center;">original image</figcaption>
167+
</figure>
168+
<figure>
169+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png" width="300" alt="Control image (Canny edges)"/>
170+
<figcaption style="text-align: center;">depth map</figcaption>
171+
</figure>
172+
<figure>
173+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_cat.png" width="300" alt="Generated image (ControlNet + prompt)"/>
174+
<figcaption style="text-align: center;">generated image</figcaption>
175+
</figure>
176+
</div>
177+
98178
</hfoption>
99179
<hfoption id="inpainting">
100180

181+
Generate a mask image and convert it to a tensor to mark the pixels in the original image as masked if the corresponding pixel in the mask image is over a certain threshold.
182+
183+
```py
184+
import cv2
185+
import torch
186+
import numpy as np
187+
from PIL import Image
188+
from diffusers.utils import load_image
189+
from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel
190+
191+
init_image = load_image(
192+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png"
193+
)
194+
init_image = init_image.resize((1024, 1024))
195+
mask_image = load_image(
196+
"/content/cat_mask.png"
197+
)
198+
mask_image = mask_image.resize((1024, 1024))
199+
200+
def make_canny_condition(image):
201+
image = np.array(image)
202+
image = cv2.Canny(image, 100, 200)
203+
image = image[:, :, None]
204+
image = np.concatenate([image, image, image], axis=2)
205+
image = Image.fromarray(image)
206+
return image
207+
208+
control_image = make_canny_condition(init_image)
209+
```
210+
211+
Pass the mask and control image to the pipeline. Use the `controlnet_conditioning_scale` parameter to determine how much weight to assign to the control.
212+
213+
```py
214+
controlnet = ControlNetModel.from_pretrained(
215+
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
216+
)
217+
pipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
218+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
219+
)
220+
pipeline(
221+
"a cute and fluffy bunny rabbit",
222+
num_inference_steps=100,
223+
strength=0.99,
224+
controlnet_conditioning_scale=0.5,
225+
image=init_image,
226+
mask_image=mask_image,
227+
control_image=control_image,
228+
).images[0]
229+
```
230+
231+
<div style="display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;">
232+
<figure>
233+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png" width="300" alt="Generated image (prompt only)"/>
234+
<figcaption style="text-align: center;">original image</figcaption>
235+
</figure>
236+
<figure>
237+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat_mask.png" width="300" alt="Control image (Canny edges)"/>
238+
<figcaption style="text-align: center;">mask image</figcaption>
239+
</figure>
240+
<figure>
241+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_rabbit_inpaint.png" width="300" alt="Generated image (ControlNet + prompt)"/>
242+
<figcaption style="text-align: center;">generated image</figcaption>
243+
</figure>
244+
</div>
101245

102246
</hfoption>
103247
</hfoptions>
104248

249+
## Multi-ControlNet
250+
251+
You can compose multiple ControlNet conditionings, such as canny image and a depth map, to create a *MultiControlNet*. For the best rersults, you should mask conditionings so they don't overlap and experiment with different `controlnet_conditioning_scale` parameters to adjust how much weight is assigned to each control input.
252+
253+
The example below composes a canny image and depth map.
254+
255+
Pass the ControlNets as a list to the pipeline and resize the images to the expected input size.
256+
257+
```py
258+
import torch
259+
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
260+
261+
controlnets = [
262+
ControlNetModel.from_pretrained(
263+
"diffusers/controlnet-depth-sdxl-1.0-small", torch_dtype=torch.float16
264+
),
265+
ControlNetModel.from_pretrained(
266+
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16,
267+
),
268+
]
269+
270+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
271+
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
272+
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnets, vae=vae, torch_dtype=torch.float16
273+
).to("cuda")
274+
275+
prompt = """
276+
a relaxed rabbit sitting on a striped towel next to a pool with a tropical drink nearby,
277+
bright sunny day, vacation scene, 35mm photograph, film, professional, 4k, highly detailed
278+
"""
279+
negative_prompt = "lowres, bad anatomy, worst quality, low quality, deformed, ugly"
280+
281+
images = [canny_image.resize((1024, 1024)), depth_image.resize((1024, 1024))]
282+
283+
pipeline(
284+
prompt,
285+
negative_prompt=negative_prompt,
286+
image=images,
287+
num_inference_steps=100,
288+
controlnet_conditioning_scale=[0.5, 0.5],
289+
strength=0.7,
290+
).images[0]
291+
```
292+
293+
<div style="display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;">
294+
<figure>
295+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png" width="300" alt="Generated image (prompt only)"/>
296+
<figcaption style="text-align: center;">canny image</figcaption>
297+
</figure>
298+
<figure>
299+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/multicontrolnet_depth.png" width="300" alt="Control image (Canny edges)"/>
300+
<figcaption style="text-align: center;">depth map</figcaption>
301+
</figure>
302+
<figure>
303+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_multi_controlnet.png" width="300" alt="Generated image (ControlNet + prompt)"/>
304+
<figcaption style="text-align: center;">generated image</figcaption>
305+
</figure>
306+
</div>
307+
105308
## guess_mode
106309

107310
[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) generates an image from **only** the control input (canny edge, depth map, pose, etc.) and without guidance from a prompt. It adjusts the scale of the ControlNet's output residuals by a fixed ratio depending on block depth. The earlier `DownBlock` is only scaled by `0.1` and the `MidBlock` is fully scaled by `1.0`.

0 commit comments

Comments
 (0)