Skip to content

Commit 4f0141a

Browse files
authored
Fix ONNX checkpoint loading (#2544)
* Revert "Disable ONNX tests (#2509)" This reverts commit a0549fe. * add external weights * + pb * style
1 parent 1021929 commit 4f0141a

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

.github/workflows/pr_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ jobs:
3131
runner: docker-cpu
3232
image: diffusers/diffusers-flax-cpu
3333
report: flax_cpu
34+
- name: Fast ONNXRuntime CPU tests on Ubuntu
35+
framework: onnxruntime
36+
runner: docker-cpu
37+
image: diffusers/diffusers-onnxruntime-cpu
38+
report: onnx_cpu
3439
- name: PyTorch Example CPU tests on Ubuntu
3540
framework: pytorch_examples
3641
runner: docker-cpu

.github/workflows/push_tests.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ jobs:
2929
runner: docker-tpu
3030
image: diffusers/diffusers-flax-tpu
3131
report: flax_tpu
32+
- name: Slow ONNXRuntime CUDA tests on Ubuntu
33+
framework: onnxruntime
34+
runner: docker-gpu
35+
image: diffusers/diffusers-onnxruntime-cuda
36+
report: onnx_cuda
3237

3338
name: ${{ matrix.config.name }}
3439

.github/workflows/push_tests_fast.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ jobs:
2929
runner: docker-cpu
3030
image: diffusers/diffusers-flax-cpu
3131
report: flax_cpu
32+
- name: Fast ONNXRuntime CPU tests on Ubuntu
33+
framework: onnxruntime
34+
runner: docker-cpu
35+
image: diffusers/diffusers-onnxruntime-cpu
36+
report: onnx_cpu
3237
- name: PyTorch Example CPU tests on Ubuntu
3338
framework: pytorch_examples
3439
runner: docker-cpu

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
6464
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
6565

66-
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
66+
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
6767

6868

6969
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -176,7 +176,13 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
176176

177177
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
178178
filenames = set(sibling.rfilename for sibling in info.siblings)
179-
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
179+
weight_names = [
180+
WEIGHTS_NAME,
181+
SAFETENSORS_WEIGHTS_NAME,
182+
FLAX_WEIGHTS_NAME,
183+
ONNX_WEIGHTS_NAME,
184+
ONNX_EXTERNAL_WEIGHTS_NAME,
185+
]
180186

181187
if is_transformers_available():
182188
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
@@ -604,7 +610,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
604610
]
605611

606612
if from_flax:
607-
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
613+
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
608614
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
609615
ignore_patterns = ["*.bin", "*.msgpack"]
610616

0 commit comments

Comments
 (0)