Skip to content

Commit 3d102b0

Browse files
committed
add docs for ONNX upscaling, rename lookup table
1 parent 75cadf2 commit 3d102b0

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

docs/source/en/optimization/onnx.mdx

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ prompt = "a photo of an astronaut riding a horse on mars"
3737
image = pipe(prompt).images[0]
3838
```
3939

40+
The snippet below demonstrates how to use the ONNX runtime with the Stable Diffusion upscaling pipeline.
41+
42+
```python
43+
from diffusers import StableDiffusionOnnxPipeline
44+
45+
prompt = "a photo of an astronaut riding a horse on mars"
46+
steps = 50
47+
48+
txt2img = StableDiffusionOnnxPipeline.from_pretrained(
49+
"runwayml/stable-diffusion-v1-5",
50+
revision="onnx",
51+
provider="CUDAExecutionProvider",
52+
)
53+
small_image = txt2img(
54+
prompt,
55+
num_inference_steps=steps,
56+
).images[0]
57+
58+
generator = torch.manual_seed(0)
59+
upscale = OnnxStableDiffusionUpscalePipeline.from_pretrained(
60+
"ssube/stable-diffusion-x4-upscaler-onnx",
61+
provider="CUDAExecutionProvider",
62+
)
63+
large_image = upscale(
64+
prompt,
65+
small_image,
66+
generator=generator,
67+
num_inference_steps=steps,
68+
).images[0]
69+
```
70+
4071
## Known Issues
4172

4273
- Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
NUM_LATENT_CHANNELS = 4
1818
NUM_UNET_INPUT_CHANNELS = 7
1919

20-
TORCH_DTYPES = {
20+
ORT_TO_PT_TYPE = {
2121
"float16": torch.float16,
2222
"float32": torch.float32,
2323
}
@@ -91,7 +91,7 @@ def __call__(
9191
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
9292
)
9393

94-
latents_dtype = TORCH_DTYPES[str(text_embeddings.dtype)]
94+
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
9595

9696
# 4. Preprocess image
9797
image = preprocess(image)

0 commit comments

Comments
 (0)