diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index f622791b572f..db261ca1ea4b 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module): def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None: super().__init__() + self.dim = dim self.patch_size = patch_size self.rope_axes_dim = rope_axes_dim - - dim_h, dim_w = dim // 2, dim // 2 - h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)) - w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) - h_seq = torch.arange(self.rope_axes_dim[0]) - w_seq = torch.arange(self.rope_axes_dim[1]) - self.freqs_h = torch.outer(h_seq, h_inv_freq) - self.freqs_w = torch.outer(w_seq, w_inv_freq) + self.theta = theta def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, num_channels, height, width = hidden_states.shape height, width = height // self.patch_size, width // self.patch_size - h_idx = torch.arange(height) - w_idx = torch.arange(width) + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(self.rope_axes_dim[0]) + w_seq = torch.arange(self.rope_axes_dim[1]) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + h_idx = torch.arange(height, device=freqs_h.device) + w_idx = torch.arange(width, device=freqs_w.device) inner_h_idx = h_idx * self.rope_axes_dim[0] // height inner_w_idx = w_idx * self.rope_axes_dim[1] // width - self.freqs_h = self.freqs_h.to(hidden_states.device) - self.freqs_w = self.freqs_w.to(hidden_states.device) - freqs_h = self.freqs_h[inner_h_idx] - freqs_w = self.freqs_w[inner_w_idx] + freqs_h = freqs_h[inner_h_idx] + freqs_w = freqs_w[inner_w_idx] # Create position matrices for height and width # [height, 1, dim//4] and [1, width, dim//4]