diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 814547d82be4..19ac868cdae0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -714,10 +714,7 @@ def save_pretrained( if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - try: - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) - except RuntimeError: - safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"}) + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) else: torch.save(shard, filepath) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index cc5008e37292..d3e39e363f91 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2293,7 +2293,7 @@ def test_torch_dtype_dict(self): specified_key = next(iter(components.keys())) with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname, safe_serialization=False) torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)