From 9c5d08f8b913b58eed7d4e501f781506a779ea96 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 27 Jan 2025 08:14:18 +0000 Subject: [PATCH] Add provider_options to OnnxRuntimeModel --- src/diffusers/pipelines/onnx_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/onnx_utils.py b/src/diffusers/pipelines/onnx_utils.py index f4dbd4092e32..0e12340f6895 100644 --- a/src/diffusers/pipelines/onnx_utils.py +++ b/src/diffusers/pipelines/onnx_utils.py @@ -61,7 +61,7 @@ def __call__(self, **kwargs): return self.model.run(None, inputs) @staticmethod - def load_model(path: Union[str, Path], provider=None, sess_options=None): + def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` @@ -75,7 +75,9 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None): logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" - return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) + return ort.InferenceSession( + path, providers=[provider], sess_options=sess_options, provider_options=provider_options + ) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """