Skip to content

Commit d8c617c

Browse files
hlkysayakpaul
andauthored
allow models to run with a user-provided dtype map instead of a single dtype (#10301)
* allow models to run with a user-provided dtype map instead of a single dtype * make style * Add warning, change `_` to `default` * make style * add test * handle shared tensors * remove warning --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent fe2b397 commit d8c617c

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,10 @@ def save_pretrained(
714714
if safe_serialization:
715715
# At some point we will need to deal better with save_function (used for TPU and other distributed
716716
# joyfulness), but for now this enough.
717-
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
717+
try:
718+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
719+
except RuntimeError:
720+
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
718721
else:
719722
torch.save(shard, filepath)
720723

src/diffusers/pipelines/pipeline_loading_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
592592
loaded_sub_model = passed_class_obj[name]
593593

594594
else:
595+
sub_model_dtype = (
596+
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
597+
if isinstance(torch_dtype, dict)
598+
else torch_dtype
599+
)
595600
loaded_sub_model = _load_empty_model(
596601
library_name=library_name,
597602
class_name=class_name,
@@ -600,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
600605
is_pipeline_module=is_pipeline_module,
601606
pipeline_class=pipeline_class,
602607
name=name,
603-
torch_dtype=torch_dtype,
608+
torch_dtype=sub_model_dtype,
604609
cached_folder=kwargs.get("cached_folder", None),
605610
force_download=kwargs.get("force_download", None),
606611
proxies=kwargs.get("proxies", None),
@@ -616,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
616621
# Obtain a sorted dictionary for mapping the model-level components
617622
# to their sizes.
618623
module_sizes = {
619-
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
624+
module_name: compute_module_sizes(
625+
module,
626+
dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
627+
if isinstance(torch_dtype, dict)
628+
else torch_dtype,
629+
)[""]
620630
for module_name, module in init_empty_modules.items()
621631
if isinstance(module, torch.nn.Module)
622632
}

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
552552
saved using
553553
[`~DiffusionPipeline.save_pretrained`].
554554
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
555-
torch_dtype (`str` or `torch.dtype`, *optional*):
555+
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
556556
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
557-
dtype is automatically derived from the model's weights.
557+
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
558+
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
559+
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
560+
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
558561
custom_pipeline (`str`, *optional*):
559562
560563
<Tip warning={true}>
@@ -703,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
703706
use_onnx = kwargs.pop("use_onnx", None)
704707
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
705708

706-
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
709+
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
707710
torch_dtype = torch.float32
708711
logger.warning(
709712
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
@@ -950,14 +953,19 @@ def load_module(name, value):
950953
loaded_sub_model = passed_class_obj[name]
951954
else:
952955
# load sub model
956+
sub_model_dtype = (
957+
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
958+
if isinstance(torch_dtype, dict)
959+
else torch_dtype
960+
)
953961
loaded_sub_model = load_sub_model(
954962
library_name=library_name,
955963
class_name=class_name,
956964
importable_classes=importable_classes,
957965
pipelines=pipelines,
958966
is_pipeline_module=is_pipeline_module,
959967
pipeline_class=pipeline_class,
960-
torch_dtype=torch_dtype,
968+
torch_dtype=sub_model_dtype,
961969
provider=provider,
962970
sess_options=sess_options,
963971
device_map=current_device_map,

tests/pipelines/test_pipelines_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,29 @@ def run_forward(pipe):
22832283
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
22842284
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))
22852285

2286+
def test_torch_dtype_dict(self):
2287+
components = self.get_dummy_components()
2288+
if not components:
2289+
self.skipTest("No dummy components defined.")
2290+
2291+
pipe = self.pipeline_class(**components)
2292+
2293+
specified_key = next(iter(components.keys()))
2294+
2295+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
2296+
pipe.save_pretrained(tmpdirname)
2297+
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
2298+
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
2299+
2300+
for name, component in loaded_pipe.components.items():
2301+
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
2302+
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
2303+
self.assertEqual(
2304+
component.dtype,
2305+
expected_dtype,
2306+
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
2307+
)
2308+
22862309

22872310
@is_staging_test
22882311
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)