@@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict):
1276
1276
converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1277
1277
1278
1278
return converted_state_dict
1279
+
1280
+
1281
+ def _convert_non_diffusers_lumina2_lora_to_diffusers (state_dict ):
1282
+ # Remove "diffusion_model." prefix from keys.
1283
+ state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
1284
+ converted_state_dict = {}
1285
+
1286
+ def get_num_layers (keys , pattern ):
1287
+ layers = set ()
1288
+ for key in keys :
1289
+ match = re .search (pattern , key )
1290
+ if match :
1291
+ layers .add (int (match .group (1 )))
1292
+ return len (layers )
1293
+
1294
+ def process_block (prefix , index , convert_norm ):
1295
+ # Process attention qkv: pop lora_A and lora_B weights.
1296
+ lora_down = state_dict .pop (f"{ prefix } .{ index } .attention.qkv.lora_A.weight" )
1297
+ lora_up = state_dict .pop (f"{ prefix } .{ index } .attention.qkv.lora_B.weight" )
1298
+ for attn_key in ["to_q" , "to_k" , "to_v" ]:
1299
+ converted_state_dict [f"{ prefix } .{ index } .attn.{ attn_key } .lora_A.weight" ] = lora_down
1300
+ for attn_key , weight in zip (["to_q" , "to_k" , "to_v" ], torch .split (lora_up , [2304 , 768 , 768 ], dim = 0 )):
1301
+ converted_state_dict [f"{ prefix } .{ index } .attn.{ attn_key } .lora_B.weight" ] = weight
1302
+
1303
+ # Process attention out weights.
1304
+ converted_state_dict [f"{ prefix } .{ index } .attn.to_out.0.lora_A.weight" ] = state_dict .pop (
1305
+ f"{ prefix } .{ index } .attention.out.lora_A.weight"
1306
+ )
1307
+ converted_state_dict [f"{ prefix } .{ index } .attn.to_out.0.lora_B.weight" ] = state_dict .pop (
1308
+ f"{ prefix } .{ index } .attention.out.lora_B.weight"
1309
+ )
1310
+
1311
+ # Process feed-forward weights for layers 1, 2, and 3.
1312
+ for layer in range (1 , 4 ):
1313
+ converted_state_dict [f"{ prefix } .{ index } .feed_forward.linear_{ layer } .lora_A.weight" ] = state_dict .pop (
1314
+ f"{ prefix } .{ index } .feed_forward.w{ layer } .lora_A.weight"
1315
+ )
1316
+ converted_state_dict [f"{ prefix } .{ index } .feed_forward.linear_{ layer } .lora_B.weight" ] = state_dict .pop (
1317
+ f"{ prefix } .{ index } .feed_forward.w{ layer } .lora_B.weight"
1318
+ )
1319
+
1320
+ if convert_norm :
1321
+ converted_state_dict [f"{ prefix } .{ index } .norm1.linear.lora_A.weight" ] = state_dict .pop (
1322
+ f"{ prefix } .{ index } .adaLN_modulation.1.lora_A.weight"
1323
+ )
1324
+ converted_state_dict [f"{ prefix } .{ index } .norm1.linear.lora_B.weight" ] = state_dict .pop (
1325
+ f"{ prefix } .{ index } .adaLN_modulation.1.lora_B.weight"
1326
+ )
1327
+
1328
+ noise_refiner_pattern = r"noise_refiner\.(\d+)\."
1329
+ num_noise_refiner_layers = get_num_layers (state_dict .keys (), noise_refiner_pattern )
1330
+ for i in range (num_noise_refiner_layers ):
1331
+ process_block ("noise_refiner" , i , convert_norm = True )
1332
+
1333
+ context_refiner_pattern = r"context_refiner\.(\d+)\."
1334
+ num_context_refiner_layers = get_num_layers (state_dict .keys (), context_refiner_pattern )
1335
+ for i in range (num_context_refiner_layers ):
1336
+ process_block ("context_refiner" , i , convert_norm = False )
1337
+
1338
+ core_transformer_pattern = r"layers\.(\d+)\."
1339
+ num_core_transformer_layers = get_num_layers (state_dict .keys (), core_transformer_pattern )
1340
+ for i in range (num_core_transformer_layers ):
1341
+ process_block ("layers" , i , convert_norm = True )
1342
+
1343
+ if len (state_dict ) > 0 :
1344
+ raise ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= } " )
1345
+
1346
+ for key in list (converted_state_dict .keys ()):
1347
+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1348
+
1349
+ return converted_state_dict
0 commit comments