From 7fc465fe4df5a4c1ba381f217d512dcc8caef583 Mon Sep 17 00:00:00 2001 From: C Date: Tue, 4 Mar 2025 18:45:56 +0800 Subject: [PATCH 1/4] Fix Graph Breaks When Compiling CogView4 Eliminate this: ``` t]V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] Recompiling function forward in /home/zeyi/repos/diffusers/src/diffusers/models/transformers/transformer_cogview4.py:374 V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] triggered by the following guard failure(s): V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/3: ___check_obj_id(L['self'].rope.freqs_h, 139976127328032) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/2: ___check_obj_id(L['self'].rope.freqs_h, 139976107780960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/1: ___check_obj_id(L['self'].rope.freqs_h, 140022511848960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/0: ___check_obj_id(L['self'].rope.freqs_h, 140024081342416) ``` --- .../models/transformers/transformer_cogview4.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index f622791b572f..c6bfb6fb50da 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -252,20 +252,18 @@ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], th w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) h_seq = torch.arange(self.rope_axes_dim[0]) w_seq = torch.arange(self.rope_axes_dim[1]) - self.freqs_h = torch.outer(h_seq, h_inv_freq) - self.freqs_w = torch.outer(w_seq, w_inv_freq) + self.freqs_h = torch.nn.Buffer(torch.outer(h_seq, h_inv_freq)) + self.freqs_w = torch.nn.Buffer(torch.outer(w_seq, w_inv_freq)) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_channels, height, width = hidden_states.shape height, width = height // self.patch_size, width // self.patch_size - h_idx = torch.arange(height) - w_idx = torch.arange(width) + h_idx = torch.arange(height, device=self.freqs_h.device) + w_idx = torch.arange(width, device=self.freqs_w.device) inner_h_idx = h_idx * self.rope_axes_dim[0] // height inner_w_idx = w_idx * self.rope_axes_dim[1] // width - self.freqs_h = self.freqs_h.to(hidden_states.device) - self.freqs_w = self.freqs_w.to(hidden_states.device) freqs_h = self.freqs_h[inner_h_idx] freqs_w = self.freqs_w[inner_w_idx] From 9080bd66e9fd04c11df3f8e7fbae5877ceb095e9 Mon Sep 17 00:00:00 2001 From: C Date: Tue, 4 Mar 2025 18:56:32 +0800 Subject: [PATCH 2/4] Update transformer_cogview4.py --- src/diffusers/models/transformers/transformer_cogview4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index c6bfb6fb50da..bd027eb9587e 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -252,8 +252,8 @@ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], th w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) h_seq = torch.arange(self.rope_axes_dim[0]) w_seq = torch.arange(self.rope_axes_dim[1]) - self.freqs_h = torch.nn.Buffer(torch.outer(h_seq, h_inv_freq)) - self.freqs_w = torch.nn.Buffer(torch.outer(w_seq, w_inv_freq)) + self.freqs_h = self.register_buffer("freqs_h", torch.outer(h_seq, h_inv_freq), persistent=False) + self.freqs_w = self.register_buffer("freqs_h", torch.outer(w_seq, w_inv_freq), persistent=False) def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_channels, height, width = hidden_states.shape From 74b591a4f30485f41234fa7a37cb3194b158283b Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Thu, 6 Mar 2025 00:37:47 +0000 Subject: [PATCH 3/4] fix cogview4 rotary pos embed --- .../transformers/transformer_cogview4.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index bd027eb9587e..c060dec55c65 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -244,28 +244,30 @@ class CogView4RotaryPosEmbed(nn.Module): def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None: super().__init__() + self.dim = dim self.patch_size = patch_size self.rope_axes_dim = rope_axes_dim - - dim_h, dim_w = dim // 2, dim // 2 - h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)) - w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) - h_seq = torch.arange(self.rope_axes_dim[0]) - w_seq = torch.arange(self.rope_axes_dim[1]) - self.freqs_h = self.register_buffer("freqs_h", torch.outer(h_seq, h_inv_freq), persistent=False) - self.freqs_w = self.register_buffer("freqs_h", torch.outer(w_seq, w_inv_freq), persistent=False) + self.theta = theta def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_channels, height, width = hidden_states.shape height, width = height // self.patch_size, width // self.patch_size - h_idx = torch.arange(height, device=self.freqs_h.device) - w_idx = torch.arange(width, device=self.freqs_w.device) + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)) + w_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) + h_seq = torch.arange(self.rope_axes_dim[0]) + w_seq = torch.arange(self.rope_axes_dim[1]) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + h_idx = torch.arange(height, device=freqs_h.device) + w_idx = torch.arange(width, device=freqs_w.device) inner_h_idx = h_idx * self.rope_axes_dim[0] // height inner_w_idx = w_idx * self.rope_axes_dim[1] // width - freqs_h = self.freqs_h[inner_h_idx] - freqs_w = self.freqs_w[inner_w_idx] + freqs_h = freqs_h[inner_h_idx] + freqs_w = freqs_w[inner_w_idx] # Create position matrices for height and width # [height, 1, dim//4] and [1, width, dim//4] From 7d18503d071c7279f78ab888296f342797d9c9a7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 6 Mar 2025 20:26:53 +0000 Subject: [PATCH 4/4] Apply style fixes --- src/diffusers/models/transformers/transformer_cogview4.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index c060dec55c65..db261ca1ea4b 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -254,8 +254,12 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens height, width = height // self.patch_size, width // self.patch_size dim_h, dim_w = self.dim // 2, self.dim // 2 - h_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)) - w_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) h_seq = torch.arange(self.rope_axes_dim[0]) w_seq = torch.arange(self.rope_axes_dim[1]) freqs_h = torch.outer(h_seq, h_inv_freq)