@@ -2331,32 +2331,177 @@ def set_gguf_parameters(self):
2331
2331
class SNACDecModel (Model ):
2332
2332
model_arch = gguf .MODEL_ARCH .SNAC_DEC
2333
2333
2334
- def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [Tuple [str , Tensor ]]:
2335
- del bid # unused
2334
+ def __init__ (self , * args , ** kwargs ):
2335
+ super ().__init__ (* args , ** kwargs )
2336
+ self ._dummy_added = False
2337
+
2338
+ def modify_tensors (self , data_torch : torch .Tensor , name : str , bid : int | None ) -> Iterable [Tuple [str , torch .Tensor ]]:
2339
+ """Convert nested PyTorch tensor names to a flat GGUF naming scheme for decoder tensors."""
2340
+ del bid # Unused
2341
+
2342
+ # Add dummy token_embd.weight only once
2343
+ if not self ._dummy_added :
2344
+ import torch
2345
+ dummy_tok_embd = torch .zeros ((4096 , 8 ), dtype = torch .float16 )
2346
+ dummy_tok_embd = dummy_tok_embd .view (4096 , 8 )
2347
+ logger .info (f"Adding dummy tensor: token_embd.weight, shape: { list (dummy_tok_embd .shape )} " )
2348
+ yield ("token_embd.weight" , dummy_tok_embd )
2349
+ self ._dummy_added = True # Mark as added
2350
+
2351
+ original_name = name
2352
+
2353
+ if name .startswith ("quantizer.quantizers." ):
2354
+ match = re .match (r"quantizer\.quantizers\.(\d+)\.(codebook\.weight|out_proj\.bias|out_proj\.parametrizations\.weight\.original[0-1])" , name )
2355
+ if match :
2356
+ q_idx = int (match .group (1 ))
2357
+ tensor_type = match .group (2 )
2358
+ if tensor_type == "codebook.weight" :
2359
+ new_name = f"quantizer.{ q_idx } .codebook"
2360
+ elif tensor_type == "out_proj.parametrizations.weight.original0" :
2361
+ new_name = f"quantizer.{ q_idx } .out_proj.scale"
2362
+ elif tensor_type == "out_proj.parametrizations.weight.original1" :
2363
+ new_name = f"quantizer.{ q_idx } .out_proj.weight"
2364
+ elif tensor_type == "out_proj.bias" :
2365
+ new_name = f"quantizer.{ q_idx } .out_proj.bias"
2366
+
2367
+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2368
+ yield (new_name , data_torch )
2369
+ else :
2370
+ logger .warning (f"Could not parse quantizer tensor from: { original_name } " )
2371
+ return
2336
2372
2337
- logger .debug (f"Processing tensor: { name } " )
2373
+ # Skip non-decoder tensors (except quantizers, which were handled above)
2374
+ if not name .startswith ("decoder." ):
2375
+ logger .debug (f"Skipping non-decoder tensor: { original_name } " )
2376
+ return
2338
2377
2339
- if (name .startswith ("decoder." ) or
2340
- re .match (r"quantizer\.quantizers\.\d+\.codebook\.weight" , name ) or
2341
- re .match (r"quantizer\.quantizers\.\d+\.out_proj\..*" , name )):
2342
- logger .info (f"{ name } -> { data_torch .shape } " )
2343
- return [(name , data_torch )]
2344
- else :
2345
- logger .debug (f"Skipping { name !r} " )
2346
- return []
2378
+ base = name [8 :] # Remove 'decoder.'
2379
+ parts = base .split ("." )
2380
+
2381
+ if base .startswith ("model.0." ):
2382
+ logger .info (f"Skipping incompatible decoder layer 0 tensor: { original_name } " )
2383
+ return # Explicitly skip this layer
2384
+
2385
+ # Layer 1: Second Conv
2386
+ if base .startswith ("model.1." ):
2387
+ if "bias" in name and "parametrizations" not in name :
2388
+ new_name = "decoder.1.conv2.bias"
2389
+ elif "parametrizations.weight.original0" in name :
2390
+ new_name = "decoder.1.conv2.scale"
2391
+ elif "parametrizations.weight.original1" in name :
2392
+ new_name = "decoder.1.conv2.weight"
2393
+ else :
2394
+ logger .warning (f"Unhandled layer 1 tensor: { original_name } " )
2395
+ return
2396
+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2397
+ yield (new_name , data_torch )
2398
+ return
2399
+
2400
+ # Layers 2–5: DecoderBlocks
2401
+ if "model." in base and "block" in base :
2402
+ try :
2403
+ layer_idx = int (parts [1 ]) # e.g., '2' from 'model.2'
2404
+ if layer_idx not in {2 , 3 , 4 , 5 }:
2405
+ logger .debug (f"Skipping non-DecoderBlock layer { layer_idx } : { original_name } " )
2406
+ return
2407
+ block_idx = int (parts [3 ]) # e.g., '1' from 'block.1'
2408
+ new_base = f"decoder.{ layer_idx } .block.{ block_idx } "
2409
+
2410
+ if block_idx == 0 : # Snake1d
2411
+ if "alpha" in name :
2412
+ new_name = f"{ new_base } .alpha"
2413
+ else :
2414
+ logger .error (f"Expected 'alpha' in { original_name } " )
2415
+ return
2416
+ elif block_idx == 1 : # Transpose Conv
2417
+ if "bias" in name and "parametrizations" not in name :
2418
+ new_name = f"{ new_base } .trans.bias"
2419
+ elif "parametrizations.weight.original0" in name :
2420
+ new_name = f"{ new_base } .trans.scale"
2421
+ elif "parametrizations.weight.original1" in name :
2422
+ new_name = f"{ new_base } .trans.weight"
2423
+ else :
2424
+ logger .error (f"Unhandled tensor in block 1: { original_name } " )
2425
+ return
2426
+ elif block_idx == 2 : # Noise Block
2427
+ if "linear.parametrizations.weight.original0" in name :
2428
+ new_name = f"{ new_base } .noise.scale"
2429
+ elif "linear.parametrizations.weight.original1" in name :
2430
+ new_name = f"{ new_base } .noise.weight"
2431
+ else :
2432
+ logger .error (f"Unhandled tensor in block 2: { original_name } " )
2433
+ return
2434
+ elif block_idx in {3 , 4 , 5 }: # Residual Units
2435
+ res_base = f"{ new_base } .res"
2436
+ if "block.0.alpha" in name :
2437
+ new_name = f"{ res_base } .snake1.alpha"
2438
+ elif "block.1.bias" in name :
2439
+ new_name = f"{ res_base } .conv1.bias"
2440
+ elif "block.1.parametrizations.weight.original0" in name :
2441
+ new_name = f"{ res_base } .conv1.scale"
2442
+ elif "block.1.parametrizations.weight.original1" in name :
2443
+ new_name = f"{ res_base } .conv1.weight"
2444
+ elif "block.2.alpha" in name :
2445
+ new_name = f"{ res_base } .snake2.alpha"
2446
+ elif "block.3.bias" in name :
2447
+ new_name = f"{ res_base } .conv2.bias"
2448
+ elif "block.3.parametrizations.weight.original0" in name :
2449
+ new_name = f"{ res_base } .conv2.scale"
2450
+ elif "block.3.parametrizations.weight.original1" in name :
2451
+ new_name = f"{ res_base } .conv2.weight"
2452
+ else :
2453
+ logger .error (f"Unhandled tensor in residual unit: { original_name } " )
2454
+ return
2455
+ else :
2456
+ logger .error (f"Unhandled block index { block_idx } in layer { layer_idx } : { original_name } " )
2457
+ return
2458
+
2459
+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2460
+ yield (new_name , data_torch )
2461
+ return
2462
+
2463
+ except (IndexError , ValueError ) as e :
2464
+ logger .error (f"Failed to parse tensor { original_name } : { e } " )
2465
+ return
2466
+
2467
+ # Layer 6: Snake1d
2468
+ if base == "model.6.alpha" :
2469
+ new_name = "decoder.6.alpha"
2470
+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2471
+ yield (new_name , data_torch )
2472
+ return
2473
+
2474
+ # Layer 7: Final Conv
2475
+ if base .startswith ("model.7." ):
2476
+ if "bias" in name and "parametrizations" not in name :
2477
+ new_name = "decoder.7.conv.bias"
2478
+ elif "parametrizations.weight.original0" in name :
2479
+ new_name = "decoder.7.conv.scale"
2480
+ elif "parametrizations.weight.original1" in name :
2481
+ new_name = "decoder.7.conv.weight"
2482
+ else :
2483
+ logger .warning (f"Unhandled layer 7 tensor: { original_name } " )
2484
+ return
2485
+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2486
+ yield (new_name , data_torch )
2487
+ return
2488
+
2489
+ logger .warning (f"Tensor { original_name } not mapped to any layer" )
2490
+ return
2347
2491
2348
2492
def set_vocab (self ):
2349
2493
self ._set_vocab_none ()
2350
2494
2351
2495
def set_gguf_parameters (self ):
2352
2496
super ().set_gguf_parameters ()
2353
- self .gguf_writer .add_vocab_size (self .hparams ["codebook_size" ])
2354
- self .gguf_writer .add_quantizer_count (len (self .hparams ["vq_strides" ]))
2355
- self .gguf_writer .add_features_length (self .hparams ["codebook_dim" ])
2356
- self .gguf_writer .add_quantizer_strides (self .hparams ["vq_strides" ])
2357
- self .gguf_writer .add_embedding_length (self .hparams ["decoder_dim" ])
2358
- self .gguf_writer .add_decoder_upsample_rates (self .hparams ["decoder_rates" ])
2359
- self .gguf_writer .add_decoder_channel_dims (self .hparams ["decoder_channel_dims" ])
2497
+ self .gguf_writer .add_vocab_size (4096 ) # TODO: Fix
2498
+ self .gguf_writer .add_uint32 ("snac.quantizer.codebook_size" , self .hparams ["codebook_size" ])
2499
+ self .gguf_writer .add_uint32 ("snac.quantizer.codebook_dim" , self .hparams ["codebook_dim" ])
2500
+ self .gguf_writer .add_embedding_length (self .hparams ["decoder_dim" ]) # 1024
2501
+ self .gguf_writer .add_decoder_upsample_rates (self .hparams ["decoder_rates" ]) # [8, 8, 4, 2]
2502
+ self .gguf_writer .add_uint32 ("n_layers" , 8 )
2503
+ self .gguf_writer .add_array ("decoder_channel_dims" , [768 , 1024 , 512 , 256 , 128 , 64 , 1 ])
2504
+ self .gguf_writer .add_array ("vq_strides" , self .hparams ["vq_strides" ])
2360
2505
2361
2506
@Model .register ("Qwen2MoeForCausalLM" )
2362
2507
class Qwen2MoeModel (Model ):
0 commit comments