Skip to content

Commit d70f8ee

Browse files
[WAN] fix recompilation issues (#11475)
* [tests] Add torch.compile() test for WanTransformer3DModel * fix wan recompilation issues. * style --------- Co-authored-by: tongyu0924 <winnie920924@gmail.com>
1 parent 06beeca commit d70f8ee

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
202202
p_t, p_h, p_w = self.patch_size
203203
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
204204

205-
self.freqs = self.freqs.to(hidden_states.device)
206-
freqs = self.freqs.split_with_sizes(
205+
freqs = self.freqs.to(hidden_states.device)
206+
freqs = freqs.split_with_sizes(
207207
[
208208
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
209209
self.attention_head_dim // 6,

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
import torch
1818

1919
from diffusers import WanTransformer3DModel
20-
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
20+
from diffusers.utils.testing_utils import (
21+
enable_full_determinism,
22+
is_torch_compile,
23+
require_torch_2,
24+
require_torch_gpu,
25+
slow,
26+
torch_device,
27+
)
2128

2229
from ..test_modeling_common import ModelTesterMixin
2330

@@ -79,3 +86,18 @@ def prepare_init_args_and_inputs_for_common(self):
7986
def test_gradient_checkpointing_is_applied(self):
8087
expected_set = {"WanTransformer3DModel"}
8188
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89+
90+
@require_torch_gpu
91+
@require_torch_2
92+
@is_torch_compile
93+
@slow
94+
def test_torch_compile_recompilation_and_graph_break(self):
95+
torch._dynamo.reset()
96+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
97+
98+
model = self.model_class(**init_dict).to(torch_device)
99+
model = torch.compile(model, fullgraph=True)
100+
101+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
102+
_ = model(**inputs_dict)
103+
_ = model(**inputs_dict)

0 commit comments

Comments
 (0)