Skip to content

Commit a1f9a71

Browse files
authored
fix offload gpu tests etc (#10366)
* add * style
1 parent ec37e20 commit a1f9a71

File tree

3 files changed

+24
-39
lines changed

3 files changed

+24
-39
lines changed

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8282
return hidden_states
8383

8484

85+
class SanaModulatedNorm(nn.Module):
86+
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
87+
super().__init__()
88+
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
89+
90+
def forward(
91+
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
92+
) -> torch.Tensor:
93+
hidden_states = self.norm(hidden_states)
94+
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
95+
hidden_states = hidden_states * (1 + scale) + shift
96+
return hidden_states
97+
98+
8599
class SanaTransformerBlock(nn.Module):
86100
r"""
87101
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
221235
"""
222236

223237
_supports_gradient_checkpointing = True
224-
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
238+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
225239

226240
@register_to_config
227241
def __init__(
@@ -288,8 +302,7 @@ def __init__(
288302

289303
# 4. Output blocks
290304
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
291-
292-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
305+
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
293306
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
294307

295308
self.gradient_checkpointing = False
@@ -462,13 +475,8 @@ def custom_forward(*inputs):
462475
)
463476

464477
# 3. Normalization
465-
shift, scale = (
466-
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
467-
).chunk(2, dim=1)
468-
hidden_states = self.norm_out(hidden_states)
478+
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
469479

470-
# 4. Modulation
471-
hidden_states = hidden_states * (1 + scale) + shift
472480
hidden_states = self.proj_out(hidden_states)
473481

474482
# 5. Unpatchify

tests/models/test_modeling_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import requests_mock
3030
import torch
3131
import torch.nn as nn
32-
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
32+
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
3333
from huggingface_hub import ModelCard, delete_repo, snapshot_download
3434
from huggingface_hub.utils import is_jinja_available
3535
from parameterized import parameterized
@@ -1080,7 +1080,7 @@ def test_cpu_offload(self):
10801080
torch.manual_seed(0)
10811081
base_output = model(**inputs_dict)
10821082

1083-
model_size = compute_module_persistent_sizes(model)[""]
1083+
model_size = compute_module_sizes(model)[""]
10841084
# We test several splits of sizes to make sure it works.
10851085
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
10861086
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1110,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
11101110
torch.manual_seed(0)
11111111
base_output = model(**inputs_dict)
11121112

1113-
model_size = compute_module_persistent_sizes(model)[""]
1113+
model_size = compute_module_sizes(model)[""]
11141114
with tempfile.TemporaryDirectory() as tmp_dir:
11151115
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
11161116

@@ -1144,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
11441144
torch.manual_seed(0)
11451145
base_output = model(**inputs_dict)
11461146

1147-
model_size = compute_module_persistent_sizes(model)[""]
1147+
model_size = compute_module_sizes(model)[""]
11481148
with tempfile.TemporaryDirectory() as tmp_dir:
11491149
model.cpu().save_pretrained(tmp_dir)
11501150

@@ -1172,7 +1172,7 @@ def test_model_parallelism(self):
11721172
torch.manual_seed(0)
11731173
base_output = model(**inputs_dict)
11741174

1175-
model_size = compute_module_persistent_sizes(model)[""]
1175+
model_size = compute_module_sizes(model)[""]
11761176
# We test several splits of sizes to make sure it works.
11771177
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11781178
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1183,6 +1183,7 @@ def test_model_parallelism(self):
11831183
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
11841184
# Making sure part of the model will actually end up offloaded
11851185
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1186+
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
11861187

11871188
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
11881189

tests/models/transformers/test_models_transformer_sana.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import pytest
1817
import torch
1918

2019
from diffusers import SanaTransformer2DModel
@@ -33,6 +32,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
3332
model_class = SanaTransformer2DModel
3433
main_input_name = "hidden_states"
3534
uses_custom_attn_processor = True
35+
model_split_percents = [0.7, 0.7, 0.9]
3636

3737
@property
3838
def dummy_input(self):
@@ -81,27 +81,3 @@ def prepare_init_args_and_inputs_for_common(self):
8181
def test_gradient_checkpointing_is_applied(self):
8282
expected_set = {"SanaTransformer2DModel"}
8383
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
84-
85-
@pytest.mark.xfail(
86-
condition=torch.device(torch_device).type == "cuda",
87-
reason="Test currently fails.",
88-
strict=True,
89-
)
90-
def test_cpu_offload(self):
91-
return super().test_cpu_offload()
92-
93-
@pytest.mark.xfail(
94-
condition=torch.device(torch_device).type == "cuda",
95-
reason="Test currently fails.",
96-
strict=True,
97-
)
98-
def test_disk_offload_with_safetensors(self):
99-
return super().test_disk_offload_with_safetensors()
100-
101-
@pytest.mark.xfail(
102-
condition=torch.device(torch_device).type == "cuda",
103-
reason="Test currently fails.",
104-
strict=True,
105-
)
106-
def test_disk_offload_without_safetensors(self):
107-
return super().test_disk_offload_without_safetensors()

0 commit comments

Comments
 (0)