@@ -12,208 +12,146 @@ specific language governing permissions and limitations under the License.
12
12
13
13
# T2I-Adapter
14
14
15
- [ T2I-Adapter] ( https://hf.co/papers/2302.08453 ) is a lightweight adapter for controlling and providing more accurate
16
- structure guidance for text-to-image models. It works by learning an alignment between the internal knowledge of the
17
- text-to-image model and an external control signal, such as edge detection or depth estimation.
15
+ [ T2I-Adapter] ( https://huggingface.co/papers/2302.08453 ) is an adapter that enables controllable generation like [ ControlNet] ( ./controlnet ) . A T2I-Adapter works by learning a * mapping* between a control signal (for example, a depth map) and a pretrained model's internal knowledge. The adapter is plugged in to the base model to provide extra guidance based on the control signal during generation.
18
16
19
- The T2I-Adapter design is simple, the condition is passed to four feature extraction blocks and three downsample
20
- blocks. This makes it fast and easy to train different adapters for different conditions which can be plugged into the
21
- text-to-image model. T2I-Adapter is similar to [ ControlNet] ( controlnet ) except it is smaller (~ 77M parameters) and
22
- faster because it only runs once during the diffusion process. The downside is that performance may be slightly worse
23
- than ControlNet.
24
-
25
- This guide will show you how to use T2I-Adapter with different Stable Diffusion models and how you can compose multiple
26
- T2I-Adapters to impose more than one condition.
27
-
28
- > [ !TIP]
29
- > There are several T2I-Adapters available for different conditions, such as color palette, depth, sketch, pose, and
30
- > segmentation. Check out the [ TencentARC] ( https://hf.co/TencentARC ) repository to try them out!
31
-
32
- Before you begin, make sure you have the following libraries installed.
17
+ Load a T2I-Adapter conditioned on a specific control, such as canny edge, and pass it to the pipeline in [ ` ~DiffusionPipeline.from_pretrained ` ] .
33
18
34
19
``` py
35
- # uncomment to install the necessary libraries in Colab
36
- # !pip install -q diffusers accelerate controlnet-aux==0.0.7
37
- ```
38
-
39
- ## Text-to-image
40
-
41
- Text-to-image models rely on a prompt to generate an image, but sometimes, text alone may not be enough to provide more
42
- accurate structural guidance. T2I-Adapter allows you to provide an additional control image to guide the generation
43
- process. For example, you can provide a canny image (a white outline of an image on a black background) to guide the
44
- model to generate an image with a similar structure.
20
+ import torch
21
+ from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, AutoencoderKL
45
22
46
- <hfoptions id =" stablediffusion " >
47
- <hfoption id =" Stable Diffusion 1.5 " >
23
+ t2i_adapter = T2IAdapter.from_pretrained(
24
+ " TencentARC/t2i-adapter-canny-sdxl-1.0" ,
25
+ torch_dtype = torch.float16,
26
+ )
27
+ ```
48
28
49
- Create a canny image with the [ opencv-library ] ( https://github.com/opencv/opencv-python ) .
29
+ Generate a canny image with [ opencv-python ] ( https://github.com/opencv/opencv-python ) .
50
30
51
31
``` py
52
32
import cv2
53
33
import numpy as np
54
34
from PIL import Image
55
35
from diffusers.utils import load_image
56
36
57
- image = load_image(" https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" )
58
- image = np.array(image)
37
+ original_image = load_image(
38
+ " https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png"
39
+ )
40
+
41
+ image = np.array(original_image)
59
42
60
43
low_threshold = 100
61
44
high_threshold = 200
62
45
63
46
image = cv2.Canny(image, low_threshold, high_threshold)
64
- image = Image.fromarray(image)
65
- ```
66
-
67
- Now load a T2I-Adapter conditioned on [ canny images] ( https://hf.co/TencentARC/t2iadapter_canny_sd15v2 ) and pass it to
68
- the [ ` StableDiffusionAdapterPipeline ` ] .
69
-
70
- ``` py
71
- import torch
72
- from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
73
-
74
- adapter = T2IAdapter.from_pretrained(" TencentARC/t2iadapter_canny_sd15v2" , torch_dtype = torch.float16)
75
- pipeline = StableDiffusionAdapterPipeline.from_pretrained(
76
- " stable-diffusion-v1-5/stable-diffusion-v1-5" ,
77
- adapter = adapter,
78
- torch_dtype = torch.float16,
79
- )
80
- pipeline.to(" cuda" )
81
- ```
82
-
83
- Finally, pass your prompt and control image to the pipeline.
84
-
85
- ``` py
86
- generator = torch.Generator(" cuda" ).manual_seed(0 )
87
-
88
- image = pipeline(
89
- prompt = " cinematic photo of a plush and soft midcentury style rug on a wooden floor, 35mm photograph, film, professional, 4k, highly detailed" ,
90
- image = image,
91
- generator = generator,
92
- ).images[0 ]
93
- image
94
- ```
95
-
96
- <div class =" flex justify-center " >
97
- <img class =" rounded-xl " src =" https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-sd1.5.png " />
98
- </div >
99
-
100
- </hfoption >
101
- <hfoption id =" Stable Diffusion XL " >
102
-
103
- Create a canny image with the [ controlnet-aux] ( https://github.com/huggingface/controlnet_aux ) library.
104
-
105
- ``` py
106
- from controlnet_aux.canny import CannyDetector
107
- from diffusers.utils import load_image
108
-
109
- canny_detector = CannyDetector()
110
-
111
- image = load_image(" https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" )
112
- image = canny_detector(image, detect_resolution = 384 , image_resolution = 1024 )
47
+ image = image[:, :, None ]
48
+ image = np.concatenate([image, image, image], axis = 2 )
49
+ canny_image = Image.fromarray(image)
113
50
```
114
51
115
- Now load a T2I-Adapter conditioned on [ canny images] ( https://hf.co/TencentARC/t2i-adapter-canny-sdxl-1.0 ) and pass it
116
- to the [ ` StableDiffusionXLAdapterPipeline ` ] .
52
+ Pass the canny image to the pipeline to generate an image.
117
53
118
54
``` py
119
- import torch
120
- from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
121
-
122
- scheduler = EulerAncestralDiscreteScheduler.from_pretrained(" stabilityai/stable-diffusion-xl-base-1.0" , subfolder = " scheduler" )
123
55
vae = AutoencoderKL.from_pretrained(" madebyollin/sdxl-vae-fp16-fix" , torch_dtype = torch.float16)
124
- adapter = T2IAdapter.from_pretrained(" TencentARC/t2i-adapter-canny-sdxl-1.0" , torch_dtype = torch.float16)
125
56
pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
126
57
" stabilityai/stable-diffusion-xl-base-1.0" ,
127
- adapter = adapter ,
58
+ adapter = t2i_adapter ,
128
59
vae = vae,
129
- scheduler = scheduler,
130
60
torch_dtype = torch.float16,
131
- variant = " fp16" ,
132
- )
133
- pipeline.to(" cuda" )
134
- ```
135
-
136
- Finally, pass your prompt and control image to the pipeline.
61
+ ).to(" cuda" )
137
62
138
- ``` py
139
- generator = torch.Generator(" cuda" ).manual_seed(0 )
63
+ prompt = """
64
+ A photorealistic overhead image of a cat reclining sideways in a flamingo pool floatie holding a margarita.
65
+ The cat is floating leisurely in the pool and completely relaxed and happy.
66
+ """
140
67
141
- image = pipeline(
142
- prompt = " cinematic photo of a plush and soft midcentury style rug on a wooden floor, 35mm photograph, film, professional, 4k, highly detailed" ,
143
- image = image,
144
- generator = generator,
68
+ pipeline(
69
+ prompt,
70
+ image = canny_image,
71
+ num_inference_steps = 100 ,
72
+ guidance_scale = 10 ,
145
73
).images[0 ]
146
- image
147
74
```
148
75
149
- <div class =" flex justify-center " >
150
- <img class =" rounded-xl " src =" https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-sdxl.png " />
76
+ <div style =" display : flex ; gap : 10px ; justify-content : space-around ; align-items : flex-end ;" >
77
+ <figure >
78
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png" width="300" alt="Generated image (prompt only)"/>
79
+ <figcaption style="text-align: center;">original image</figcaption>
80
+ </figure >
81
+ <figure >
82
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png" width="300" alt="Control image (Canny edges)"/>
83
+ <figcaption style="text-align: center;">canny image</figcaption>
84
+ </figure >
85
+ <figure >
86
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-canny-cat-generated.png" width="300" alt="Generated image (ControlNet + prompt)"/>
87
+ <figcaption style="text-align: center;">generated image</figcaption>
88
+ </figure >
151
89
</div >
152
90
153
- </hfoption >
154
- </hfoptions >
155
-
156
91
## MultiAdapter
157
92
158
- T2I-Adapters are also composable, allowing you to use more than one adapter to impose multiple control conditions on an
159
- image. For example, you can use a pose map to provide structural control and a depth map for depth control. This is
160
- enabled by the [ ` MultiAdapter ` ] class.
93
+ You can compose multiple controls, such as canny image and a depth map, with the [ ` MultiAdapter ` ] class.
161
94
162
- Let's condition a text-to-image model with a pose and depth adapter. Create and place your depth and pose image and in a list.
95
+ The example below composes a canny image and depth map.
96
+
97
+ Load the control images and T2I-Adapters as a list.
163
98
164
99
``` py
100
+ import torch
165
101
from diffusers.utils import load_image
102
+ from diffusers import StableDiffusionXLAdapterPipeline, AutoencoderKL, MultiAdapter, T2IAdapter
166
103
167
- pose_image = load_image(
168
- " https://huggingface.co/datasets/diffusers/docs -images/resolve/main/t2i-adapter/keypose_sample_input .png"
104
+ canny_image = load_image(
105
+ " https://huggingface.co/datasets/huggingface/documentation -images/resolve/main/diffusers/canny-cat .png"
169
106
)
170
107
depth_image = load_image(
171
- " https://huggingface.co/datasets/diffusers/docs -images/resolve/main/t2i-adapter/depth_sample_input .png"
108
+ " https://huggingface.co/datasets/huggingface/documentation -images/resolve/main/diffusers/sdxl_depth_image .png"
172
109
)
173
- cond = [pose_image, depth_image]
174
- prompt = [" Santa Claus walking into an office room with a beautiful city view" ]
175
- ```
176
-
177
- <div class =" flex gap-4 " >
178
- <div >
179
- <img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"/>
180
- <figcaption class="mt-2 text-center text-sm text-gray-500">depth image</figcaption>
181
- </div >
182
- <div >
183
- <img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"/>
184
- <figcaption class="mt-2 text-center text-sm text-gray-500">pose image</figcaption>
185
- </div >
186
- </div >
187
-
188
- Load the corresponding pose and depth adapters as a list in the [ ` MultiAdapter ` ] class.
189
-
190
- ``` py
191
- import torch
192
- from diffusers import StableDiffusionAdapterPipeline, MultiAdapter, T2IAdapter
110
+ controls = [canny_image, depth_image]
111
+ prompt = ["""
112
+ a relaxed rabbit sitting on a striped towel next to a pool with a tropical drink nearby,
113
+ bright sunny day, vacation scene, 35mm photograph, film, professional, 4k, highly detailed
114
+ """ ]
193
115
194
116
adapters = MultiAdapter(
195
117
[
196
- T2IAdapter.from_pretrained(" TencentARC/t2iadapter_keypose_sd14v1 " ),
197
- T2IAdapter.from_pretrained(" TencentARC/t2iadapter_depth_sd14v1 " ),
118
+ T2IAdapter.from_pretrained(" TencentARC/t2i-adapter-canny-sdxl-1.0 " , torch_dtype = torch.float16 ),
119
+ T2IAdapter.from_pretrained(" TencentARC/t2i-adapter-depth-midas-sdxl-1.0 " , torch_dtype = torch.float16 ),
198
120
]
199
121
)
200
- adapters = adapters.to(torch.float16)
201
122
```
202
123
203
- Finally, load a [ ` StableDiffusionAdapterPipeline ` ] with the adapters, and pass your prompt and conditioned images to
204
- it. Use the [ ` adapter_conditioning_scale ` ] to adjust the weight of each adapter on the image.
124
+ Pass the adapters, prompt, and control images to [ ` StableDiffusionXLAdapterPipeline ` ] . Use the ` adapter_conditioning_scale ` parameter to determine how much weight to assign to each control.
205
125
206
126
``` py
207
- pipeline = StableDiffusionAdapterPipeline.from_pretrained(
208
- " CompVis/stable-diffusion-v1-4" ,
127
+ vae = AutoencoderKL.from_pretrained(" madebyollin/sdxl-vae-fp16-fix" , torch_dtype = torch.float16)
128
+ pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
129
+ " stabilityai/stable-diffusion-xl-base-1.0" ,
209
130
torch_dtype = torch.float16,
131
+ vae = vae,
210
132
adapter = adapters,
211
133
).to(" cuda" )
212
134
213
- image = pipeline(prompt, cond, adapter_conditioning_scale = [0.7 , 0.7 ]).images[0 ]
214
- image
135
+ pipeline(
136
+ prompt,
137
+ image = controls,
138
+ height = 1024 ,
139
+ width = 1024 ,
140
+ adapter_conditioning_scale = [0.7 , 0.7 ]
141
+ ).images[0 ]
215
142
```
216
143
217
- <div class =" flex justify-center " >
218
- <img class =" rounded-xl " src =" https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-multi.png " />
144
+ <div style =" display : flex ; gap : 10px ; justify-content : space-around ; align-items : flex-end ;" >
145
+ <figure >
146
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png" width="300" alt="Generated image (prompt only)"/>
147
+ <figcaption style="text-align: center;">canny image</figcaption>
148
+ </figure >
149
+ <figure >
150
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png" width="300" alt="Control image (Canny edges)"/>
151
+ <figcaption style="text-align: center;">depth map</figcaption>
152
+ </figure >
153
+ <figure >
154
+ <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-multi-rabbbit.png" width="300" alt="Generated image (ControlNet + prompt)"/>
155
+ <figcaption style="text-align: center;">generated image</figcaption>
156
+ </figure >
219
157
</div >
0 commit comments