From 36432462a2291d03f382000496806af102800d87 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 12:21:17 +0000 Subject: [PATCH 1/7] allow models to run with a user-provided dtype map instead of a single dtype --- src/diffusers/pipelines/pipeline_loading_utils.py | 14 ++++++++++++-- src/diffusers/pipelines/pipeline_utils.py | 12 ++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 0a7a222ec007..70f05c7d7c52 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -554,6 +554,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic loaded_sub_model = passed_class_obj[name] else: + sub_model_dtype = ( + torch_dtype.get(name, torch_dtype.get("_", torch.float32)) + if isinstance(torch_dtype, dict) + else torch_dtype + ) loaded_sub_model = _load_empty_model( library_name=library_name, class_name=class_name, @@ -562,7 +567,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, name=name, - torch_dtype=torch_dtype, + torch_dtype=sub_model_dtype, cached_folder=kwargs.get("cached_folder", None), force_download=kwargs.get("force_download", None), proxies=kwargs.get("proxies", None), @@ -578,7 +583,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic # Obtain a sorted dictionary for mapping the model-level components # to their sizes. module_sizes = { - module_name: compute_module_sizes(module, dtype=torch_dtype)[""] + module_name: compute_module_sizes( + module, + dtype=torch_dtype.get(module_name, torch_dtype.get("_", torch.float32)) + if isinstance(torch_dtype, dict) + else torch_dtype, + )[""] for module_name, module in init_empty_modules.items() if isinstance(module, torch.nn.Module) } diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c505c5a262a3..3a1af33c5805 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -530,9 +530,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. - torch_dtype (`str` or `torch.dtype`, *optional*): + torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. + To load submodels with different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). + Set the default dtype for unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). + If a component is not specifed and no default is set, `torch.float32` is used. custom_pipeline (`str`, *optional*): @@ -921,6 +924,11 @@ def load_module(name, value): loaded_sub_model = passed_class_obj[name] else: # load sub model + sub_model_dtype = ( + torch_dtype.get(name, torch_dtype.get("_", torch.float32)) + if isinstance(torch_dtype, dict) + else torch_dtype + ) loaded_sub_model = load_sub_model( library_name=library_name, class_name=class_name, @@ -928,7 +936,7 @@ def load_module(name, value): pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, - torch_dtype=torch_dtype, + torch_dtype=sub_model_dtype, provider=provider, sess_options=sess_options, device_map=current_device_map, From db770067104e6e1e5833d75089377a693be538da Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 12:28:03 +0000 Subject: [PATCH 2/7] make style --- src/diffusers/pipelines/pipeline_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3a1af33c5805..aec41cb15db8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -532,10 +532,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the - dtype is automatically derived from the model's weights. - To load submodels with different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). - Set the default dtype for unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). - If a component is not specifed and no default is set, `torch.float32` is used. + dtype is automatically derived from the model's weights. To load submodels with different dtype pass a + `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for + unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). If + a component is not specifed and no default is set, `torch.float32` is used. custom_pipeline (`str`, *optional*): From 2c58c6478fba9ffbe71b17bb251a35c04fda492e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 14:49:02 +0000 Subject: [PATCH 3/7] Add warning, change `_` to `default` --- .../pipelines/pipeline_loading_utils.py | 4 ++-- src/diffusers/pipelines/pipeline_utils.py | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 70f05c7d7c52..928bb8fd482e 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -555,7 +555,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic else: sub_model_dtype = ( - torch_dtype.get(name, torch_dtype.get("_", torch.float32)) + torch_dtype.get(name, torch_dtype.get("default", torch.float32)) if isinstance(torch_dtype, dict) else torch_dtype ) @@ -585,7 +585,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic module_sizes = { module_name: compute_module_sizes( module, - dtype=torch_dtype.get(module_name, torch_dtype.get("_", torch.float32)) + dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32)) if isinstance(torch_dtype, dict) else torch_dtype, )[""] diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index aec41cb15db8..dfbd367d0d54 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -534,8 +534,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. To load submodels with different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for - unspecified components with `_` (for example `{'transformer': torch.bfloat16, '_': torch.float16}`). If - a component is not specifed and no default is set, `torch.float32` is used. + unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default': torch.float16}`). + If a component is not specified and no default is set, `torch.float32` is used. custom_pipeline (`str`, *optional*): @@ -858,6 +858,20 @@ def load_module(name, value): f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." ) + # Check `torch_dtype` map for unused keys + if isinstance(torch_dtype, dict): + extra_keys_dtype = set(torch_dtype.keys()) - set(passed_class_obj.keys()) + extra_keys_obj = set(passed_class_obj.keys()) - set(torch_dtype.keys()) + if len(extra_keys_dtype) > 0: + logger.warning( + f"Expected `{list(passed_class_obj.keys())}`, got extra `torch_dtype` keys `{extra_keys_dtype}`." + ) + if len(extra_keys_obj) > 0: + logger.warning( + f"Expected `{list(passed_class_obj.keys())}`, missing `torch_dtype` keys `{extra_keys_dtype}`." + " using `default` or `torch.float32`." + ) + # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( @@ -925,7 +939,7 @@ def load_module(name, value): else: # load sub model sub_model_dtype = ( - torch_dtype.get(name, torch_dtype.get("_", torch.float32)) + torch_dtype.get(name, torch_dtype.get("default", torch.float32)) if isinstance(torch_dtype, dict) else torch_dtype ) From 2adba04c0146c6d1c5bb47aefff81b12081e96af Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 14:51:54 +0000 Subject: [PATCH 4/7] make style --- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index dfbd367d0d54..553d644380cd 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -534,8 +534,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. To load submodels with different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for - unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default': torch.float16}`). - If a component is not specified and no default is set, `torch.float32` is used. + unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default': + torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used. custom_pipeline (`str`, *optional*): From ec53008a5fb9e63c0d7164f020c58a8da17b6ab9 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 1 Apr 2025 07:30:29 +0100 Subject: [PATCH 5/7] add test --- src/diffusers/pipelines/pipeline_utils.py | 2 +- tests/pipelines/test_pipelines_common.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 15bc50a8b385..abe9cb56795f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -706,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 logger.warning( f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d069def66ecf..cc5008e37292 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2283,6 +2283,29 @@ def run_forward(pipe): self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4)) self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4)) + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + if not components: + self.skipTest("No dummy components defined.") + + pipe = self.pipeline_class(**components) + + specified_key = next(iter(components.keys())) + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + pipe.save_pretrained(tmpdirname) + torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict) + + for name, component in loaded_pipe.components.items(): + if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"): + expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32)) + self.assertEqual( + component.dtype, + expected_dtype, + f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", + ) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): From e8aa61b0ffe17746da89eac966417c76bde55e24 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 1 Apr 2025 13:30:21 +0100 Subject: [PATCH 6/7] handle shared tensors --- src/diffusers/models/modeling_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 19ac868cdae0..814547d82be4 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -714,7 +714,10 @@ def save_pretrained( if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + try: + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + except RuntimeError: + safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"}) else: torch.save(shard, filepath) From 8f713111acff4361605df037f576cf6a97a09567 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 1 Apr 2025 13:30:31 +0100 Subject: [PATCH 7/7] remove warning --- src/diffusers/pipelines/pipeline_utils.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index abe9cb56795f..0df4b477e1b9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -887,20 +887,6 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - # Check `torch_dtype` map for unused keys - if isinstance(torch_dtype, dict): - extra_keys_dtype = set(torch_dtype.keys()) - set(passed_class_obj.keys()) - extra_keys_obj = set(passed_class_obj.keys()) - set(torch_dtype.keys()) - if len(extra_keys_dtype) > 0: - logger.warning( - f"Expected `{list(passed_class_obj.keys())}`, got extra `torch_dtype` keys `{extra_keys_dtype}`." - ) - if len(extra_keys_obj) > 0: - logger.warning( - f"Expected `{list(passed_class_obj.keys())}`, missing `torch_dtype` keys `{extra_keys_dtype}`." - " using `default` or `torch.float32`." - ) - # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError(