Skip to content

Commit b5cfab4

Browse files
committed
t2i
1 parent b261322 commit b5cfab4

File tree

1 file changed

+84
-146
lines changed

1 file changed

+84
-146
lines changed

docs/source/en/using-diffusers/t2i_adapter.md

Lines changed: 84 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -12,208 +12,146 @@ specific language governing permissions and limitations under the License.
1212

1313
# T2I-Adapter
1414

15-
[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter for controlling and providing more accurate
16-
structure guidance for text-to-image models. It works by learning an alignment between the internal knowledge of the
17-
text-to-image model and an external control signal, such as edge detection or depth estimation.
15+
[T2I-Adapter](https://huggingface.co/papers/2302.08453) is an adapter that enables controllable generation like [ControlNet](./controlnet). A T2I-Adapter works by learning a *mapping* between a control signal (for example, a depth map) and a pretrained model's internal knowledge. The adapter is plugged in to the base model to provide extra guidance based on the control signal during generation.
1816

19-
The T2I-Adapter design is simple, the condition is passed to four feature extraction blocks and three downsample
20-
blocks. This makes it fast and easy to train different adapters for different conditions which can be plugged into the
21-
text-to-image model. T2I-Adapter is similar to [ControlNet](controlnet) except it is smaller (~77M parameters) and
22-
faster because it only runs once during the diffusion process. The downside is that performance may be slightly worse
23-
than ControlNet.
24-
25-
This guide will show you how to use T2I-Adapter with different Stable Diffusion models and how you can compose multiple
26-
T2I-Adapters to impose more than one condition.
27-
28-
> [!TIP]
29-
> There are several T2I-Adapters available for different conditions, such as color palette, depth, sketch, pose, and
30-
> segmentation. Check out the [TencentARC](https://hf.co/TencentARC) repository to try them out!
31-
32-
Before you begin, make sure you have the following libraries installed.
17+
Load a T2I-Adapter conditioned on a specific control, such as canny edge, and pass it to the pipeline in [`~DiffusionPipeline.from_pretrained`].
3318

3419
```py
35-
# uncomment to install the necessary libraries in Colab
36-
#!pip install -q diffusers accelerate controlnet-aux==0.0.7
37-
```
38-
39-
## Text-to-image
40-
41-
Text-to-image models rely on a prompt to generate an image, but sometimes, text alone may not be enough to provide more
42-
accurate structural guidance. T2I-Adapter allows you to provide an additional control image to guide the generation
43-
process. For example, you can provide a canny image (a white outline of an image on a black background) to guide the
44-
model to generate an image with a similar structure.
20+
import torch
21+
from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, AutoencoderKL
4522

46-
<hfoptions id="stablediffusion">
47-
<hfoption id="Stable Diffusion 1.5">
23+
t2i_adapter = T2IAdapter.from_pretrained(
24+
"TencentARC/t2i-adapter-canny-sdxl-1.0",
25+
torch_dtype=torch.float16,
26+
)
27+
```
4828

49-
Create a canny image with the [opencv-library](https://github.com/opencv/opencv-python).
29+
Generate a canny image with [opencv-python](https://github.com/opencv/opencv-python).
5030

5131
```py
5232
import cv2
5333
import numpy as np
5434
from PIL import Image
5535
from diffusers.utils import load_image
5636

57-
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png")
58-
image = np.array(image)
37+
original_image = load_image(
38+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png"
39+
)
40+
41+
image = np.array(original_image)
5942

6043
low_threshold = 100
6144
high_threshold = 200
6245

6346
image = cv2.Canny(image, low_threshold, high_threshold)
64-
image = Image.fromarray(image)
65-
```
66-
67-
Now load a T2I-Adapter conditioned on [canny images](https://hf.co/TencentARC/t2iadapter_canny_sd15v2) and pass it to
68-
the [`StableDiffusionAdapterPipeline`].
69-
70-
```py
71-
import torch
72-
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
73-
74-
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_canny_sd15v2", torch_dtype=torch.float16)
75-
pipeline = StableDiffusionAdapterPipeline.from_pretrained(
76-
"stable-diffusion-v1-5/stable-diffusion-v1-5",
77-
adapter=adapter,
78-
torch_dtype=torch.float16,
79-
)
80-
pipeline.to("cuda")
81-
```
82-
83-
Finally, pass your prompt and control image to the pipeline.
84-
85-
```py
86-
generator = torch.Generator("cuda").manual_seed(0)
87-
88-
image = pipeline(
89-
prompt="cinematic photo of a plush and soft midcentury style rug on a wooden floor, 35mm photograph, film, professional, 4k, highly detailed",
90-
image=image,
91-
generator=generator,
92-
).images[0]
93-
image
94-
```
95-
96-
<div class="flex justify-center">
97-
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-sd1.5.png"/>
98-
</div>
99-
100-
</hfoption>
101-
<hfoption id="Stable Diffusion XL">
102-
103-
Create a canny image with the [controlnet-aux](https://github.com/huggingface/controlnet_aux) library.
104-
105-
```py
106-
from controlnet_aux.canny import CannyDetector
107-
from diffusers.utils import load_image
108-
109-
canny_detector = CannyDetector()
110-
111-
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png")
112-
image = canny_detector(image, detect_resolution=384, image_resolution=1024)
47+
image = image[:, :, None]
48+
image = np.concatenate([image, image, image], axis=2)
49+
canny_image = Image.fromarray(image)
11350
```
11451

115-
Now load a T2I-Adapter conditioned on [canny images](https://hf.co/TencentARC/t2i-adapter-canny-sdxl-1.0) and pass it
116-
to the [`StableDiffusionXLAdapterPipeline`].
52+
Pass the canny image to the pipeline to generate an image.
11753

11854
```py
119-
import torch
120-
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
121-
122-
scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
12355
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
124-
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-canny-sdxl-1.0", torch_dtype=torch.float16)
12556
pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
12657
"stabilityai/stable-diffusion-xl-base-1.0",
127-
adapter=adapter,
58+
adapter=t2i_adapter,
12859
vae=vae,
129-
scheduler=scheduler,
13060
torch_dtype=torch.float16,
131-
variant="fp16",
132-
)
133-
pipeline.to("cuda")
134-
```
135-
136-
Finally, pass your prompt and control image to the pipeline.
61+
).to("cuda")
13762

138-
```py
139-
generator = torch.Generator("cuda").manual_seed(0)
63+
prompt = """
64+
A photorealistic overhead image of a cat reclining sideways in a flamingo pool floatie holding a margarita.
65+
The cat is floating leisurely in the pool and completely relaxed and happy.
66+
"""
14067

141-
image = pipeline(
142-
prompt="cinematic photo of a plush and soft midcentury style rug on a wooden floor, 35mm photograph, film, professional, 4k, highly detailed",
143-
image=image,
144-
generator=generator,
68+
pipeline(
69+
prompt,
70+
image=canny_image,
71+
num_inference_steps=100,
72+
guidance_scale=10,
14573
).images[0]
146-
image
14774
```
14875

149-
<div class="flex justify-center">
150-
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-sdxl.png"/>
76+
<div style="display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;">
77+
<figure>
78+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png" width="300" alt="Generated image (prompt only)"/>
79+
<figcaption style="text-align: center;">original image</figcaption>
80+
</figure>
81+
<figure>
82+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png" width="300" alt="Control image (Canny edges)"/>
83+
<figcaption style="text-align: center;">canny image</figcaption>
84+
</figure>
85+
<figure>
86+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-canny-cat-generated.png" width="300" alt="Generated image (ControlNet + prompt)"/>
87+
<figcaption style="text-align: center;">generated image</figcaption>
88+
</figure>
15189
</div>
15290

153-
</hfoption>
154-
</hfoptions>
155-
15691
## MultiAdapter
15792

158-
T2I-Adapters are also composable, allowing you to use more than one adapter to impose multiple control conditions on an
159-
image. For example, you can use a pose map to provide structural control and a depth map for depth control. This is
160-
enabled by the [`MultiAdapter`] class.
93+
You can compose multiple controls, such as canny image and a depth map, with the [`MultiAdapter`] class.
16194

162-
Let's condition a text-to-image model with a pose and depth adapter. Create and place your depth and pose image and in a list.
95+
The example below composes a canny image and depth map.
96+
97+
Load the control images and T2I-Adapters as a list.
16398

16499
```py
100+
import torch
165101
from diffusers.utils import load_image
102+
from diffusers import StableDiffusionXLAdapterPipeline, AutoencoderKL, MultiAdapter, T2IAdapter
166103

167-
pose_image = load_image(
168-
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
104+
canny_image = load_image(
105+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png"
169106
)
170107
depth_image = load_image(
171-
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
108+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png"
172109
)
173-
cond = [pose_image, depth_image]
174-
prompt = ["Santa Claus walking into an office room with a beautiful city view"]
175-
```
176-
177-
<div class="flex gap-4">
178-
<div>
179-
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"/>
180-
<figcaption class="mt-2 text-center text-sm text-gray-500">depth image</figcaption>
181-
</div>
182-
<div>
183-
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"/>
184-
<figcaption class="mt-2 text-center text-sm text-gray-500">pose image</figcaption>
185-
</div>
186-
</div>
187-
188-
Load the corresponding pose and depth adapters as a list in the [`MultiAdapter`] class.
189-
190-
```py
191-
import torch
192-
from diffusers import StableDiffusionAdapterPipeline, MultiAdapter, T2IAdapter
110+
controls = [canny_image, depth_image]
111+
prompt = ["""
112+
a relaxed rabbit sitting on a striped towel next to a pool with a tropical drink nearby,
113+
bright sunny day, vacation scene, 35mm photograph, film, professional, 4k, highly detailed
114+
"""]
193115

194116
adapters = MultiAdapter(
195117
[
196-
T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
197-
T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
118+
T2IAdapter.from_pretrained("TencentARC/t2i-adapter-canny-sdxl-1.0", torch_dtype=torch.float16),
119+
T2IAdapter.from_pretrained("TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16),
198120
]
199121
)
200-
adapters = adapters.to(torch.float16)
201122
```
202123

203-
Finally, load a [`StableDiffusionAdapterPipeline`] with the adapters, and pass your prompt and conditioned images to
204-
it. Use the [`adapter_conditioning_scale`] to adjust the weight of each adapter on the image.
124+
Pass the adapters, prompt, and control images to [`StableDiffusionXLAdapterPipeline`]. Use the `adapter_conditioning_scale` parameter to determine how much weight to assign to each control.
205125

206126
```py
207-
pipeline = StableDiffusionAdapterPipeline.from_pretrained(
208-
"CompVis/stable-diffusion-v1-4",
127+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
128+
pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
129+
"stabilityai/stable-diffusion-xl-base-1.0",
209130
torch_dtype=torch.float16,
131+
vae=vae,
210132
adapter=adapters,
211133
).to("cuda")
212134

213-
image = pipeline(prompt, cond, adapter_conditioning_scale=[0.7, 0.7]).images[0]
214-
image
135+
pipeline(
136+
prompt,
137+
image=controls,
138+
height=1024,
139+
width=1024,
140+
adapter_conditioning_scale=[0.7, 0.7]
141+
).images[0]
215142
```
216143

217-
<div class="flex justify-center">
218-
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-multi.png"/>
144+
<div style="display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;">
145+
<figure>
146+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png" width="300" alt="Generated image (prompt only)"/>
147+
<figcaption style="text-align: center;">canny image</figcaption>
148+
</figure>
149+
<figure>
150+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png" width="300" alt="Control image (Canny edges)"/>
151+
<figcaption style="text-align: center;">depth map</figcaption>
152+
</figure>
153+
<figure>
154+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-multi-rabbbit.png" width="300" alt="Generated image (ControlNet + prompt)"/>
155+
<figcaption style="text-align: center;">generated image</figcaption>
156+
</figure>
219157
</div>

0 commit comments

Comments
 (0)