Skip to content

Commit 0e975e5

Browse files
[Safetensors] Make sure metadata is saved (#2506)
* [Safetensors] Make sure metadata is saved * make style
1 parent 7f43f65 commit 0e975e5

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,6 @@ def save_pretrained(
291291
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
292292
return
293293

294-
if save_function is None:
295-
save_function = safetensors.torch.save_file if safe_serialization else torch.save
296-
297294
os.makedirs(save_directory, exist_ok=True)
298295

299296
model_to_save = self
@@ -310,7 +307,12 @@ def save_pretrained(
310307
weights_name = _add_variant(weights_name, variant)
311308

312309
# Save the model
313-
save_function(state_dict, os.path.join(save_directory, weights_name))
310+
if safe_serialization:
311+
safetensors.torch.save_file(
312+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
313+
)
314+
else:
315+
torch.save(state_dict, os.path.join(save_directory, weights_name))
314316

315317
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
316318

0 commit comments

Comments
 (0)