14
14
# limitations under the License.
15
15
16
16
import math
17
- from typing import Any , Dict , Optional , Tuple
17
+ from typing import Any , Dict , Optional , Tuple , Union
18
18
19
19
import torch
20
20
import torch .nn as nn
21
21
import torch .nn .functional as F
22
22
23
23
from ...configuration_utils import ConfigMixin , register_to_config
24
24
from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
25
- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
25
+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
26
26
from ...utils .torch_utils import maybe_allow_in_graph
27
27
from ..attention import FeedForward
28
28
from ..attention_processor import Attention
@@ -102,6 +102,7 @@ def __init__(
102
102
patch_size : int = 1 ,
103
103
patch_size_t : int = 1 ,
104
104
theta : float = 10000.0 ,
105
+ _causal_rope_fix : bool = False ,
105
106
) -> None :
106
107
super ().__init__ ()
107
108
@@ -112,13 +113,15 @@ def __init__(
112
113
self .patch_size = patch_size
113
114
self .patch_size_t = patch_size_t
114
115
self .theta = theta
116
+ self ._causal_rope_fix = _causal_rope_fix
115
117
116
118
def forward (
117
119
self ,
118
120
hidden_states : torch .Tensor ,
119
121
num_frames : int ,
120
122
height : int ,
121
123
width : int ,
124
+ frame_rate : Optional [int ] = None ,
122
125
rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] = None ,
123
126
) -> Tuple [torch .Tensor , torch .Tensor ]:
124
127
batch_size = hidden_states .size (0 )
@@ -132,9 +135,24 @@ def forward(
132
135
grid = grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
133
136
134
137
if rope_interpolation_scale is not None :
135
- grid [:, 0 :1 ] = grid [:, 0 :1 ] * rope_interpolation_scale [0 ] * self .patch_size_t / self .base_num_frames
136
- grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 ] * self .patch_size / self .base_height
137
- grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 ] * self .patch_size / self .base_width
138
+ if isinstance (rope_interpolation_scale , tuple ):
139
+ # This will be deprecated in v0.34.0
140
+ grid [:, 0 :1 ] = grid [:, 0 :1 ] * rope_interpolation_scale [0 ] * self .patch_size_t / self .base_num_frames
141
+ grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 ] * self .patch_size / self .base_height
142
+ grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 ] * self .patch_size / self .base_width
143
+ else :
144
+ if not self ._causal_rope_fix :
145
+ grid [:, 0 :1 ] = (
146
+ grid [:, 0 :1 ] * rope_interpolation_scale [0 :1 ] * self .patch_size_t / self .base_num_frames
147
+ )
148
+ else :
149
+ grid [:, 0 :1 ] = (
150
+ ((grid [:, 0 :1 ] - 1 ) * rope_interpolation_scale [0 :1 ] + 1 / frame_rate ).clamp (min = 0 )
151
+ * self .patch_size_t
152
+ / self .base_num_frames
153
+ )
154
+ grid [:, 1 :2 ] = grid [:, 1 :2 ] * rope_interpolation_scale [1 :2 ] * self .patch_size / self .base_height
155
+ grid [:, 2 :3 ] = grid [:, 2 :3 ] * rope_interpolation_scale [2 :3 ] * self .patch_size / self .base_width
138
156
139
157
grid = grid .flatten (2 , 4 ).transpose (1 , 2 )
140
158
@@ -315,6 +333,7 @@ def __init__(
315
333
caption_channels : int = 4096 ,
316
334
attention_bias : bool = True ,
317
335
attention_out_bias : bool = True ,
336
+ _causal_rope_fix : bool = False ,
318
337
) -> None :
319
338
super ().__init__ ()
320
339
@@ -336,6 +355,7 @@ def __init__(
336
355
patch_size = patch_size ,
337
356
patch_size_t = patch_size_t ,
338
357
theta = 10000.0 ,
358
+ _causal_rope_fix = _causal_rope_fix ,
339
359
)
340
360
341
361
self .transformer_blocks = nn .ModuleList (
@@ -370,7 +390,8 @@ def forward(
370
390
num_frames : int ,
371
391
height : int ,
372
392
width : int ,
373
- rope_interpolation_scale : Optional [Tuple [float , float , float ]] = None ,
393
+ frame_rate : int ,
394
+ rope_interpolation_scale : Optional [Union [Tuple [float , float , float ], torch .Tensor ]] = None ,
374
395
attention_kwargs : Optional [Dict [str , Any ]] = None ,
375
396
return_dict : bool = True ,
376
397
) -> torch .Tensor :
@@ -389,7 +410,11 @@ def forward(
389
410
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
390
411
)
391
412
392
- image_rotary_emb = self .rope (hidden_states , num_frames , height , width , rope_interpolation_scale )
413
+ if not isinstance (rope_interpolation_scale , torch .Tensor ):
414
+ msg = "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0."
415
+ deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
416
+
417
+ image_rotary_emb = self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale )
393
418
394
419
# convert encoder_attention_mask to a bias the same way we do for attention_mask
395
420
if encoder_attention_mask is not None and encoder_attention_mask .ndim == 2 :
0 commit comments