@@ -223,17 +223,23 @@ def __init__(
223
223
224
224
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
225
225
hidden_states = torch .cat ([hidden_states [:, :, : self .stride [0 ] - 1 ], hidden_states ], dim = 2 )
226
- hidden_states = (
226
+
227
+ residual = (
227
228
hidden_states .unflatten (4 , (- 1 , self .stride [2 ]))
228
229
.unflatten (3 , (- 1 , self .stride [1 ]))
229
230
.unflatten (2 , (- 1 , self .stride [0 ]))
230
231
)
231
- hidden_states = hidden_states .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
232
+ residual = residual .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
233
+ residual = residual .unflatten (1 , (- 1 , self .group_size ))
234
+ residual = residual .mean (dim = 2 )
232
235
233
- residual = hidden_states
234
- hidden_states = hidden_states .unflatten (1 , (- 1 , self .group_size ))
235
- hidden_states = hidden_states .mean (dim = 2 )
236
236
hidden_states = self .conv (hidden_states )
237
+ hidden_states = (
238
+ hidden_states .unflatten (4 , (- 1 , self .stride [2 ]))
239
+ .unflatten (3 , (- 1 , self .stride [1 ]))
240
+ .unflatten (2 , (- 1 , self .stride [0 ]))
241
+ )
242
+ hidden_states = hidden_states .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (1 , 4 )
237
243
hidden_states = hidden_states + residual
238
244
239
245
return hidden_states
0 commit comments