Skip to content

Commit beacaa5

Browse files
a-r-r-o-wDN6stevhliu
authored
[core] Layerwise Upcasting (#10347)
* update * update * make style * remove dynamo disable * add coauthor Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> * update * update * update * update mixin * add some basic tests * update * update * non_blocking * improvements * update * norm.* -> norm * apply suggestions from review * add example * update hook implementation to the latest changes from pyramid attention broadcast * deinitialize should raise an error * update doc page * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update docs * update * refactor * fix _always_upcast_modules for asym ae and vq_model * fix lumina embedding forward to not depend on weight dtype * refactor tests * add simple lora inference tests * _always_upcast_modules -> _precision_sensitive_module_patterns * remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case * check layer dtypes in lora test * fix UNet1DModelTests::test_layerwise_upcasting_inference * _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback * skip test in NCSNppModelTests * skip tests for AutoencoderTinyTests * skip tests for AutoencoderOobleckTests * skip tests for UNet1DModelTests - unsupported pytorch operations * layerwise_upcasting -> layerwise_casting * skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support * add layerwise fp8 pipeline test * use xfail * Apply suggestions from code review Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass) * add note about memory consumption on tesla CI runner for failing test --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
1 parent a647682 commit beacaa5

File tree

73 files changed

+859
-4
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+859
-4
lines changed

docs/source/en/api/utilities.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
4141
## randn_tensor
4242

4343
[[autodoc]] utils.torch_utils.randn_tensor
44+
45+
## apply_layerwise_casting
46+
47+
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting

docs/source/en/optimization/memory.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run
158158

159159
</Tip>
160160

161+
## FP8 layerwise weight-casting
162+
163+
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
164+
165+
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
166+
167+
```python
168+
import torch
169+
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
170+
from diffusers.utils import export_to_video
171+
172+
model_id = "THUDM/CogVideoX-5b"
173+
174+
# Load the model in bfloat16 and enable layerwise casting
175+
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
176+
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
177+
178+
# Load the pipeline
179+
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
180+
pipe.to("cuda")
181+
182+
prompt = (
183+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
184+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
185+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
186+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
187+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
188+
"atmosphere of this unique musical performance."
189+
)
190+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
191+
export_to_video(video, "output.mp4", fps=8)
192+
```
193+
194+
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
195+
196+
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
197+
161198
## Channels-last memory format
162199

163200
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.

src/diffusers/hooks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ..utils import is_torch_available
2+
3+
4+
if is_torch_available():
5+
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

src/diffusers/hooks/hooks.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, Optional, Tuple
17+
18+
import torch
19+
20+
from ..utils.logging import get_logger
21+
22+
23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class ModelHook:
27+
r"""
28+
A hook that contains callbacks to be executed just before and after the forward method of a model.
29+
"""
30+
31+
_is_stateful = False
32+
33+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
34+
r"""
35+
Hook that is executed when a model is initialized.
36+
37+
Args:
38+
module (`torch.nn.Module`):
39+
The module attached to this hook.
40+
"""
41+
return module
42+
43+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
44+
r"""
45+
Hook that is executed when a model is deinitalized.
46+
47+
Args:
48+
module (`torch.nn.Module`):
49+
The module attached to this hook.
50+
"""
51+
module.forward = module._old_forward
52+
del module._old_forward
53+
return module
54+
55+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56+
r"""
57+
Hook that is executed just before the forward method of the model.
58+
59+
Args:
60+
module (`torch.nn.Module`):
61+
The module whose forward pass will be executed just after this event.
62+
args (`Tuple[Any]`):
63+
The positional arguments passed to the module.
64+
kwargs (`Dict[Str, Any]`):
65+
The keyword arguments passed to the module.
66+
Returns:
67+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
68+
A tuple with the treated `args` and `kwargs`.
69+
"""
70+
return args, kwargs
71+
72+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
73+
r"""
74+
Hook that is executed just after the forward method of the model.
75+
76+
Args:
77+
module (`torch.nn.Module`):
78+
The module whose forward pass been executed just before this event.
79+
output (`Any`):
80+
The output of the module.
81+
Returns:
82+
`Any`: The processed `output`.
83+
"""
84+
return output
85+
86+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
87+
r"""
88+
Hook that is executed when the hook is detached from a module.
89+
90+
Args:
91+
module (`torch.nn.Module`):
92+
The module detached from this hook.
93+
"""
94+
return module
95+
96+
def reset_state(self, module: torch.nn.Module):
97+
if self._is_stateful:
98+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
99+
return module
100+
101+
102+
class HookRegistry:
103+
def __init__(self, module_ref: torch.nn.Module) -> None:
104+
super().__init__()
105+
106+
self.hooks: Dict[str, ModelHook] = {}
107+
108+
self._module_ref = module_ref
109+
self._hook_order = []
110+
111+
def register_hook(self, hook: ModelHook, name: str) -> None:
112+
if name in self.hooks.keys():
113+
logger.warning(f"Hook with name {name} already exists, replacing it.")
114+
115+
if hasattr(self._module_ref, "_old_forward"):
116+
old_forward = self._module_ref._old_forward
117+
else:
118+
old_forward = self._module_ref.forward
119+
self._module_ref._old_forward = self._module_ref.forward
120+
121+
self._module_ref = hook.initialize_hook(self._module_ref)
122+
123+
if hasattr(hook, "new_forward"):
124+
rewritten_forward = hook.new_forward
125+
126+
def new_forward(module, *args, **kwargs):
127+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
128+
output = rewritten_forward(module, *args, **kwargs)
129+
return hook.post_forward(module, output)
130+
else:
131+
132+
def new_forward(module, *args, **kwargs):
133+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
134+
output = old_forward(*args, **kwargs)
135+
return hook.post_forward(module, output)
136+
137+
self._module_ref.forward = functools.update_wrapper(
138+
functools.partial(new_forward, self._module_ref), old_forward
139+
)
140+
141+
self.hooks[name] = hook
142+
self._hook_order.append(name)
143+
144+
def get_hook(self, name: str) -> Optional[ModelHook]:
145+
if name not in self.hooks.keys():
146+
return None
147+
return self.hooks[name]
148+
149+
def remove_hook(self, name: str, recurse: bool = True) -> None:
150+
if name in self.hooks.keys():
151+
hook = self.hooks[name]
152+
self._module_ref = hook.deinitalize_hook(self._module_ref)
153+
del self.hooks[name]
154+
self._hook_order.remove(name)
155+
156+
if recurse:
157+
for module_name, module in self._module_ref.named_modules():
158+
if module_name == "":
159+
continue
160+
if hasattr(module, "_diffusers_hook"):
161+
module._diffusers_hook.remove_hook(name, recurse=False)
162+
163+
def reset_stateful_hooks(self, recurse: bool = True) -> None:
164+
for hook_name in self._hook_order:
165+
hook = self.hooks[hook_name]
166+
if hook._is_stateful:
167+
hook.reset_state(self._module_ref)
168+
169+
if recurse:
170+
for module_name, module in self._module_ref.named_modules():
171+
if module_name == "":
172+
continue
173+
if hasattr(module, "_diffusers_hook"):
174+
module._diffusers_hook.reset_stateful_hooks(recurse=False)
175+
176+
@classmethod
177+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
178+
if not hasattr(module, "_diffusers_hook"):
179+
module._diffusers_hook = cls(module)
180+
return module._diffusers_hook
181+
182+
def __repr__(self) -> str:
183+
hook_repr = ""
184+
for i, hook_name in enumerate(self._hook_order):
185+
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
186+
if i < len(self._hook_order) - 1:
187+
hook_repr += "\n"
188+
return f"HookRegistry(\n{hook_repr}\n)"

0 commit comments

Comments
 (0)