Skip to content

Commit 295a96d

Browse files
committed
fix scheduler compatibility and class labels dtype
1 parent 7082405 commit 295a96d

File tree

2 files changed

+144
-16
lines changed

2 files changed

+144
-16
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from typing import Any, Callable, List, Optional, Union
33

44
import numpy as np
5-
import torch
6-
75
import PIL
6+
import torch
87

98
from ...schedulers import DDPMScheduler
109
from ..onnx_utils import OnnxRuntimeModel
@@ -14,11 +13,15 @@
1413

1514
logger = getLogger(__name__)
1615

17-
# TODO: make this dynamic, from self.vae.config.latent_channels
18-
num_channels_latents = 4
1916

20-
# TODO: make this dynamic, from self.unet.config.in_channels
21-
unet_in_channels = 7
17+
NUM_LATENT_CHANNELS = 4
18+
NUM_UNET_INPUT_CHANNELS = 7
19+
20+
# TODO: should this be a lookup? it needs to match the conversion script
21+
class_labels_dtype = np.long
22+
23+
# TODO: should this be a lookup or converted? can it vary on ONNX?
24+
text_embeddings_dtype = torch.float32
2225

2326
###
2427
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
@@ -94,7 +97,6 @@ def __call__(
9497
text_embeddings = self._encode_prompt(
9598
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
9699
)
97-
text_embeddings_dtype = torch.float32 # TODO: convert text_embeddings.dtype to torch dtype
98100

99101
# 4. Preprocess image
100102
image = preprocess(image)
@@ -117,7 +119,7 @@ def __call__(
117119
height, width = image.shape[2:]
118120
latents = self.prepare_latents(
119121
batch_size * num_images_per_prompt,
120-
num_channels_latents,
122+
NUM_LATENT_CHANNELS,
121123
height,
122124
width,
123125
text_embeddings_dtype,
@@ -128,12 +130,12 @@ def __call__(
128130

129131
# 7. Check that sizes of image and latents match
130132
num_channels_image = image.shape[1]
131-
if num_channels_latents + num_channels_image != unet_in_channels:
133+
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
132134
raise ValueError(
133135
"Incorrect configuration settings! The config of `pipeline.unet` expects"
134-
f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +"
136+
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
135137
f" `num_channels_image`: {num_channels_image} "
136-
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
138+
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
137139
" `pipeline.unet` or your `image` input."
138140
)
139141

@@ -159,7 +161,7 @@ def __call__(
159161
sample=latent_model_input,
160162
timestep=timestep,
161163
encoder_hidden_states=text_embeddings,
162-
class_labels=noise_level,
164+
class_labels=noise_level.astype(class_labels_dtype),
163165
)[0]
164166

165167
# perform guidance
@@ -168,7 +170,9 @@ def __call__(
168170
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
169171

170172
# compute the previous noisy sample x_t -> x_t-1
171-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
173+
latents = self.scheduler.step(
174+
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
175+
).prev_sample
172176

173177
# call the callback, if provided
174178
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
import numpy as np
2020
import torch
2121

22-
from diffusers import OnnxStableDiffusionUpscalePipeline
22+
from diffusers import (
23+
DPMSolverMultistepScheduler,
24+
EulerAncestralDiscreteScheduler,
25+
EulerDiscreteScheduler,
26+
LMSDiscreteScheduler,
27+
OnnxStableDiffusionUpscalePipeline,
28+
PNDMScheduler,
29+
)
2330
from diffusers.utils import floats_tensor
2431
from diffusers.utils.testing_utils import (
2532
is_onnx_available,
@@ -68,6 +75,86 @@ def test_pipeline_default_ddpm(self):
6875
)
6976
assert np.abs(image_slice - expected_slice).max() < 1e-1
7077

78+
def test_pipeline_pndm(self):
79+
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
80+
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config, skip_prk_steps=True)
81+
pipe.set_progress_bar_config(disable=None)
82+
83+
inputs = self.get_dummy_inputs()
84+
image = pipe(**inputs).images
85+
image_slice = image[0, -3:, -3:, -1]
86+
87+
assert image.shape == (1, 512, 512, 3)
88+
expected_slice = np.array(
89+
[0.6898892, 0.59240556, 0.52499527, 0.58866215, 0.52258235, 0.52572715, 0.62414473, 0.6174387, 0.6214964]
90+
)
91+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
92+
93+
def test_pipeline_dpm_multistep(self):
94+
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
95+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
96+
pipe.set_progress_bar_config(disable=None)
97+
98+
inputs = self.get_dummy_inputs()
99+
image = pipe(**inputs).images
100+
image_slice = image[0, -3:, -3:, -1]
101+
102+
assert image.shape == (1, 512, 512, 3)
103+
expected_slice = np.array(
104+
[0.7659278, 0.76437664, 0.75579107, 0.7691116, 0.77666986, 0.7727672, 0.7758664, 0.7812226, 0.76942515]
105+
)
106+
107+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
108+
109+
def test_pipeline_lms(self):
110+
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
111+
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
112+
pipe.set_progress_bar_config(disable=None)
113+
114+
# warmup pass to apply optimizations
115+
_ = pipe(**self.get_dummy_inputs())
116+
117+
inputs = self.get_dummy_inputs()
118+
image = pipe(**inputs).images
119+
image_slice = image[0, -3:, -3:, -1]
120+
121+
assert image.shape == (1, 512, 512, 3)
122+
expected_slice = np.array(
123+
[0.6974782, 0.68902093, 0.70135885, 0.7583618, 0.7804545, 0.7854912, 0.78667426, 0.78743863, 0.78070223]
124+
)
125+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
126+
127+
def test_pipeline_euler(self):
128+
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
129+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
130+
pipe.set_progress_bar_config(disable=None)
131+
132+
inputs = self.get_dummy_inputs()
133+
image = pipe(**inputs).images
134+
image_slice = image[0, -3:, -3:, -1]
135+
136+
assert image.shape == (1, 512, 512, 3)
137+
expected_slice = np.array(
138+
[0.6974782, 0.68902093, 0.70135885, 0.7583618, 0.7804545, 0.7854912, 0.78667426, 0.78743863, 0.78070223]
139+
)
140+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
141+
142+
def test_pipeline_euler_ancestral(self):
143+
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
144+
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
145+
pipe.set_progress_bar_config(disable=None)
146+
147+
inputs = self.get_dummy_inputs()
148+
image = pipe(**inputs).images
149+
image_slice = image[0, -3:, -3:, -1]
150+
151+
assert image.shape == (1, 512, 512, 3)
152+
expected_slice = np.array(
153+
[0.77424496, 0.773601, 0.7645288, 0.7769598, 0.7772739, 0.7738688, 0.78187233, 0.77879584, 0.767043]
154+
)
155+
156+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
157+
71158

72159
@nightly
73160
@require_onnxruntime
@@ -98,8 +185,6 @@ def test_inference_default_ddpm(self):
98185
# using the PNDM scheduler by default
99186
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
100187
"ssube/stable-diffusion-x4-upscaler-onnx",
101-
safety_checker=None,
102-
feature_extractor=None,
103188
provider=self.gpu_provider,
104189
sess_options=self.gpu_options,
105190
)
@@ -124,3 +209,42 @@ def test_inference_default_ddpm(self):
124209
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
125210

126211
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
212+
213+
def test_inference_k_lms(self):
214+
init_image = load_image(
215+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
216+
"/img2img/sketch-mountains-input.jpg"
217+
)
218+
init_image = init_image.resize((128, 128))
219+
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
220+
"ssube/stable-diffusion-x4-upscaler-onnx", subfolder="scheduler"
221+
)
222+
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
223+
"ssube/stable-diffusion-x4-upscaler-onnx",
224+
scheduler=lms_scheduler,
225+
provider=self.gpu_provider,
226+
sess_options=self.gpu_options,
227+
)
228+
pipe.set_progress_bar_config(disable=None)
229+
230+
prompt = "A fantasy landscape, trending on artstation"
231+
232+
generator = torch.manual_seed(0)
233+
output = pipe(
234+
prompt=prompt,
235+
image=init_image,
236+
guidance_scale=7.5,
237+
num_inference_steps=20,
238+
generator=generator,
239+
output_type="np",
240+
)
241+
images = output.images
242+
image_slice = images[0, 255:258, 383:386, -1]
243+
244+
assert images.shape == (1, 512, 512, 3)
245+
expected_slice = np.array(
246+
[0.50173753, 0.50223356, 0.502039, 0.50233036, 0.5023725, 0.5022601, 0.5018758, 0.50234085, 0.50241566]
247+
)
248+
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
249+
250+
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2

0 commit comments

Comments
 (0)