diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d21edce16b71e..093e769e338f3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2327,6 +2327,182 @@ def set_gguf_parameters(self): self.gguf_writer.add_causal_attention(False) +@Model.register("SNACDec") +class SNACDecModel(Model): + model_arch = gguf.MODEL_ARCH.SNAC_DEC + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._dummy_added = False + + def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, torch.Tensor]]: + """Convert nested PyTorch tensor names to a flat GGUF naming scheme for decoder tensors.""" + del bid # Unused + + # Add dummy token_embd.weight only once + if not self._dummy_added: + import torch + dummy_tok_embd = torch.zeros((4096, 8), dtype=torch.float16) + dummy_tok_embd = dummy_tok_embd.view(4096, 8) + logger.info(f"Adding dummy tensor: token_embd.weight, shape: {list(dummy_tok_embd.shape)}") + yield ("token_embd.weight", dummy_tok_embd) + self._dummy_added = True # Mark as added + + original_name = name + + if name.startswith("quantizer.quantizers."): + match = re.match(r"quantizer\.quantizers\.(\d+)\.(codebook\.weight|out_proj\.bias|out_proj\.parametrizations\.weight\.original[0-1])", name) + if match: + q_idx = int(match.group(1)) + tensor_type = match.group(2) + if tensor_type == "codebook.weight": + new_name = f"quantizer.{q_idx}.codebook" + elif tensor_type == "out_proj.parametrizations.weight.original0": + new_name = f"quantizer.{q_idx}.out_proj.scale" + elif tensor_type == "out_proj.parametrizations.weight.original1": + new_name = f"quantizer.{q_idx}.out_proj.weight" + elif tensor_type == "out_proj.bias": + new_name = f"quantizer.{q_idx}.out_proj.bias" + + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + else: + logger.warning(f"Could not parse quantizer tensor from: {original_name}") + return + + # Skip non-decoder tensors (except quantizers, which were handled above) + if not name.startswith("decoder."): + logger.debug(f"Skipping non-decoder tensor: {original_name}") + return + + base = name[8:] # Remove 'decoder.' + parts = base.split(".") + + if base.startswith("model.0."): + logger.info(f"Skipping incompatible decoder layer 0 tensor: {original_name}") + return # Explicitly skip this layer + + # Layer 1: Second Conv + if base.startswith("model.1."): + if "bias" in name and "parametrizations" not in name: + new_name = "decoder.1.conv2.bias" + elif "parametrizations.weight.original0" in name: + new_name = "decoder.1.conv2.scale" + elif "parametrizations.weight.original1" in name: + new_name = "decoder.1.conv2.weight" + else: + logger.warning(f"Unhandled layer 1 tensor: {original_name}") + return + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + # Layers 2–5: DecoderBlocks + if "model." in base and "block" in base: + try: + layer_idx = int(parts[1]) # e.g., '2' from 'model.2' + if layer_idx not in {2, 3, 4, 5}: + logger.debug(f"Skipping non-DecoderBlock layer {layer_idx}: {original_name}") + return + block_idx = int(parts[3]) # e.g., '1' from 'block.1' + new_base = f"decoder.{layer_idx}.block.{block_idx}" + + if block_idx == 0: # Snake1d + if "alpha" in name: + new_name = f"{new_base}.alpha" + else: + logger.error(f"Expected 'alpha' in {original_name}") + return + elif block_idx == 1: # Transpose Conv + if "bias" in name and "parametrizations" not in name: + new_name = f"{new_base}.trans.bias" + elif "parametrizations.weight.original0" in name: + new_name = f"{new_base}.trans.scale" + elif "parametrizations.weight.original1" in name: + new_name = f"{new_base}.trans.weight" + else: + logger.error(f"Unhandled tensor in block 1: {original_name}") + return + elif block_idx == 2: # Noise Block + if "linear.parametrizations.weight.original0" in name: + new_name = f"{new_base}.noise.scale" + elif "linear.parametrizations.weight.original1" in name: + new_name = f"{new_base}.noise.weight" + else: + logger.error(f"Unhandled tensor in block 2: {original_name}") + return + elif block_idx in {3, 4, 5}: # Residual Units + res_base = f"{new_base}.res" + if "block.0.alpha" in name: + new_name = f"{res_base}.snake1.alpha" + elif "block.1.bias" in name: + new_name = f"{res_base}.conv1.bias" + elif "block.1.parametrizations.weight.original0" in name: + new_name = f"{res_base}.conv1.scale" + elif "block.1.parametrizations.weight.original1" in name: + new_name = f"{res_base}.conv1.weight" + elif "block.2.alpha" in name: + new_name = f"{res_base}.snake2.alpha" + elif "block.3.bias" in name: + new_name = f"{res_base}.conv2.bias" + elif "block.3.parametrizations.weight.original0" in name: + new_name = f"{res_base}.conv2.scale" + elif "block.3.parametrizations.weight.original1" in name: + new_name = f"{res_base}.conv2.weight" + else: + logger.error(f"Unhandled tensor in residual unit: {original_name}") + return + else: + logger.error(f"Unhandled block index {block_idx} in layer {layer_idx}: {original_name}") + return + + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + except (IndexError, ValueError) as e: + logger.error(f"Failed to parse tensor {original_name}: {e}") + return + + # Layer 6: Snake1d + if base == "model.6.alpha": + new_name = "decoder.6.alpha" + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + # Layer 7: Final Conv + if base.startswith("model.7."): + if "bias" in name and "parametrizations" not in name: + new_name = "decoder.7.conv.bias" + elif "parametrizations.weight.original0" in name: + new_name = "decoder.7.conv.scale" + elif "parametrizations.weight.original1" in name: + new_name = "decoder.7.conv.weight" + else: + logger.warning(f"Unhandled layer 7 tensor: {original_name}") + return + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + logger.warning(f"Tensor {original_name} not mapped to any layer") + return + + def set_vocab(self): + self._set_vocab_none() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_uint32("snac.quantizer.codebook_size", self.hparams["codebook_size"]) + self.gguf_writer.add_uint32("snac.quantizer.codebook_dim", self.hparams["codebook_dim"]) + self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) # 1024 + self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) # [8, 8, 4, 2] + self.gguf_writer.add_uint32("n_layers", 8) + self.gguf_writer.add_array("decoder_channel_dims", [768, 1024, 512, 256, 128, 64, 1]) + self.gguf_writer.add_array("vq_strides", self.hparams["vq_strides"]) + @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): model_arch = gguf.MODEL_ARCH.QWEN2MOE diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..42f95df7387b4 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-orpheus-tts) +add_executable(${TARGET} orpheus-tts.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp new file mode 100644 index 0000000000000..a7f0e16dfa296 --- /dev/null +++ b/examples/tts/orpheus-tts.cpp @@ -0,0 +1,223 @@ +#include "common.h" +#include "llama.h" +#include "log.h" +#include "arg.h" +#include "sampling.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +static bool save_wav16(const std::string &fname, const std::vector &data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); + return false; + } + + wav_header header; + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto &sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0f, -32768.0f, 32767.0f)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + return file.good(); +} + +std::vector redistribute_codes(const std::vector& raw_codes) { + std::vector snac_codes; + for (size_t i = 0; i < raw_codes.size(); i += 7) { + if (i + 6 >= raw_codes.size()) break; + + // Subtract 128266 base and layer-specific offsets + snac_codes.push_back(raw_codes[i] - 128266); // Layer 1: offset 0 + snac_codes.push_back(raw_codes[i + 1] - 128266 - 4096); // Layer 2: offset 4096 + snac_codes.push_back(raw_codes[i + 2] - 128266 - 8192); // Layer 3: offset 8192 + snac_codes.push_back(raw_codes[i + 3] - 128266 - 12288); // Layer 3: offset 12288 + snac_codes.push_back(raw_codes[i + 4] - 128266 - 16384); // Layer 2: offset 16384 + snac_codes.push_back(raw_codes[i + 5] - 128266 - 20480); // Layer 3: offset 20480 + snac_codes.push_back(raw_codes[i + 6] - 128266 - 24576); // Layer 3: offset 24576 + } + return snac_codes; +} + +static void print_usage(int /*argc*/, char **argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); + LOG("\n"); +} + +static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt.insert(prompt.end(), tmp.begin(), tmp.end()); +} + +int main(int argc, char **argv) { + common_params params; + + params.model = "models/orpheus-3b-0.1-ft-q4_k_m.gguf"; + params.vocoder.model = "models/snac-fwd-pass-devel.gguf"; + params.out_file = "output.wav"; + + params.sampling.top_k = 4; + params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + params.n_batch = 4096; + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + + common_init_result orpheus_init_ttc = common_init_from_params(params); + + llama_model * model_ttc = NULL; + llama_context * ctx_ttc = NULL; + + model_ttc = orpheus_init_ttc.model.get(); + ctx_ttc = orpheus_init_ttc.context.get(); + + const llama_vocab *vocab = llama_model_get_vocab(model_ttc); + + common_sampler *sampler = common_sampler_init(model_ttc, params.sampling); + if (!sampler) { + LOG_ERR("Failed to initialize sampler\n"); + return 1; + } + + // Construct prompt: <|startofhuman|> tara: [prompt] <|eot_id|> <|endofhuman|> + std::vector tokens; + tokens.push_back(128259); // <|startofhuman|> + prompt_add(tokens, vocab, "tara: ", false, true); // Voice prefix + prompt_add(tokens, vocab, params.prompt, false, true); // User prompt + prompt_add(tokens, vocab, "", false, true); // Emotion tag + tokens.push_back(128009); // <|eot_id|> + tokens.push_back(128260); // <|endofhuman|> + + + llama_model * model_cts = NULL; + llama_context * ctx_cts = NULL; + + params.model = params.vocoder.model; + + params.embedding = true; + params.warmup = false; // SNAC doesn't care about BOS or EOS tokens + + common_init_result snac_init_cts = common_init_from_params(params); + LOG_INF("SNAC model loaded: %s\n", params.model.c_str()); + + model_cts = snac_init_cts.model.get(); + ctx_cts = snac_init_cts.context.get(); + + // TODO: Use real orpheus codes + // Just some random numbers for testing + std::vector orpheus_codes = { + // Frame 1, 7 codes per frame + 128266 + 100, // L1: 100 + 128266 + 4096 + 200, // L2: 200 + 128266 + 8192 + 300, // L3: 300 + 128266 + 12288 + 400,// L3: 400 + 128266 + 16384 + 500,// L2: 500 + 128266 + 20480 + 600,// L3: 600 + 128266 + 24576 + 700,// L3: 700 + // Frame 2 + 128266 + 150, 128266 + 4096 + 250, 128266 + 8192 + 350, 128266 + 12288 + 450, + 128266 + 16384 + 550, 128266 + 20480 + 650, 128266 + 24576 + 750, + // Frame 3 + 128266 + 110, 128266 + 4096 + 210, 128266 + 8192 + 310, 128266 + 12288 + 410, + 128266 + 16384 + 510, 128266 + 20480 + 610, 128266 + 24576 + 710, + // Frame 4 + 128266 + 120, 128266 + 4096 + 220, 128266 + 8192 + 320, 128266 + 12288 + 420, + 128266 + 16384 + 520, 128266 + 20480 + 620, 128266 + 24576 + 720, + // Frame 5 + 128266 + 130, 128266 + 4096 + 230, 128266 + 8192 + 330, 128266 + 12288 + 430, + 128266 + 16384 + 530, 128266 + 20480 + 630, 128266 + 24576 + 730, + // Frame 6 + 128266 + 140, 128266 + 4096 + 240, 128266 + 8192 + 340, 128266 + 12288 + 440, + 128266 + 16384 + 540, 128266 + 20480 + 640, 128266 + 24576 + 740, + // Frame 7 + 128266 + 160, 128266 + 4096 + 260, 128266 + 8192 + 360, 128266 + 12288 + 460, + 128266 + 16384 + 560, 128266 + 20480 + 660, 128266 + 24576 + 760, + // Frame 8 + 128266 + 170, 128266 + 4096 + 270, 128266 + 8192 + 370, 128266 + 12288 + 470, + 128266 + 16384 + 570, 128266 + 20480 + 670, 128266 + 24576 + 770, + // Frame 9 + 128266 + 180, 128266 + 4096 + 280, 128266 + 8192 + 380, 128266 + 12288 + 480, + 128266 + 16384 + 580, 128266 + 20480 + 680, 128266 + 24576 + 780, + // Frame 10 + 128266 + 190, 128266 + 4096 + 290, 128266 + 8192 + 390, 128266 + 12288 + 490, + 128266 + 16384 + 590, 128266 + 20480 + 690, 128266 + 24576 + 790 + }; + + std::vector snac_codes = redistribute_codes(orpheus_codes); + + const int batch_size = snac_codes.size(); + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + for (size_t i = 0; i < batch_size; ++i) { + common_batch_add(batch, snac_codes[i], i, {0}, true); + } + + LOG_INF("Batch before decode: n_tokens = %d\n", batch.n_tokens); + GGML_ASSERT(batch.n_tokens == batch_size); + + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx_cts, batch) != 0) { + LOG_ERR("Failed to decode SNAC batch\n"); + return 1; + } + LOG_INF("SNAC decode completed\n"); + llama_synchronize(ctx_cts); + + float* embd = llama_get_embeddings(ctx_cts); + if (!embd) { + LOG_ERR("No embeddings available\n"); + return 1; + } + + int n_samples = llama_get_n_outputs(ctx_cts); + std::vector audio(n_samples); + LOG_INF("n_samples: %i\n", n_samples); + memcpy(audio.data(), embd, n_samples * sizeof(float)); + + save_wav16(params.out_file, audio, 24000); + + llama_batch_free(batch); + llama_backend_free(); + return 0; +} diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index cb3edb10d4702..42e1d2506337e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -492,6 +492,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_SNAKE, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -1062,6 +1063,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_snake( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha); + + GGML_API struct ggml_tensor * ggml_snake_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 75dc96b478655..7bded06f88a94 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1911,6 +1911,21 @@ inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f)); } } +inline static void ggml_vec_snake_f32(const int n, float * y, const float * x, const float a) { + for (int i = 0; i < n; ++i) { + float x_val = x[i]; + float sin_val = sinf(a * x_val); + y[i] = x_val + sin_val * sin_val; + } +} +inline static void ggml_vec_snake_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t a) { + for (int i = 0; i < n; ++i) { + float x_val = GGML_FP16_TO_FP32(x[i]); // TODO: double check this conversion + float a_val = GGML_FP16_TO_FP32(a); + float sin_val = sinf(a_val * x_val); + y[i] = GGML_FP32_TO_FP16(x_val + sin_val * sin_val); + } +} inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { @@ -7817,6 +7832,86 @@ static void ggml_compute_forward_leaky_relu( } } +// ggml_compute_forward_snake + +static void ggml_compute_forward_snake_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + + // Scaffold code, 1 thread for now + // TODO: add multithreading + if (params->ith != 0) { + return; + } + + struct ggml_tensor * alpha = *(struct ggml_tensor **)(dst->op_params); + const float * x = (const float *)src0->data; + const float * a = (const float *)alpha->data; + float * y = (float *)dst->data; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int channels = src0->ne[1]; + + for (int i = 0; i < n; i++) { + int c = i % channels; + ggml_vec_snake_f32(nc, + (float *) ((char *) y + i * dst->nb[1]), + (const float *) ((const char *) x + i * src0->nb[1]), + a[c]); // alpha tensor for this channel + } +} + +static void ggml_compute_forward_snake_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + struct ggml_tensor * alpha = *(struct ggml_tensor **)(dst->op_params); + const ggml_fp16_t * x = (const ggml_fp16_t *)src0->data; + const ggml_fp16_t * a = (const ggml_fp16_t *)alpha->data; + ggml_fp16_t * y = (ggml_fp16_t *)dst->data; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int channels = src0->ne[1]; + + for (int i = 0; i < n; i++) { + int c = i % channels; + ggml_vec_snake_f16(nc, + (ggml_fp16_t *) ((char *) y + i * dst->nb[1]), + (const ggml_fp16_t *) ((const char *) x + i * src0->nb[1]), + a[c]); + } +} + +static void ggml_compute_forward_snake( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_snake_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_snake_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_silu_back static void ggml_compute_forward_silu_back_f32( @@ -14555,6 +14650,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; + case GGML_OP_SNAKE: + { + ggml_compute_forward_snake(params, tensor); + } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); @@ -14795,6 +14894,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: + case GGML_OP_SNAKE: { n_tasks = 1; } break; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2e081d5910c6e..adb8c81e7b547 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -967,6 +967,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "SNAKE", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -998,7 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1097,7 +1098,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2474,6 +2475,35 @@ struct ggml_tensor * ggml_leaky_relu( return result; } +// ggml snake + +struct ggml_tensor * ggml_snake( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + // store ptr to alpha tensor + ggml_set_op_params(result, &alpha, sizeof(alpha)); + result->op = GGML_OP_SNAKE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_snake_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha) { + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + ggml_set_op_params(result, &alpha, sizeof(alpha)); + result->op = GGML_OP_SNAKE; + result->src[0] = a; + + return result; +} + // ggml_sigmoid struct ggml_tensor * ggml_sigmoid( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cc48913d9789d..19fb2319f85cc 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -173,6 +173,11 @@ class ConvNext: EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" BLOCK_COUNT = "{arch}.convnext.block_count" + class AudioCodec: + #CODEBOOK_DIM = "{arch}.audio_codec.codebook_dim" + DECODER_UPSAMPLE_RATES = "{arch}.audio_codec.decoder_upsample_rates" + DECODER_CHANNEL_DIMS = "{arch}.audio_codec.decoder_channel_dims" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -286,6 +291,7 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + SNAC_DEC = auto() class MODEL_TENSOR(IntEnum): @@ -425,6 +431,7 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + UPSAMPLE_CONV = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -488,6 +495,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.SNAC_DEC: "snac-dec", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -627,6 +635,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + MODEL_TENSOR.UPSAMPLE_CONV: "upsample_conv.{bid}", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1650,6 +1659,16 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], + MODEL_ARCH.SNAC_DEC: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.CONVNEXT_DW, + MODEL_TENSOR.CONVNEXT_GAMMA, + MODEL_TENSOR.UPSAMPLE_CONV, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.OUTPUT, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af8b388dfaba5..f1924358d9cad 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -887,6 +887,12 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_decoder_upsample_rates(self, rates: Sequence[int]) -> None: + self.add_array(Keys.AudioCodec.DECODER_UPSAMPLE_RATES.format(arch=self.arch), rates) + + def add_decoder_channel_dims(self, dims: Sequence[int]) -> None: + self.add_array(Keys.AudioCodec.DECODER_CHANNEL_DIMS.format(arch=self.arch), dims) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/include/llama.h b/include/llama.h index 6a44be404d914..f98f1910bcf1c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -629,6 +629,8 @@ extern "C" { llama_seq_id * cells_sequences; }; + LLAMA_API int32_t llama_get_n_outputs(struct llama_context * ctx); + // Create an empty KV cache view. (use only for debugging purposes) LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9debb56cc80d5..b90214d255584 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_SNAC_DEC, "snac-dec" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1391,6 +1392,55 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_SNAC_DEC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_CODEBOOK, "quantizer.%d.codebook" }, + { LLM_TENSOR_CODEBOOK_PROJ_B, "quantizer.%d.out_proj.bias" }, + { LLM_TENSOR_CODEBOOK_PROJ_S, "quantizer.%d.out_proj.scale" }, + { LLM_TENSOR_CODEBOOK_PROJ_W, "quantizer.%d.out_proj.weight" }, + + { LLM_TENSOR_CONV_W2, "decoder.1.conv2.weight" }, + { LLM_TENSOR_CONV_S2, "decoder.1.conv2.scale" }, + { LLM_TENSOR_CONV_B2, "decoder.1.conv2.bias" }, + { LLM_TENSOR_BLOCK_ALPHA, "decoder.%d.block.0.alpha" }, + { LLM_TENSOR_TRANS_W, "decoder.%d.block.1.trans.weight" }, + { LLM_TENSOR_TRANS_S, "decoder.%d.block.1.trans.scale" }, + { LLM_TENSOR_TRANS_B, "decoder.%d.block.1.trans.bias" }, + { LLM_TENSOR_NOISE_W, "decoder.%d.block.2.noise.weight" }, + { LLM_TENSOR_NOISE_S, "decoder.%d.block.2.noise.scale" }, + // Residual Units + { LLM_TENSOR_RES_SNAKE1_A, "decoder.%d.block.3.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W, "decoder.%d.block.3.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S, "decoder.%d.block.3.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B, "decoder.%d.block.3.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A, "decoder.%d.block.3.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W, "decoder.%d.block.3.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S, "decoder.%d.block.3.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B, "decoder.%d.block.3.res.conv2.bias" }, + { LLM_TENSOR_RES_SNAKE1_A_B4, "decoder.%d.block.4.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W_B4, "decoder.%d.block.4.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S_B4, "decoder.%d.block.4.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B_B4, "decoder.%d.block.4.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A_B4, "decoder.%d.block.4.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W_B4, "decoder.%d.block.4.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S_B4, "decoder.%d.block.4.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B_B4, "decoder.%d.block.4.res.conv2.bias" }, + { LLM_TENSOR_RES_SNAKE1_A_B5, "decoder.%d.block.5.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W_B5, "decoder.%d.block.5.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S_B5, "decoder.%d.block.5.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B_B5, "decoder.%d.block.5.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A_B5, "decoder.%d.block.5.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W_B5, "decoder.%d.block.5.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S_B5, "decoder.%d.block.5.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B_B5, "decoder.%d.block.5.res.conv2.bias" }, + { LLM_TENSOR_ALPHA, "decoder.6.alpha" }, + { LLM_TENSOR_CONV_W7, "decoder.7.conv.weight" }, + { LLM_TENSOR_CONV_S7, "decoder.7.conv.scale" }, + { LLM_TENSOR_CONV_B7, "decoder.7.conv.bias" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1552,8 +1602,53 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + + { LLM_TENSOR_CONV_B2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CONV_S2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_W2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_BLOCK_ALPHA, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_TRANS_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_TRANS_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_TRANS_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_NOISE_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_NOISE_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT } }, + { LLM_TENSOR_RES_SNAKE1_A, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE1_A_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE1_A_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_ALPHA, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_B7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CONV_S7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_W7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + + { LLM_TENSOR_CODEBOOK, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS } }, + { LLM_TENSOR_CODEBOOK_PROJ_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CODEBOOK_PROJ_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CODEBOOK_PROJ_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, }; + + LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { @@ -1563,6 +1658,7 @@ std::string LLM_KV::operator()(llm_kv kv) const { std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + fprintf(stderr, "LLM_TN_IMPL::str: tensor enum %d not found in map for arch %d\n", (int)tensor, (int)arch); return "__missing__"; } diff --git a/src/llama-arch.h b/src/llama-arch.h index a28815d8a14c7..5d649d045cc78 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_SNAC_DEC, LLM_ARCH_UNKNOWN, }; @@ -201,6 +202,11 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_QUANTIZER_COUNT, + LLM_KV_QUANTIZER_STRIDES, + LLM_KV_DECODER_UPSAMPLE_RATES, + LLM_KV_DECODER_CHANNEL_DIMS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -346,6 +352,52 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + + LLM_TENSOR_CONV_B1, + LLM_TENSOR_CONV_S1, + LLM_TENSOR_CONV_W1, + LLM_TENSOR_CONV_B2, + LLM_TENSOR_CONV_S2, + LLM_TENSOR_CONV_W2, + LLM_TENSOR_BLOCK_ALPHA, + LLM_TENSOR_TRANS_B, + LLM_TENSOR_TRANS_S, + LLM_TENSOR_TRANS_W, + LLM_TENSOR_NOISE_S, + LLM_TENSOR_NOISE_W, + LLM_TENSOR_RES_SNAKE1_A, + LLM_TENSOR_RES_CONV1_B, + LLM_TENSOR_RES_CONV1_S, + LLM_TENSOR_RES_CONV1_W, + LLM_TENSOR_RES_SNAKE2_A, + LLM_TENSOR_RES_CONV2_B, + LLM_TENSOR_RES_CONV2_S, + LLM_TENSOR_RES_CONV2_W, + LLM_TENSOR_RES_SNAKE1_A_B4, + LLM_TENSOR_RES_CONV1_B_B4, + LLM_TENSOR_RES_CONV1_S_B4, + LLM_TENSOR_RES_CONV1_W_B4, + LLM_TENSOR_RES_SNAKE2_A_B4, + LLM_TENSOR_RES_CONV2_B_B4, + LLM_TENSOR_RES_CONV2_S_B4, + LLM_TENSOR_RES_CONV2_W_B4, + LLM_TENSOR_RES_SNAKE1_A_B5, + LLM_TENSOR_RES_CONV1_B_B5, + LLM_TENSOR_RES_CONV1_S_B5, + LLM_TENSOR_RES_CONV1_W_B5, + LLM_TENSOR_RES_SNAKE2_A_B5, + LLM_TENSOR_RES_CONV2_B_B5, + LLM_TENSOR_RES_CONV2_S_B5, + LLM_TENSOR_RES_CONV2_W_B5, + LLM_TENSOR_ALPHA, + LLM_TENSOR_CONV_B7, + LLM_TENSOR_CONV_S7, + LLM_TENSOR_CONV_W7, + + LLM_TENSOR_CODEBOOK, + LLM_TENSOR_CODEBOOK_PROJ_B, + LLM_TENSOR_CODEBOOK_PROJ_S, + LLM_TENSOR_CODEBOOK_PROJ_W, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2e79ff..2958f44e0e76d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -851,6 +851,10 @@ float * llama_context::get_logits_ith(int32_t i) { } } +int32_t llama_context::get_n_outputs() { + return n_outputs; +} + float * llama_context::get_embeddings() { // reorder embeddings for backward compatibility output_reorder(); @@ -1344,13 +1348,13 @@ int llama_context::decode(llama_batch & inp_batch) { const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; } } @@ -1403,10 +1407,21 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT(embd != nullptr); float * embd_out = embd + n_outputs_prev*n_embd; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (model.arch == LLM_ARCH_SNAC_DEC) { + // TODO: hack, SNAC outputs audio samples, not embeddings + // Rely on n_outputs for now, but perhaps add an `n_samples_snac` to + // llama_context to avoid doing these checks + int64_t n_samples = t_embd->ne[0]; + if (n_samples > 0) { + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_samples * sizeof(float)); + n_outputs = n_samples; // Update for downstream + } + } else { + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs * n_embd * sizeof(float)); + } } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1471,8 +1486,11 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; + // TODO: Hack for now to avoid overwriting n_outputs in previous step + if (model.arch != LLM_ARCH_SNAC_DEC) { + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + } // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); @@ -2417,6 +2435,12 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +int32_t llama_get_n_outputs(struct llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_n_outputs(); +} + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544cb1a..ff9ad663d1fe5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -48,6 +48,8 @@ struct llama_context { float * get_logits(); float * get_logits_ith(int32_t i); + int32_t get_n_outputs(); + float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438cc..449a94b52d642 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -556,6 +556,81 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_snac::set_input(const llama_ubatch * ubatch) { + + LLAMA_LOG_INFO("Setting SNAC input for layer %d\n", ilayer); + + const int n_tokens = ubatch->n_tokens; + if (n_tokens % frame_size != 0) { + return; // TODO: handle gracefully + } + const int n_frames = n_tokens / frame_size; + + int64_t expected_elements = 0; + int vocab_offset = 0; + int tokens_per_frame = 0; + + switch (ilayer) { + case 0: // Layer 1 + tokens_per_frame = 1; + vocab_offset = 128266; // TODO: hparams + break; + case 1: // Layer 2 + tokens_per_frame = 2; + vocab_offset = 132362; + break; + case 2: // Layer 3 + tokens_per_frame = 4; + vocab_offset = 136458; + break; + default: + LLAMA_LOG_ERROR("%s: Invalid SNAC layer index %d encountered.\n", __func__, ilayer); + GGML_ASSERT(false && "Invalid SNAC layer index"); // Should be caught by constructor assert + return; + } + expected_elements = (int64_t)n_frames * tokens_per_frame; + + std::vector indices; + indices.reserve(expected_elements); + + const llama_token * tokens_data = ubatch->token; + + for (int i_frame = 0; i_frame < n_frames; ++i_frame) { + const int frame_start_idx = i_frame * frame_size; + const llama_token * frame_tokens = tokens_data + frame_start_idx; + + switch (ilayer) { + case 0: { // L1: token 0 + int32_t index = (int32_t)(frame_tokens[0] - vocab_offset); + + indices.push_back(index); + break; + } + case 1: { // L2: tokens 1, 4 + int32_t index1 = (int32_t)(frame_tokens[1] - vocab_offset); + int32_t index4 = (int32_t)(frame_tokens[4] - vocab_offset); + + indices.push_back(index1); + indices.push_back(index4); + break; + } + case 2: { // L3: tokens 2, 3, 5, 6 + int32_t index2 = (int32_t)(frame_tokens[2] - vocab_offset); + int32_t index3 = (int32_t)(frame_tokens[3] - vocab_offset); + int32_t index5 = (int32_t)(frame_tokens[5] - vocab_offset); + int32_t index6 = (int32_t)(frame_tokens[6] - vocab_offset); + + indices.push_back(index2); + indices.push_back(index3); + indices.push_back(index5); + indices.push_back(index6); + break; + } + } + } + ggml_backend_tensor_set(target, indices.data(), 0, ggml_nbytes(target)); +} + // // llm_graph_context // diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..cb385c0400474 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -268,6 +268,20 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { const llama_cross * cross = nullptr; }; +class llm_graph_input_snac : public llm_graph_input_i { +public: + llm_graph_input_snac(ggml_tensor * target, int ilayer, + const llama_hparams & hparams) : target(target), ilayer(ilayer), hparams(hparams) {} + virtual ~llm_graph_input_snac() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * target; // idx tensor 1, 2, or 3 + const llama_hparams & hparams; + const int ilayer; + const int frame_size = 7; +}; + // // llm_graph_result // diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bb17ba86dc2fb..ed0f9ab98d8e0 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -70,6 +70,10 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + // for SNAC vocoder + std::array upsample_rates; + std::array n_channels; + // for RWKV uint32_t rescale_every_n_layers = 0; uint32_t time_mix_extra_dim = 0; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 05d58ad90eba9..4730b8c6a5d22 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -745,6 +745,8 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } } if (!is_ok) { + fprintf(stderr, "check_tensor_dims: name=%s, expected=%s, got=%s\n", + name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(cur).c_str()); throw std::runtime_error( format("%s: tensor '%s' has wrong shape; expected %s, got %s", __func__, name.c_str(), diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cd7e0a0c4dbf8..a82cfa8aaa337 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1317,6 +1317,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_SNAC_DEC: + { + // TODO: Read from GGUF + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; + hparams.upsample_rates = {8, 8, 4, 2}; + hparams.n_embd = 768; + hparams.n_layer = 8; + + // Dummy KV cache params to satisfy init error + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.n_head_arr[i] = 1; + hparams.n_head_kv_arr[i] = 1; + } + hparams.n_embd_head_k = 1; + hparams.n_embd_head_v = 1; + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3686,7 +3702,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); } - // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } @@ -3694,6 +3709,136 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_SNAC_DEC: + { + // TODO: Magic numbers everwhere + const int64_t n_total_layers = hparams.n_layer; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {8, 4096, 1}, 0); + + // Quantizer projection tensors (0, 1, 2) + for (int qid = 0; qid < 3; ++qid) { + fprintf(stderr, "%s: Loading quantizer %d tensors\n", __func__, qid); + // Bias: [768, 1, 1, 1] -> {768} + codebook_proj_b[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_B, qid, -1), {768, 1, 1}, 0); + // Scale: [1, 1, 768, 1] -> {768} + codebook_proj_s[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_S, qid, -1), {1, 1, 768}, 0); + // Weight: [1, 8, 768, 1] -> {8, 768} + codebook_proj_w[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_W, qid, -1), {1, 8, 768}, 0); + + codebook[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK, qid, -1), {8, 4096, 1, 1}, 0); + } + + // Decoder tensors + for (int i = 1; i < n_total_layers; ++i) { // Loop from i = 0 to 7 + auto & layer = layers[i]; + + // Calculate n_in and n_out for the current layer i + const int64_t n_in = (i == 0) ? 1 : ((i == 7) ? hparams.n_channels[i-2] /* 64 */ : hparams.n_channels[i-1]); + const int64_t n_out = (i == 7) ? hparams.n_channels[i-1] /* 1 */ : hparams.n_channels[i]; + + fprintf(stderr, "%s: Layer %d: Starting (n_in=%lld, n_out=%lld)\n", __func__, i, n_in, n_out); + + if (i == 1) { // --- Layer 1: Conv2 --- + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV_W2, i, -1), {1, n_in, n_out}, 0); + layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S2, i, -1), {1, 1, n_out}, 0); + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B2, i, -1), {n_out}, 0); + } + else if (i >= 2 && i <= 5) { // --- Layers 2-5: Blocks --- + const int n_blocks = 6; + layer.decoder_blocks.resize(n_blocks); + + for (int bid = 0; bid < n_blocks; ++bid) { + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Starting\n", __func__, i, bid); + + switch (bid) { + case 0: // Block 0: Alpha + layer.decoder_blocks[bid].alpha = create_tensor(tn(LLM_TENSOR_BLOCK_ALPHA, i, bid), {1, n_in, 1}, 0); + break; + case 1: // Block 1: Transition + { + int64_t trans_dim; + if (i == 2) trans_dim = 16; + else if (i == 3) trans_dim = 16; + else if (i == 4) trans_dim = 8; + else trans_dim = 4; // Assumed for i == 5 + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Using trans_dim = %lld\n", __func__, i, bid, trans_dim); + layer.decoder_blocks[bid].up_weight = create_tensor(tn(LLM_TENSOR_TRANS_W, i, bid), {trans_dim, n_out, n_in}, 0); + layer.decoder_blocks[bid].up_scale = create_tensor(tn(LLM_TENSOR_TRANS_S, i, bid), {1, 1, n_in}, 0); + layer.decoder_blocks[bid].up_bias = create_tensor(tn(LLM_TENSOR_TRANS_B, i, bid), {n_out}, 0); + if (!layer.decoder_blocks[bid].up_bias) { + LLAMA_LOG_DEBUG("Failed to create decoder.%d.block.%d.trans.bias\n", i, bid); + } + } + break; + case 2: + { + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Loading noise tensors\n", __func__, i, bid); + layer.decoder_blocks[bid].noise_w = create_tensor(tn(LLM_TENSOR_NOISE_W, i, bid), {1, n_out, n_out}, 0); + layer.decoder_blocks[bid].noise_s = create_tensor(tn(LLM_TENSOR_NOISE_S, i, bid), {1, 1, n_out}, 0); + } + break; + case 3: // Block 3: Residual Unit 1 + { + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); + } + break; + case 4: // Block 4: Residual Unit 2 + { + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); + } + break; + case 5: // Block 5: Residual Unit 3 + { + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); + } + break; + default: + fprintf(stderr, "%s: ERROR: Unexpected block id %d in layer %d\n", __func__, bid, i); + return false; + } + fprintf(stderr, "%s: Layer %d, Block %d: Finished\n", __func__, i, bid); + } + } + else if (i == 6) { // --- Layer 6: Alpha --- + layer.alpha = create_tensor(tn(LLM_TENSOR_ALPHA, i, -1), {1, n_in, 1}, 0); + } + else if (i == 7) { // --- Layer 7: Final Conv --- + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV_W7, i, -1), {7, n_in, n_out}, 0); + layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S7, i, -1), {1, 1, n_out}, 0); + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B7, i, -1), {n_out}, 0); + } + else { + fprintf(stderr, "%s: ERROR: Unexpected layer index %d\n", __func__, i); + return false; + } + fprintf(stderr, "%s: Layer %d: Finished\n", __func__, i); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -11597,6 +11742,196 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; +struct llm_build_snac_dec : public llm_graph_context { + + llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * emb_layer_1, * emb_layer_2, * emb_layer_3; + + bool inputs = build_snac_inputs(model, &emb_layer_1, &emb_layer_2, &emb_layer_3); + + if (!inputs) { + // graph build is called with garbage ubatch codes during model init + // in this case, bypass normal graph construction and return a dummy + LLAMA_LOG_INFO("build_codebook_inputs returned null, using dummy tensor\n"); + cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, ubatch.n_tokens > 0 ? ubatch.n_tokens : 64, 1, 1); + ggml_set_input(cur); + } else { + // TODO: Upsampling is wrong + // Projections + cur = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[0], 8, 768), emb_layer_1); + cur = ggml_reshape_4d(ctx0, cur, 768, emb_layer_1->ne[1], 1, 1); + ggml_tensor * scale_1 = ggml_reshape_4d(ctx0, model.codebook_proj_s[0], 768, 1, 1, 1); + cur = ggml_mul(ctx0, cur, scale_1); + ggml_tensor * bias_1 = ggml_reshape_4d(ctx0, model.codebook_proj_b[0], 768, 1, 1, 1); // Fix here + cur = ggml_add(ctx0, cur, bias_1); + + ggml_tensor * proj_2 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[1], 8, 768), emb_layer_2); + proj_2 = ggml_reshape_4d(ctx0, proj_2, 768, emb_layer_2->ne[1], 1, 1); + ggml_tensor * scale_2 = ggml_reshape_4d(ctx0, model.codebook_proj_s[1], 768, 1, 1, 1); + proj_2 = ggml_mul(ctx0, proj_2, scale_2); + ggml_tensor * bias_2 = ggml_reshape_4d(ctx0, model.codebook_proj_b[1], 768, 1, 1, 1); + proj_2 = ggml_add(ctx0, proj_2, bias_2); + + ggml_tensor * proj_3 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[2], 8, 768), emb_layer_3); + proj_3 = ggml_reshape_4d(ctx0, proj_3, 768, emb_layer_3->ne[1], 1, 1); + ggml_tensor * scale_3 = ggml_reshape_4d(ctx0, model.codebook_proj_s[2], 768, 1, 1, 1); + proj_3 = ggml_mul(ctx0, proj_3, scale_3); + ggml_tensor * bias_3 = ggml_reshape_4d(ctx0, model.codebook_proj_b[2], 768, 1, 1, 1); + proj_3 = ggml_add(ctx0, proj_3, bias_3); + + cur = ggml_concat(ctx0, cur, proj_2, 1); + cur = ggml_concat(ctx0, cur, proj_3, 1); + + for (int j = 1; j <= hparams.n_layer; ++j) { + const auto & layer = model.layers[j]; + const int64_t n_in = hparams.n_channels[j-1]; + const int64_t n_out = (j < 7) ? hparams.n_channels[j] : hparams.n_channels[j-1]; + + if (j == 1) { + int64_t seq_len = cur->ne[1]; + cur = ggml_reshape_2d(ctx0, cur, 768, seq_len); // cur starts F32 (type 0) from projections + ggml_tensor * w = ggml_reshape_2d(ctx0, layer.conv_w, 768, 1024); // F16 (type 1) + ggml_tensor * s = ggml_cpy(ctx0, layer.conv_s, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, 1, n_out)); // Cast F32 -> F16 + w = ggml_mul(ctx0, w, s); + cur = ggml_mul_mat(ctx0, w, cur); + cur = ggml_reshape_4d(ctx0, cur, seq_len, 1024, 1, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + // Residual Units + else if (j >= 2 && j <= 5) { + ggml_tensor * alpha = layer.decoder_blocks[0].alpha; + cur = ggml_snake(ctx0, cur, alpha); + + ggml_tensor * w = layer.decoder_blocks[1].up_weight; + ggml_tensor * s = ggml_cpy(ctx0, layer.decoder_blocks[1].up_scale, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_in, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_transpose_1d(ctx0, w, cur, hparams.upsample_rates[j-2], 0, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.decoder_blocks[1].up_bias, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + + ggml_tensor * noise_w = layer.decoder_blocks[2].noise_w; + ggml_tensor * noise_s = ggml_cpy(ctx0, layer.decoder_blocks[2].noise_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + noise_w = ggml_mul(ctx0, noise_w, noise_s); + cur = ggml_conv_1d(ctx0, noise_w, cur, 1, 0, 1); + + for (int r = 0; r < 3; ++r) { + int bid = 3 + r; + ggml_tensor * w1 = layer.decoder_blocks[bid].res_unit.conv1_w; + ggml_tensor * s1 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv1_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w1 = ggml_mul(ctx0, w1, s1); + cur = ggml_conv_1d_dw(ctx0, w1, cur, 1, 3, 1); + ggml_tensor * b1 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv1_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b1); + + ggml_tensor * w2 = layer.decoder_blocks[bid].res_unit.conv2_w; + ggml_tensor * s2 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv2_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w2 = ggml_mul(ctx0, w2, s2); + cur = ggml_conv_1d(ctx0, w2, cur, 1, 0, 1); + ggml_tensor * b2 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv2_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b2); + } + } + else if (j == 6) { + ggml_tensor * alpha = layer.alpha; + cur = ggml_snake(ctx0, cur, alpha); + } + else if (j == 7) { + ggml_tensor * w = layer.conv_w; + ggml_tensor * s = layer.conv_s; + + s = ggml_reshape_4d(ctx0, s, 1, 1, 1, 1); + s = ggml_cpy(ctx0, s, ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, 1, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_1d(ctx0, w, cur, 1, 3, 1); + + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + } + + } + + cur = ggml_cpy(ctx0, cur, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3])); + + LLAMA_LOG_INFO("Final shape of cur = [%ld, %ld, %ld, %ld]\n", + cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + //cb(cur, "result_embd", -1); + res->t_embd = cur; + ggml_build_forward_expand(gf, cur); + } +private: + // Create 3 input nodes used for lookups into 3 embd matrices + bool build_snac_inputs(const llama_model & model, + ggml_tensor ** emb_layer_1_out, + ggml_tensor ** emb_layer_2_out, + ggml_tensor ** emb_layer_3_out) { + + *emb_layer_1_out = nullptr; + *emb_layer_2_out = nullptr; + *emb_layer_3_out = nullptr; + + if (this->ubatch.n_tokens <= 0 || this->ubatch.n_tokens % 7 != 0) { + LLAMA_LOG_WARN("%s: Invalid ubatch size n_tokens=%d provided for SNAC graph definition. Cannot define input nodes.\n", + __func__, this->ubatch.n_tokens); + return false; + } + + const int32_t n_tokens = this->ubatch.n_tokens; + const int32_t n_frames = n_tokens / 7; + + const int32_t n_indices_l1 = n_frames * 1; + const int32_t n_indices_l2 = n_frames * 2; + const int32_t n_indices_l3 = n_frames * 4; + + ggml_tensor * idx1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_indices_l1); + ggml_tensor * idx2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_indices_l2); + ggml_tensor * idx3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_indices_l3); + + if (!idx1 || !idx2 || !idx3) { + LLAMA_LOG_ERROR("%s: Failed to allocate ggml index tensors.\n", __func__); + return false; + } + + ggml_set_name(idx1, "snac_indices_L1"); + ggml_set_name(idx2, "snac_indices_L2"); + ggml_set_name(idx3, "snac_indices_L3"); + + // mark as inputs + ggml_set_input(idx1); + ggml_set_input(idx2); + ggml_set_input(idx3); + + // add to res for future access via llama_context + res->add_input(std::make_unique(idx1, 0, this->hparams)); + res->add_input(std::make_unique(idx2, 1, this->hparams)); + res->add_input(std::make_unique(idx3, 2, this->hparams)); + + // lookup + *emb_layer_1_out = ggml_get_rows(ctx0, model.codebook[0], idx1); + *emb_layer_2_out = ggml_get_rows(ctx0, model.codebook[1], idx2); + *emb_layer_3_out = ggml_get_rows(ctx0, model.codebook[2], idx3); + + if (!*emb_layer_1_out || !*emb_layer_2_out || !*emb_layer_3_out) { + LLAMA_LOG_ERROR("%s: Failed to create ggml_get_rows nodes.\n", __func__); + *emb_layer_1_out = *emb_layer_2_out = *emb_layer_3_out = nullptr; // Ensure outputs are null on failure + return false; + } + + ggml_set_name(*emb_layer_1_out, "snac_embd_L1"); + ggml_set_name(*emb_layer_2_out, "snac_embd_L2"); + ggml_set_name(*emb_layer_3_out, "snac_embd_L3"); + + return true; + } +}; + llama_memory_i * llama_model::create_memory() const { llama_memory_i * res; @@ -11868,6 +12203,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_SNAC_DEC: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -11976,6 +12315,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_SNAC_DEC: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215abbfd..0524d1b82705a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -137,6 +137,29 @@ struct llama_layer_convnext { struct ggml_tensor * gamma = nullptr; }; +struct llama_layer_snac_dec_block { + + struct ggml_tensor * alpha = nullptr; + + struct ggml_tensor * up_weight = nullptr; + struct ggml_tensor * up_scale = nullptr; + struct ggml_tensor * up_bias = nullptr; + + struct ggml_tensor * noise_w = nullptr; + struct ggml_tensor * noise_s = nullptr; + + struct { + struct ggml_tensor * alpha1 = nullptr; + struct ggml_tensor * conv1_w = nullptr; + struct ggml_tensor * conv1_s = nullptr; + struct ggml_tensor * conv1_b = nullptr; + struct ggml_tensor * alpha2 = nullptr; + struct ggml_tensor * conv2_w = nullptr; + struct ggml_tensor * conv2_s = nullptr; + struct ggml_tensor * conv2_b = nullptr; + } res_unit; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -304,6 +327,13 @@ struct llama_layer { struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; + + struct ggml_tensor * conv_w = nullptr; + struct ggml_tensor * conv_s = nullptr; + struct ggml_tensor * conv_b = nullptr; + struct ggml_tensor * alpha = nullptr; + + std::vector decoder_blocks; }; struct llama_model { @@ -336,6 +366,13 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + + // TODO: structify + ggml_tensor * codebook[3]; + ggml_tensor * codebook_proj_b[3]; // Array for quantizer 0, 1, 2 bias + ggml_tensor * codebook_proj_s[3]; // Array for quantizer 0, 1, 2 scale + ggml_tensor * codebook_proj_w[3]; + std::vector layers; llama_model_params params;