Skip to content

add OnnxStableDiffusionUpscalePipeline pipeline #2158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 6, 2023

Conversation

ssube
Copy link
Contributor

@ssube ssube commented Jan 30, 2023

I think I have a working implemention of an OnnxStableDiffusionUpscalePipeline, which extends StableDiffusionUpscalePipeline to be compatible with OnnxRuntimeModel. I'm hoping to get some feedback on whether this is the right approach, and if so, what else I need to do before this can be merged besides writing tests. There are a few spots in the code that I have questions about, marked with # TODOs and noted at the bottom here.

Motivation

Running the current StableDiffusionUpscalePipeline on a machine without CUDA acceleration can be pretty slow, even with relatively small 128x128 input images. I am writing a web UI for running ONNX pipelines that allows you to run a series of upscaling models (or one model repeatedly), but running StableDiffusionUpscalePipeline on a 1024px square input (split into 128px tiles) can easily take 60+ minutes on a 16 core CPU. Using the ONNX runtime is much faster, but that combination was not available, so I wrote this pipeline.

  • Per 128x128 tile:
    • Using StableDiffusionUpscalePipeline: 2.98s/it or 02:28 per tile
    • Using OnnxStableDiffusionUpscalePipeline w/ ROCmExecutionProvider: 6.46it/s or 00:07 per tile
    • Using OnnxStableDiffusionUpscalePipeline w/ DMLExecutionProvider: 1.17it/s or 00:42 per tile
  • Upscaling 512x512 -> 2048x2048, 16 runs with 50 inference steps each:
    • Using StableDiffusionUpscalePipeline: finished pipeline in 0:41:00.270845
    • Using OnnxStableDiffusionUpscalePipeline w/ ROCmExecutionProvider: finished pipeline in 0:02:10.359478
  • Upscaling 1024x1024 -> 4096x4096, 64 runs with 50 inference steps each:
    • Using StableDiffusionUpscalePipeline: still running
    • Using OnnxStableDiffusionUpscalePipeline w/ ROCmExecutionProvider: finished pipeline in 0:05:53.323918

I have only tested this using the CPUExecutionProvider and ROCmExecutionProvider so far, but I have machines set up for testing the CUDAExecutionProvider and DMLExecutionProviders and will check on them as well.

I tried to make the least-necessary changes and ended up only overriding a few methods. It looks like the preference in some of the other pipelines is to copy methods, which I can also do, but I wanted to find the minimum viable diff. Most of the changes are around passing named parameters to the models and replacing .sample with [0], but there are a few ndarray.int() calls that I'm not sure about, and the StableDiffusionUpscalePipeline code used some config values that do not appear to exist on OnnxRuntimeModel.

Example

prompt = "an astronaut eating a hamburger"
steps = 50

txt2img = StableDiffusionOnnxPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="onnx",
    provider="CUDAExecutionProvider",
)
small_image = txt2img(
    prompt,
    num_inference_steps=steps,
).images[0]

generator = torch.manual_seed(0)
upscale = OnnxStableDiffusionUpscalePipeline.from_pretrained(
    "ssube/stable-diffusion-x4-upscaler-onnx",
    provider="CUDAExecutionProvider",
)
large_image = upscale(
    prompt,
    small_image,
    generator=generator,
    num_inference_steps=steps,
).images[0]

TODOs

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 30, 2023

The documentation is not available anymore as the PR was closed or merged.

@ssube
Copy link
Contributor Author

ssube commented Jan 31, 2023

I added a basic test, which is passing locally (13 passed, 10 skipped in 67.31s (0:01:07)), but relies on an ONNX revision of stabilityai/stable-diffusion-x4-upscaler that does not exist in the https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/tree/main repo.

@ForserX
Copy link
Contributor

ForserX commented Jan 31, 2023

@ssube
And how did you translate this model into ONYX format? I catch a bunch of errors. (AMD GPU & Windows)
i wanna check into DML mode

@ssube
Copy link
Contributor Author

ssube commented Jan 31, 2023

@ForserX I'm using this script: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert.py#L206
It's very close to https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py, but the single_vae branches are new for upscaling. ssube/onnx-web@bacce0a#diff-2b8422f2625f7e1cd0ca3fa3e9975deed7d4962823108c2fc29f14c53e2c0cc6 is the bulk of the changes. I got it to work by switching between class_labels and return_dict on the UNet inputs and export a single VAE rather than splitting the encoder/decoder. No idea if that's right. 😄

@ForserX
Copy link
Contributor

ForserX commented Jan 31, 2023

How difficult everything is... I'll try, if it doesn't work out, I'll ask for a ready-made model))

@ssube
Copy link
Contributor Author

ssube commented Jan 31, 2023

Using that convert.py script, I was able to convert the model on Windows 10 and run it using the DirectMLExecutionProvider on an AMD GPU. The output looks about right, nothing unusual showing up. I've added the iteration and 128px tile times to the description. It's not as fast as ROCm, from initial testing, but still much faster than CPU (roughly 5x).

Some logs from that:

Fetching 17 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [02:49<00:00,  9.95s/it]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensur
e that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team a
nd Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior 
or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
C:\Users\ssube\stabdiff\onnx-try-2\onnx-web\api\onnx_env\lib\site-packages\transformers\models\clip\modeling_clip.py:754: TracerWarning: torch.tensor results are registered as c
onstants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this fu
nction. In any other case, this might cause the trace to be incorrect.
  mask.fill_(torch.tensor(torch.finfo(dtype).min))
C:\Users\ssube\stabdiff\onnx-try-2\onnx-web\api\onnx_env\lib\site-packages\torch\onnx\symbolic_opset9.py:5408: UserWarning: Exporting aten::index operator of advanced indexing i
n opset 14 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will p
roduce incorrect results.
  warnings.warn(
[2023-01-30 21:03:37,446] INFO: __main__: UNET config: FrozenDict([('sample_size', 128), ('in_channels', 7), ('out_channels', 4), ('center_input_sample', False), ('flip_sin_to_c
os', True), ('freq_shift', 0), ('down_block_types', ['DownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D']), ('mid_block_type', 'UNetMidBlock2DC
rossAttn'), ('up_block_types', ['CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'UpBlock2D']), ('only_cross_attention', [True, True, True, False]), ('block_out
_channels', [256, 512, 512, 1024]), ('layers_per_block', 2), ('downsample_padding', 1), ('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups', 32), ('norm_eps', 
1e-05), ('cross_attention_dim', 1024), ('attention_head_dim', 8), ('dual_cross_attention', False), ('use_linear_projection', True), ('class_embed_type', None), ('num_class_embed
s', 1000), ('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('_class_name', 'UNet2DConditionModel'), ('_diffusers_version', '0.9.0.dev0'), ('_name_or_path', 
'C:\\Users\\ssube/.cache\\huggingface\\diffusers\\models--stabilityai--stable-diffusion-x4-upscaler\\snapshots\\19b610c68ca7572defb6e09e64d1063f32b4db83\\unet')])
[2023-01-30 21:04:33,172] INFO: __main__: VAE config: FrozenDict([('in_channels', 3), ('out_channels', 3), ('down_block_types', ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'Dow
nEncoderBlock2D']), ('up_block_types', ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']), ('block_out_channels', [128, 256, 512]), ('layers_per_block', 2), ('act_fn'
, 'silu'), ('latent_channels', 4), ('norm_num_groups', 32), ('sample_size', 256), ('_class_name', 'AutoencoderKL'), ('_diffusers_version', '0.9.0.dev0'), ('_name_or_path', 'C:\\
Users\\ssube/.cache\\huggingface\\diffusers\\models--stabilityai--stable-diffusion-x4-upscaler\\snapshots\\19b610c68ca7572defb6e09e64d1063f32b4db83\\vae')])
[2023-01-30 21:04:43,174] INFO: __main__: exporting ONNX model
[2023-01-30 21:04:43,225] INFO: __main__: ONNX pipeline saved to ..\models\upscaling-stable-diffusion-x4
[2023-01-30 21:04:47,267] INFO: __main__: ONNX pipeline is loadable

and

[2023-01-30 21:29:02,983] INFO: onnx_web.chain.upscale_outpaint: final output image size: 1024x1024
[2023-01-30 21:29:02,984] INFO: onnx_web.chain.base: finished stage expand, result size: 1024x1024
[2023-01-30 21:29:02,984] INFO: onnx_web.chain.base: running stage upscale on image with dimensions 1024x1024, dict_keys(['output', 'size', 'prompt', 'scale', 'outscale', 'tile_
size', 'upscale'])
[2023-01-30 21:29:02,984] INFO: onnx_web.chain.base: image larger than tile size of SizeChart.mini, tiling stage
[2023-01-30 21:29:02,992] INFO: onnx_web.chain.utils: processing tile 1 of 64, 0.0
[2023-01-30 21:29:02,993] INFO: onnx_web.chain.upscale_stable_diffusion: upscaling with Stable Diffusion, 50 steps
2023-01-30 21:29:03.0777243 [W:onnxruntime:, inference_session.cc:493 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.
2023-01-30 21:29:04.0214862 [W:onnxruntime:, session_state.cc:1030 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-01-30 21:29:04.0253514 [W:onnxruntime:, session_state.cc:1032 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
2023-01-30 21:29:05.8783892 [W:onnxruntime:, inference_session.cc:493 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.
2023-01-30 21:29:05.9192614 [W:onnxruntime:, session_state.cc:1030 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-01-30 21:29:05.9231882 [W:onnxruntime:, session_state.cc:1032 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
2023-01-30 21:29:06.5044530 [W:onnxruntime:, inference_session.cc:493 onnxruntime::InferenceSession::RegisterExecutionProvider] Having memory pattern enabled is not supported while using the DML Execution Provider. So disabling it for this session since it uses the DML Execution Provider.
2023-01-30 21:29:06.7290335 [W:onnxruntime:, session_state.cc:1030 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-01-30 21:29:06.7331632 [W:onnxruntime:, session_state.cc:1032 onnxruntime::VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
  8%|███████████▏                                                                                                                                | 4/50 [00:09<01:53,  2.47s/it]

@ForserX
Copy link
Contributor

ForserX commented Jan 31, 2023

(roughly 5x)

The Vulkan variation of ESRGAN works even faster

Check mail, Please

@ssube
Copy link
Contributor Author

ssube commented Jan 31, 2023

I pushed a copy of the model that I have been using to https://huggingface.co/ssube/stable-diffusion-x4-upscaler-onnx and updated the tests accordingly 🤞

@patrickvonplaten
Copy link
Contributor

Cool, cc @anton-l @echarlaix for review

@ForserX
Copy link
Contributor

ForserX commented Jan 31, 2023

It remains to wait custiom VAE and LoRA for ONNX))

@ssube
Copy link
Contributor Author

ssube commented Feb 1, 2023

I added another, longer test and fixed up a few of the TODOs. The remaining ones are all related to hard-coded channel counts and the text_embeddings dtype, and I'm not sure where to look those up, they don't seem to be present on the OnnxRuntimeModel.

I also tried adding attention_mask back to the text encoder, but I don't see it being used in the other ONNX pipelines, and attempting to add it causes an 2 : INVALID_ARGUMENT : Invalid Feed Input Name:attention_mask error.

@ssube ssube changed the title [WIP] add OnnxStableDiffusionUpscalePipeline pipeline add OnnxStableDiffusionUpscalePipeline pipeline Feb 3, 2023
@ssube ssube mentioned this pull request Feb 6, 2023
68 tasks
@patrickvonplaten
Copy link
Contributor

cc @anton-l

@ssube
Copy link
Contributor Author

ssube commented Feb 9, 2023

Is there anything else I can/should add to this? I'm not sure where to look up the vae.config/unet.config equivalents, or how important that is.

@patrickvonplaten
Copy link
Contributor

@anton-l can you take a look here?

@ssube ssube force-pushed the feature/onnx-upscale branch from 39bdc34 to 295a96d Compare February 15, 2023 23:09
@ssube
Copy link
Contributor Author

ssube commented Feb 15, 2023

I've been using and testing this pipeline more, with more schedulers, and fixed a couple of issues related to the mix of numpy and torch types. There was an unsupported operand type(s) for *: 'numpy.ndarray' and 'Tensor' error with some (but not all) schedulers, which I fixed based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py#L437. I've added tests for all of the same schedulers that are tested in https://github.com/huggingface/diffusers/blob/main/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py except for a fast test for LMS discrete, which was timing out.

There were a few .config lookups that I wasn't sure about, but it looks like the other ONNX pipelines declare them as constants, so I did the same: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py#L34

The last issue I'm aware of is a slight difference between the parameter types to the scheduler.step() call: many of the other ONNX pipelines use something like torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs, but I think due to the restrictions in ORT, converting latents to a tensor causes an TypeError: expected np.ndarray (got Tensor) error and does not seem right here. torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs does appear to work.

I did run into one issue with int32 vs int64 types, but that appears to be related to how the model is trained or serialized, and exporting it again with the 4th input as a torch.long solved that:

     # UNET
     if single_vae:
         unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
-        unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
+        unet_scale = torch.tensor(4).to(
+            device=ctx.training_device, dtype=torch.long
+        )

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very impressive work @ssube, thank you so much for contributing!
Overall your implementation looks good to me, just left a couple of minor comments :)

For the int32 vs int64 issue: maybe it would be possible to infer the type at runtime, similar to

?

Comment on lines +17 to +18
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this works 👍

NUM_UNET_INPUT_CHANNELS = 7

# TODO: should this be a lookup? it needs to match the conversion script
class_labels_dtype = np.int64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integer types stay the same even in fp16 mode, so you can safely move it inline


# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text_embeddings_dtype can be inferred from text_embeddings (fp32 or fp16), so this shouldn't be a constant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so, let me fix that up

Comment on lines 26 to 30
###
# This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline:
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
# https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
###
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably ok to remove this disclaimer now 😄

Comment on lines +91 to +94
if hasattr(vae, "config"):
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
is_vae_scaling_factor_set_to_0_08333 = (
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @patrickvonplaten @patil-suraj for this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about this part, but if the VAE doesn't have .config, the current implement will throw without logging much.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

@ssube
Copy link
Contributor Author

ssube commented Feb 17, 2023

I inlined the integer type and put in lookups for the other two. One of them needed to go from numpy to the torch dtype since that's what the StableDiffusionUpscalePipeline expects, so I put in a little lookup table for that, hopefully that is ok: 75cadf2#diff-3815a0888bb607ca69fe4022fa3b4a809687fe2b3ae4d0ea0397288fac3c920bR20-R23

For the int32/64 issue that I mentioned, I tested that a little bit more, and everything seems to work as long as the type in the convert/export code and the pipeline match. Is there any reason not to use int64 there? For more context, this is my convert script and the relevant part is:

    # UNET
    if single_vae: # upscale pipeline
        unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
        unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long) # <- this is the type that needs to match
    else:
        unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
        unet_scale = torch.tensor(False).to(
            device=ctx.training_device, dtype=torch.bool
        )

    unet_in_channels = pipeline.unet.config.in_channels
    unet_sample_size = pipeline.unet.config.sample_size
    unet_path = output_path / "unet" / "model.onnx"
    onnx_export(
        pipeline.unet,
        model_args=(
            torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
                device=ctx.training_device, dtype=dtype
            ),
            torch.randn(2).to(device=ctx.training_device, dtype=dtype),
            torch.randn(2, num_tokens, text_hidden_size).to(
                device=ctx.training_device, dtype=dtype
            ),
            unet_scale,
        ),
        output_path=unet_path,
        ordered_input_names=unet_inputs,
        # has to be different from "sample" for correct tracing
        output_names=["out_sample"],
        dynamic_axes={
            "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            "timestep": {0: "batch"},
            "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=ctx.opset,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )

@ssube ssube force-pushed the feature/onnx-upscale branch from 3d102b0 to 9b7810f Compare February 20, 2023 04:02
@patrickvonplaten
Copy link
Contributor

Looks good to me - thanks for checking the PR @anton-l :-)

cc @williamberman could you also take a quick look?

@patrickvonplaten
Copy link
Contributor

Merging to not block the community contributor here

@patrickvonplaten patrickvonplaten merged commit 9920c33 into huggingface:main Mar 6, 2023
mengfei25 pushed a commit to mengfei25/diffusers that referenced this pull request Mar 27, 2023
* [Onnx] add Stable Diffusion Upscale pipeline

* add a test for the OnnxStableDiffusionUpscalePipeline

* check for VAE config before adjusting scaling factor

* update test assertions, lint fixes

* run fix-copies target

* switch test checkpoint to one hosted on huggingface

* partially restore attention mask

* reshape embeddings after running text encoder

* add longer nightly test for ONNX upscale pipeline

* use package import to fix tests

* fix scheduler compatibility and class labels dtype

* use more precise type

* remove LMS from fast tests

* lookup latent and timestamp types

* add docs for ONNX upscaling, rename lookup table

* replace deprecated pipeline names in ONNX docs
w4ffl35 pushed a commit to w4ffl35/diffusers that referenced this pull request Apr 14, 2023
* [Onnx] add Stable Diffusion Upscale pipeline

* add a test for the OnnxStableDiffusionUpscalePipeline

* check for VAE config before adjusting scaling factor

* update test assertions, lint fixes

* run fix-copies target

* switch test checkpoint to one hosted on huggingface

* partially restore attention mask

* reshape embeddings after running text encoder

* add longer nightly test for ONNX upscale pipeline

* use package import to fix tests

* fix scheduler compatibility and class labels dtype

* use more precise type

* remove LMS from fast tests

* lookup latent and timestamp types

* add docs for ONNX upscaling, rename lookup table

* replace deprecated pipeline names in ONNX docs
@zetyquickly
Copy link
Contributor

Hello. On version diffusers > 0.16.0 this pipeline throws exception due to vae.config attribute check is removed.

File "/opt/conda/envs/lora/lib/python3.9/site-packages/diffusers/pipelines/pipeline_utils.py", line 1101, in from_pretrained
    model = pipeline_class(**init_kwargs)
  File "/opt/conda/envs/lora/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py", line 59, in __init__
    super().__init__(
  File "/opt/conda/envs/lora/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py", line 134, in __init__
    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
AttributeError: 'OnnxRuntimeModel' object has no attribute 'config'

@patrickvonplaten
Copy link
Contributor

Thanks for the ping @zetyquickly ! Would you like to open an issue to fix it?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [Onnx] add Stable Diffusion Upscale pipeline

* add a test for the OnnxStableDiffusionUpscalePipeline

* check for VAE config before adjusting scaling factor

* update test assertions, lint fixes

* run fix-copies target

* switch test checkpoint to one hosted on huggingface

* partially restore attention mask

* reshape embeddings after running text encoder

* add longer nightly test for ONNX upscale pipeline

* use package import to fix tests

* fix scheduler compatibility and class labels dtype

* use more precise type

* remove LMS from fast tests

* lookup latent and timestamp types

* add docs for ONNX upscaling, rename lookup table

* replace deprecated pipeline names in ONNX docs
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* [Onnx] add Stable Diffusion Upscale pipeline

* add a test for the OnnxStableDiffusionUpscalePipeline

* check for VAE config before adjusting scaling factor

* update test assertions, lint fixes

* run fix-copies target

* switch test checkpoint to one hosted on huggingface

* partially restore attention mask

* reshape embeddings after running text encoder

* add longer nightly test for ONNX upscale pipeline

* use package import to fix tests

* fix scheduler compatibility and class labels dtype

* use more precise type

* remove LMS from fast tests

* lookup latent and timestamp types

* add docs for ONNX upscaling, rename lookup table

* replace deprecated pipeline names in ONNX docs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants