Skip to content

allow models to run with a user-provided dtype map instead of a single dtype #10301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 2, 2025
14 changes: 12 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
loaded_sub_model = passed_class_obj[name]

else:
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
Expand All @@ -600,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
proxies=kwargs.get("proxies", None),
Expand All @@ -616,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
module_name: compute_module_sizes(
module,
dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype,
)[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
Expand Down
30 changes: 26 additions & 4 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
saved using
[`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
torch_dtype (`str` or `torch.dtype`, *optional*):
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
custom_pipeline (`str`, *optional*):

<Tip warning={true}>
Expand Down Expand Up @@ -703,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
Expand Down Expand Up @@ -884,6 +887,20 @@ def load_module(name, value):

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

# Check `torch_dtype` map for unused keys
if isinstance(torch_dtype, dict):
extra_keys_dtype = set(torch_dtype.keys()) - set(passed_class_obj.keys())
extra_keys_obj = set(passed_class_obj.keys()) - set(torch_dtype.keys())
if len(extra_keys_dtype) > 0:
logger.warning(
f"Expected `{list(passed_class_obj.keys())}`, got extra `torch_dtype` keys `{extra_keys_dtype}`."
)
if len(extra_keys_obj) > 0:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this warning. I think the expectation of passed class objects is that their dtype is already set and if it isn't it happens at the model level where a dtype=None results in FP32 default.

f"Expected `{list(passed_class_obj.keys())}`, missing `torch_dtype` keys `{extra_keys_dtype}`."
" using `default` or `torch.float32`."
)

# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(
Expand Down Expand Up @@ -950,14 +967,19 @@ def load_module(name, value):
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
importable_classes=importable_classes,
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
provider=provider,
sess_options=sess_options,
device_map=current_device_map,
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,29 @@ def run_forward(pipe):
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))

def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")

pipe = self.pipeline_class(**components)

specified_key = next(iter(components.keys()))

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)

for name, component in loaded_pipe.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)


@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
Expand Down
Loading