Skip to content

Commit 6a0137e

Browse files
chengzeyigithub-actions[bot]yiyixuxu
authored
Fix Graph Breaks When Compiling CogView4 (#10959)
* Fix Graph Breaks When Compiling CogView4 Eliminate this: ``` t]V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] Recompiling function forward in /home/zeyi/repos/diffusers/src/diffusers/models/transformers/transformer_cogview4.py:374 V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] triggered by the following guard failure(s): V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/3: ___check_obj_id(L['self'].rope.freqs_h, 139976127328032) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/2: ___check_obj_id(L['self'].rope.freqs_h, 139976107780960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/1: ___check_obj_id(L['self'].rope.freqs_h, 140022511848960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/0: ___check_obj_id(L['self'].rope.freqs_h, 140024081342416) ``` * Update transformer_cogview4.py * fix cogview4 rotary pos embed * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 2e5203b commit 6a0137e

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module):
244244
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
245245
super().__init__()
246246

247+
self.dim = dim
247248
self.patch_size = patch_size
248249
self.rope_axes_dim = rope_axes_dim
249-
250-
dim_h, dim_w = dim // 2, dim // 2
251-
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
252-
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
253-
h_seq = torch.arange(self.rope_axes_dim[0])
254-
w_seq = torch.arange(self.rope_axes_dim[1])
255-
self.freqs_h = torch.outer(h_seq, h_inv_freq)
256-
self.freqs_w = torch.outer(w_seq, w_inv_freq)
250+
self.theta = theta
257251

258252
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
259253
batch_size, num_channels, height, width = hidden_states.shape
260254
height, width = height // self.patch_size, width // self.patch_size
261255

262-
h_idx = torch.arange(height)
263-
w_idx = torch.arange(width)
256+
dim_h, dim_w = self.dim // 2, self.dim // 2
257+
h_inv_freq = 1.0 / (
258+
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
259+
)
260+
w_inv_freq = 1.0 / (
261+
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
262+
)
263+
h_seq = torch.arange(self.rope_axes_dim[0])
264+
w_seq = torch.arange(self.rope_axes_dim[1])
265+
freqs_h = torch.outer(h_seq, h_inv_freq)
266+
freqs_w = torch.outer(w_seq, w_inv_freq)
267+
268+
h_idx = torch.arange(height, device=freqs_h.device)
269+
w_idx = torch.arange(width, device=freqs_w.device)
264270
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
265271
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
266272

267-
self.freqs_h = self.freqs_h.to(hidden_states.device)
268-
self.freqs_w = self.freqs_w.to(hidden_states.device)
269-
freqs_h = self.freqs_h[inner_h_idx]
270-
freqs_w = self.freqs_w[inner_w_idx]
273+
freqs_h = freqs_h[inner_h_idx]
274+
freqs_w = freqs_w[inner_w_idx]
271275

272276
# Create position matrices for height and width
273277
# [height, 1, dim//4] and [1, width, dim//4]

0 commit comments

Comments
 (0)