Skip to content

Commit 75cadf2

Browse files
committed
lookup latent and timestamp types
1 parent ea0fbdd commit 75cadf2

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from ...schedulers import DDPMScheduler
9-
from ..onnx_utils import OnnxRuntimeModel
9+
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
1010
from ..pipeline_utils import ImagePipelineOutput
1111
from . import StableDiffusionUpscalePipeline
1212

@@ -17,17 +17,10 @@
1717
NUM_LATENT_CHANNELS = 4
1818
NUM_UNET_INPUT_CHANNELS = 7
1919

20-
# TODO: should this be a lookup? it needs to match the conversion script
21-
class_labels_dtype = np.int64
22-
23-
# TODO: should this be a lookup or converted? can it vary on ONNX?
24-
text_embeddings_dtype = torch.float32
25-
26-
###
27-
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
28-
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
29-
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
30-
###
20+
TORCH_DTYPES = {
21+
"float16": torch.float16,
22+
"float32": torch.float32,
23+
}
3124

3225

3326
def preprocess(image):
@@ -98,6 +91,8 @@ def __call__(
9891
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
9992
)
10093

94+
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
95+
10196
# 4. Preprocess image
10297
image = preprocess(image)
10398
image = image.cpu()
@@ -108,7 +103,7 @@ def __call__(
108103

109104
# 5. Add noise to image
110105
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
111-
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
106+
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
112107
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
113108

114109
batch_multiplier = 2 if do_classifier_free_guidance else 1
@@ -122,7 +117,7 @@ def __call__(
122117
NUM_LATENT_CHANNELS,
123118
height,
124119
width,
125-
text_embeddings_dtype,
120+
latents_dtype,
126121
device,
127122
generator,
128123
latents,
@@ -142,6 +137,11 @@ def __call__(
142137
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
143138
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
144139

140+
timestep_dtype = next(
141+
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
142+
)
143+
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
144+
145145
# 9. Denoising loop
146146
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
147147
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -154,14 +154,14 @@ def __call__(
154154
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
155155

156156
# timestep to tensor
157-
timestep = np.array([t], dtype=np.float32)
157+
timestep = np.array([t], dtype=timestep_dtype)
158158

159159
# predict the noise residual
160160
noise_pred = self.unet(
161161
sample=latent_model_input,
162162
timestep=timestep,
163163
encoder_hidden_states=text_embeddings,
164-
class_labels=noise_level.astype(class_labels_dtype),
164+
class_labels=noise_level.astype(np.int64),
165165
)[0]
166166

167167
# perform guidance

0 commit comments

Comments
 (0)