@@ -244,30 +244,34 @@ class CogView4RotaryPosEmbed(nn.Module):
244
244
def __init__ (self , dim : int , patch_size : int , rope_axes_dim : Tuple [int , int ], theta : float = 10000.0 ) -> None :
245
245
super ().__init__ ()
246
246
247
+ self .dim = dim
247
248
self .patch_size = patch_size
248
249
self .rope_axes_dim = rope_axes_dim
249
-
250
- dim_h , dim_w = dim // 2 , dim // 2
251
- h_inv_freq = 1.0 / (theta ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h ))
252
- w_inv_freq = 1.0 / (theta ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w ))
253
- h_seq = torch .arange (self .rope_axes_dim [0 ])
254
- w_seq = torch .arange (self .rope_axes_dim [1 ])
255
- self .freqs_h = torch .outer (h_seq , h_inv_freq )
256
- self .freqs_w = torch .outer (w_seq , w_inv_freq )
250
+ self .theta = theta
257
251
258
252
def forward (self , hidden_states : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
259
253
batch_size , num_channels , height , width = hidden_states .shape
260
254
height , width = height // self .patch_size , width // self .patch_size
261
255
262
- h_idx = torch .arange (height )
263
- w_idx = torch .arange (width )
256
+ dim_h , dim_w = self .dim // 2 , self .dim // 2
257
+ h_inv_freq = 1.0 / (
258
+ self .theta ** (torch .arange (0 , dim_h , 2 , dtype = torch .float32 )[: (dim_h // 2 )].float () / dim_h )
259
+ )
260
+ w_inv_freq = 1.0 / (
261
+ self .theta ** (torch .arange (0 , dim_w , 2 , dtype = torch .float32 )[: (dim_w // 2 )].float () / dim_w )
262
+ )
263
+ h_seq = torch .arange (self .rope_axes_dim [0 ])
264
+ w_seq = torch .arange (self .rope_axes_dim [1 ])
265
+ freqs_h = torch .outer (h_seq , h_inv_freq )
266
+ freqs_w = torch .outer (w_seq , w_inv_freq )
267
+
268
+ h_idx = torch .arange (height , device = freqs_h .device )
269
+ w_idx = torch .arange (width , device = freqs_w .device )
264
270
inner_h_idx = h_idx * self .rope_axes_dim [0 ] // height
265
271
inner_w_idx = w_idx * self .rope_axes_dim [1 ] // width
266
272
267
- self .freqs_h = self .freqs_h .to (hidden_states .device )
268
- self .freqs_w = self .freqs_w .to (hidden_states .device )
269
- freqs_h = self .freqs_h [inner_h_idx ]
270
- freqs_w = self .freqs_w [inner_w_idx ]
273
+ freqs_h = freqs_h [inner_h_idx ]
274
+ freqs_w = freqs_w [inner_w_idx ]
271
275
272
276
# Create position matrices for height and width
273
277
# [height, 1, dim//4] and [1, width, dim//4]
0 commit comments