6
6
import torch
7
7
8
8
from ...schedulers import DDPMScheduler
9
- from ..onnx_utils import OnnxRuntimeModel
9
+ from ..onnx_utils import ORT_TO_NP_TYPE , OnnxRuntimeModel
10
10
from ..pipeline_utils import ImagePipelineOutput
11
11
from . import StableDiffusionUpscalePipeline
12
12
17
17
NUM_LATENT_CHANNELS = 4
18
18
NUM_UNET_INPUT_CHANNELS = 7
19
19
20
- # TODO: should this be a lookup? it needs to match the conversion script
21
- class_labels_dtype = np .int64
22
-
23
- # TODO: should this be a lookup or converted? can it vary on ONNX?
24
- text_embeddings_dtype = torch .float32
25
-
26
- ###
27
- # This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
28
- # https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
29
- # https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
30
- ###
20
+ TORCH_DTYPES = {
21
+ "float16" : torch .float16 ,
22
+ "float32" : torch .float32 ,
23
+ }
31
24
32
25
33
26
def preprocess (image ):
@@ -98,6 +91,8 @@ def __call__(
98
91
prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt
99
92
)
100
93
94
+ latents_dtype = TORCH_DTYPES [str (text_embeddings .dtype )]
95
+
101
96
# 4. Preprocess image
102
97
image = preprocess (image )
103
98
image = image .cpu ()
@@ -108,7 +103,7 @@ def __call__(
108
103
109
104
# 5. Add noise to image
110
105
noise_level = torch .tensor ([noise_level ], dtype = torch .long , device = device )
111
- noise = torch .randn (image .shape , generator = generator , device = device , dtype = text_embeddings_dtype )
106
+ noise = torch .randn (image .shape , generator = generator , device = device , dtype = latents_dtype )
112
107
image = self .low_res_scheduler .add_noise (image , noise , noise_level )
113
108
114
109
batch_multiplier = 2 if do_classifier_free_guidance else 1
@@ -122,7 +117,7 @@ def __call__(
122
117
NUM_LATENT_CHANNELS ,
123
118
height ,
124
119
width ,
125
- text_embeddings_dtype ,
120
+ latents_dtype ,
126
121
device ,
127
122
generator ,
128
123
latents ,
@@ -142,6 +137,11 @@ def __call__(
142
137
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
143
138
extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
144
139
140
+ timestep_dtype = next (
141
+ (input .type for input in self .unet .model .get_inputs () if input .name == "timestep" ), "tensor(float)"
142
+ )
143
+ timestep_dtype = ORT_TO_NP_TYPE [timestep_dtype ]
144
+
145
145
# 9. Denoising loop
146
146
num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
147
147
with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -154,14 +154,14 @@ def __call__(
154
154
latent_model_input = np .concatenate ([latent_model_input , image ], axis = 1 )
155
155
156
156
# timestep to tensor
157
- timestep = np .array ([t ], dtype = np . float32 )
157
+ timestep = np .array ([t ], dtype = timestep_dtype )
158
158
159
159
# predict the noise residual
160
160
noise_pred = self .unet (
161
161
sample = latent_model_input ,
162
162
timestep = timestep ,
163
163
encoder_hidden_states = text_embeddings ,
164
- class_labels = noise_level .astype (class_labels_dtype ),
164
+ class_labels = noise_level .astype (np . int64 ),
165
165
)[0 ]
166
166
167
167
# perform guidance
0 commit comments