From 9906bd971677482f09f03a29b8db1eddb07f18b8 Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Thu, 20 Mar 2025 21:02:35 -0700 Subject: [PATCH 1/7] (draft) tts: Orpheus support Working on each part incrementally, added a rough draft for SNAC convertion to .gguf --- convert_hf_to_gguf.py | 32 ++++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 21 +++++++++++++++++++++ gguf-py/gguf/gguf_writer.py | 12 ++++++++++++ 3 files changed, 65 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index d21edce16b71e..00f2a49128aa7 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2327,6 +2327,38 @@ def set_gguf_parameters(self): self.gguf_writer.add_causal_attention(False) +@Model.register("SNACDec") +class SNACDecModel(Model): + model_arch = gguf.MODEL_ARCH.SNAC_DEC # Assumes this constant is defined in gguf + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]: + del bid # unused + + logger.debug(f"Processing tensor: {name}") + + if (name.startswith("decoder.") or + re.match(r"quantizer\.quantizers\.\d+\.codebook\.weight", name) or + re.match(r"quantizer\.quantizers\.\d+\.out_proj\..*", name)): + logger.info(f"{name} -> {data_torch.shape}") + return [(name, data_torch)] + else: + logger.debug(f"Skipping {name!r}") + return [] + + def set_vocab(self): + self._set_vocab_none() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size(self.hparams["codebook_size"]) + self.gguf_writer.add_quantizer_count(len(self.hparams["vq_strides"])) + self.gguf_writer.add_features_length(self.hparams["codebook_dim"]) + self.gguf_writer.add_quantizer_strides(self.hparams["vq_strides"]) + self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) + self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) + self.gguf_writer.add_decoder_channel_dims(self.hparams["decoder_channel_dims"]) + self.gguf_writer.add_convnext_block_count(1) + @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): model_arch = gguf.MODEL_ARCH.QWEN2MOE diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cc48913d9789d..15e86da8960f4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -173,6 +173,13 @@ class ConvNext: EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" BLOCK_COUNT = "{arch}.convnext.block_count" + class AudioCodec: + QUANTIZER_COUNT = "{arch}.audio_codec.quantizer_count" + CODEBOOK_DIM = "{arch}.audio_codec.codebook_dim" + QUANTIZER_STRIDES = "{arch}.audio_codec.quantizer_strides" + DECODER_UPSAMPLE_RATES = "{arch}.audio_codec.decoder_upsample_rates" + DECODER_CHANNEL_DIMS = "{arch}.audio_codec.decoder_channel_dims" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -286,6 +293,7 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + SNAC_DEC = auto() class MODEL_TENSOR(IntEnum): @@ -425,6 +433,7 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + UPSAMPLE_CONV = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -488,6 +497,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.SNAC_DEC: "snac-dec", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -627,6 +637,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + MODEL_TENSOR.UPSAMPLE_CONV: "upsample_conv.{bid}", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1650,6 +1661,16 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], + MODEL_ARCH.SNAC_DEC: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.CONVNEXT_DW, + MODEL_TENSOR.CONVNEXT_GAMMA, + MODEL_TENSOR.UPSAMPLE_CONV, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.OUTPUT, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af8b388dfaba5..9cf3a55702a5a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -887,6 +887,18 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_quantizer_count(self, count: int) -> None: + self.add_uint32(Keys.AudioCodec.QUANTIZER_COUNT.format(arch=self.arch), count) + + def add_quantizer_strides(self, strides: Sequence[int]) -> None: + self.add_array(Keys.AudioCodec.QUANTIZER_STRIDES.format(arch=self.arch), strides) + + def add_decoder_upsample_rates(self, rates: Sequence[int]) -> None: + self.add_array(Keys.AudioCodec.DECODER_UPSAMPLE_RATES.format(arch=self.arch), rates) + + def add_decoder_channel_dims(self, dims: Sequence[int]) -> None: + self.add_array(Keys.AudioCodec.DECODER_CHANNEL_DIMS.format(arch=self.arch), dims) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None From ad7d7ff05fe601545ee3c313f7af20dee897a707 Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Fri, 21 Mar 2025 20:48:21 -0700 Subject: [PATCH 2/7] Scaffolding for snake activation fn SNAC uses the snake activation function. Added scaffolding to include `GGML_OP_SNAKE` as a new op. Should this be a unary op? The SNAC decoder uses noise blocks to enhance outputs, its optional, so omitting it for now until the model is integrated e2e. Next steps: write the `llm_graph_context` for SNAC --- convert_hf_to_gguf.py | 3 +- ggml/include/ggml.h | 11 ++++ ggml/src/ggml-cpu/ggml-cpu.c | 99 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml.c | 34 ++++++++++++- 4 files changed, 143 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 00f2a49128aa7..7984834188c5e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2329,7 +2329,7 @@ def set_gguf_parameters(self): @Model.register("SNACDec") class SNACDecModel(Model): - model_arch = gguf.MODEL_ARCH.SNAC_DEC # Assumes this constant is defined in gguf + model_arch = gguf.MODEL_ARCH.SNAC_DEC def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]: del bid # unused @@ -2357,7 +2357,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) self.gguf_writer.add_decoder_channel_dims(self.hparams["decoder_channel_dims"]) - self.gguf_writer.add_convnext_block_count(1) @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index cb3edb10d4702..42e1d2506337e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -492,6 +492,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_SNAKE, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -1062,6 +1063,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_snake( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha); + + GGML_API struct ggml_tensor * ggml_snake_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 75dc96b478655..def6eb3423c61 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1911,6 +1911,21 @@ inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f)); } } +inline static void ggml_vec_snake_f32(const int n, float * y, const float * x, const float a) { + for (int i = 0; i < n; ++i) { + float x_val = x[i]; + float sin_val = sinf(a * x_val); + y[i] = x_val + sin_val * sin_val; + } +} +inline static void ggml_vec_snake_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t a) { + for (int i = 0; i < n; ++i) { + float x_val = GGML_FP16_TO_FP32(x[i]); // TODO: double check this conversion + float a_val = GGML_FP16_TO_FP32(a); + float sin_val = sinf(a_val * x_val); + y[i] = GGML_FP32_TO_FP16(x_val + sin_val * sin_val); + } +} inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { @@ -7817,6 +7832,86 @@ static void ggml_compute_forward_leaky_relu( } } +// ggml_compute_forward_snake + +static void ggml_compute_forward_snake_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + + // Scaffold code, 1 thread for now + // TODO: add multithreading + if (params->ith != 0) { + return; + } + + struct ggml_tensor * alpha = *(struct ggml_tensor **)(dst->op_params); + const float * x = (const float *)src0->data; + const float * a = (const float *)alpha->data; + float * y = (float *)dst->data; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int channels = src0->ne[1]; + + for (int i = 0; i < n; i++) { + int c = i % channels; + ggml_vec_snake_f32(nc, + (float *) ((char *) y + i * dst->nb[1]), + (const float *) ((const char *) x + i * src0->nb[1]), + a[c]); // alpha tensor for this channel + } +} + +static void ggml_compute_forward_snake_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + struct ggml_tensor * alpha = *(struct ggml_tensor **)(dst->op_params); + const ggml_fp16_t * x = (const ggml_fp16_t *)src0->data; + const ggml_fp16_t * a = (const ggml_fp16_t *)alpha->data; + ggml_fp16_t * y = (ggml_fp16_t *)dst->data; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int channels = src0->ne[1]; + + for (int i = 0; i < n; i++) { + int c = i % channels; + ggml_vec_snake_f16(nc, + (ggml_fp16_t *) ((char *) y + i * dst->nb[1]), + (const ggml_fp16_t *) ((const char *) x + i * src0->nb[1]), + a[c]); + } +} + +static void ggml_compute_forward_snake( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_snake_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_snake_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_silu_back static void ggml_compute_forward_silu_back_f32( @@ -14555,6 +14650,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; + case GGML_OP_SNAKE: + { + ggml_compute_forward_snake(params, tensor); + } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 2e081d5910c6e..adb8c81e7b547 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -967,6 +967,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "SNAKE", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -998,7 +999,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1097,7 +1098,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2474,6 +2475,35 @@ struct ggml_tensor * ggml_leaky_relu( return result; } +// ggml snake + +struct ggml_tensor * ggml_snake( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + // store ptr to alpha tensor + ggml_set_op_params(result, &alpha, sizeof(alpha)); + result->op = GGML_OP_SNAKE; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_snake_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * alpha) { + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + ggml_set_op_params(result, &alpha, sizeof(alpha)); + result->op = GGML_OP_SNAKE; + result->src[0] = a; + + return result; +} + // ggml_sigmoid struct ggml_tensor * ggml_sigmoid( From efd527644c408c81ad4ccb79d28683ceb76a8e2f Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Tue, 25 Mar 2025 10:26:01 -0700 Subject: [PATCH 3/7] (rebase me) snac graph, tensor layour and creation now, integrate the LM (seems straightforward it's llama3), rewrite/extend/add to tts.cpp, then fix bugs and optimize. --- convert_hf_to_gguf.py | 18 ++-- gguf-py/gguf/constants.py | 4 +- gguf-py/gguf/gguf_writer.py | 6 -- src/llama-arch.cpp | 22 ++++- src/llama-arch.h | 18 ++++ src/llama-hparams.h | 4 + src/llama-model.cpp | 184 ++++++++++++++++++++++++++++++++++++ src/llama-model.h | 26 +++++ 8 files changed, 261 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7984834188c5e..220af521f47f2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2334,11 +2334,7 @@ class SNACDecModel(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]: del bid # unused - logger.debug(f"Processing tensor: {name}") - - if (name.startswith("decoder.") or - re.match(r"quantizer\.quantizers\.\d+\.codebook\.weight", name) or - re.match(r"quantizer\.quantizers\.\d+\.out_proj\..*", name)): + if (name.startswith("decoder.")): logger.info(f"{name} -> {data_torch.shape}") return [(name, data_torch)] else: @@ -2350,13 +2346,15 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vocab_size(self.hparams["codebook_size"]) - self.gguf_writer.add_quantizer_count(len(self.hparams["vq_strides"])) - self.gguf_writer.add_features_length(self.hparams["codebook_dim"]) - self.gguf_writer.add_quantizer_strides(self.hparams["vq_strides"]) + # TODO: Don't think codebook is needed if the LM is a drop in quantizer replacement + #self.gguf_writer.add_vocab_size(self.hparams["codebook_size"]) + #self.gguf_writer.add_features_length(self.hparams["codebook_dim"]) self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) - self.gguf_writer.add_decoder_channel_dims(self.hparams["decoder_channel_dims"]) + self.gguf_writer.add_uint32("n_layers", len(self.hparams["decoder_rates"])) # Infer as 4 from decoder_rates + self.gguf_writer.add_array("decoder_channel_dims", [768, 1024, 512, 256, 128, 64, 1]) + # TODO: Add sampling rate? + #self.gguf_writer.add_decoder_channel_dims(self.hparams["sampling_rate"]) @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 15e86da8960f4..19fb2319f85cc 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -174,9 +174,7 @@ class ConvNext: BLOCK_COUNT = "{arch}.convnext.block_count" class AudioCodec: - QUANTIZER_COUNT = "{arch}.audio_codec.quantizer_count" - CODEBOOK_DIM = "{arch}.audio_codec.codebook_dim" - QUANTIZER_STRIDES = "{arch}.audio_codec.quantizer_strides" + #CODEBOOK_DIM = "{arch}.audio_codec.codebook_dim" DECODER_UPSAMPLE_RATES = "{arch}.audio_codec.decoder_upsample_rates" DECODER_CHANNEL_DIMS = "{arch}.audio_codec.decoder_channel_dims" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9cf3a55702a5a..f1924358d9cad 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -887,12 +887,6 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) - def add_quantizer_count(self, count: int) -> None: - self.add_uint32(Keys.AudioCodec.QUANTIZER_COUNT.format(arch=self.arch), count) - - def add_quantizer_strides(self, strides: Sequence[int]) -> None: - self.add_array(Keys.AudioCodec.QUANTIZER_STRIDES.format(arch=self.arch), strides) - def add_decoder_upsample_rates(self, rates: Sequence[int]) -> None: self.add_array(Keys.AudioCodec.DECODER_UPSAMPLE_RATES.format(arch=self.arch), rates) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9debb56cc80d5..e3eb12790a5cb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1392,11 +1392,29 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_UNKNOWN, + LLM_ARCH_SNAC_DEC, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, "decoder.model.%d.parametrizations.weight.original0" }, + { LLM_TENSOR_WEIGHT_NORMALIZATION_KERNEL, "decoder.model.%d.parametrizations.weight.original1" }, + { LLM_TENSOR_TRANSPOSED_CONV_WEIGHT, "decoder.model.%d.parametrizations.weight.original1" }, + { LLM_TENSOR_CONV1D_BIAS, "decoder.model.%d.bias" }, + { LLM_TENSOR_TRANSPOSED_CONV_BIAS, "decoder.model.%d.bias" }, + { LLM_TENSOR_SNAKE_ALPHA, "decoder.model.%d.alpha" }, + { LLM_TENSOR_SNAKE_ALPHA, "decoder.model.%d.block.%d.alpha" }, + { LLM_TENSOR_SNAKE_ALPHA, "decoder.model.%d.block.%d.block.%d.alpha" }, + { LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, "decoder.model.%d.block.%d.parametrizations.weight.original0" }, + { LLM_TENSOR_WEIGHT_NORMALIZATION_KERNEL, "decoder.model.%d.block.%d.parametrizations.weight.original1" }, + { LLM_TENSOR_CONV1D_BIAS, "decoder.model.%d.block.%d.bias" }, + { LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, "decoder.model.%d.block.%d.linear.parametrizations.weight.original0" }, + { LLM_TENSOR_LINEAR_WEIGHT, "decoder.model.%d.block.%d.linear.parametrizations.weight.original1" }, }, }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, }; static const std::map LLM_TENSOR_INFOS = { diff --git a/src/llama-arch.h b/src/llama-arch.h index a28815d8a14c7..23eec12e718f1 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_SNAC_DEC, LLM_ARCH_UNKNOWN, }; @@ -201,6 +202,11 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_QUANTIZER_COUNT, + LLM_KV_QUANTIZER_STRIDES, + LLM_KV_DECODER_UPSAMPLE_RATES, + LLM_KV_DECODER_CHANNEL_DIMS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -346,6 +352,18 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + + LLM_TENSOR_CODEBOOK_WEIGHT, + LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, + LLM_TENSOR_WEIGHT_NORMALIZATION_KERNEL, + LLM_TENSOR_CONV1D_WEIGHT, + LLM_TENSOR_CONV1D_SCALE, + LLM_TENSOR_CONV1D_BIAS, + LLM_TENSOR_SNAKE_ALPHA, + LLM_TENSOR_LINEAR_WEIGHT, + LLM_TENSOR_TRANSPOSED_CONV_WEIGHT, + LLM_TENSOR_TRANSPOSED_CONV_SCALE, + LLM_TENSOR_TRANSPOSED_CONV_BIAS, }; enum llm_tensor_layer { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bb17ba86dc2fb..ed0f9ab98d8e0 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -70,6 +70,10 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + // for SNAC vocoder + std::array upsample_rates; + std::array n_channels; + // for RWKV uint32_t rescale_every_n_layers = 0; uint32_t time_mix_extra_dim = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cd7e0a0c4dbf8..b3d3ab1ac7e08 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1317,6 +1317,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_SNAC_DEC: + { + // ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab); + // ml.get_key(LLM_KV_QUANTIZER_COUNT, hparams.n_quantizers); + // ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_codebook_dim); + // ml.get_key(LLM_KV_QUANTIZER_STRIDES, hparams.vq_strides); + // ml.get_key(LLM_KV_DECODER_UPSAMPLE_RATES, hparams.rates); + // ml.get_key(LLM_KV_DECODER_CHANNEL_DIMS, hparams.n_channels); + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3694,6 +3703,86 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_SNAC_DEC: + { + const int64_t n_layers = hparams.n_channels.size(); + const int64_t n_blocks = hparams.upsample_rates.size(); + GGML_ASSERT(n_layers == n_blocks + 2); + + layers.resize(n_layers); + + for (int i = 0; i < n_layers; ++i) { + auto & layer = layers[i]; + const int64_t n_in = (i == 0) ? hparams.n_channels[0] : hparams.n_channels[i-1]; + const int64_t n_out = hparams.n_channels[i]; + + if (i == 0) { + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i), + {7, 1, n_out}, 0); // [7, 1, 768] (original1) + layer.conv_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i), + {n_out, 1, 1}, 0); // [768, 1, 1] (original0) + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i), + {n_out}, 0); + } else if (i == 1) { + // Pointwise expansion + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i), + {1, n_in, n_out}, 0); // [1, 768, 1024] + layer.conv_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i), + {n_out, 1, 1}, 0); // [1024, 1, 1] + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i), + {n_out}, 0); + } else if (i == n_layers - 1) { + // Final convolution + layer.alpha = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha", i), + {1, n_in, 1}, 0); // [1, 64, 1] + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i), + {7, n_in, n_out}, 0); // [7, 64, 1] + layer.conv_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i), + {n_out, 1, 1}, 0); // [1, 1, 1] + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i), + {n_out}, 0); + } else { + // Decoder blocks + const int64_t stride = hparams.upsample_rates[i-2]; + layer.decoder_blocks.resize(1); + auto & block = layer.decoder_blocks[0]; + + // Upsampling + block.alpha = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha", i), + {1, n_in, 1}, 0); // something like [1, 1024, 1] + block.up_weight = create_tensor(tn(LLM_TENSOR_TRANSPOSED_CONV_WEIGHT, "weight", i), + {stride * 2, n_out, n_in}, 0); + block.up_scale = create_tensor(tn(LLM_TENSOR_TRANSPOSED_CONV_SCALE, "scale", i), + {n_out, 1, 1}, 0); + block.up_bias = create_tensor(tn(LLM_TENSOR_TRANSPOSED_CONV_BIAS, "bias", i), + {n_out}, 0); + + // Residual units + for (int j = 0; j < 3; ++j) { + auto & ru = block.res_units[j]; + const int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; + + ru.alpha1 = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha1", i, j), + {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i, j), + {7, 1, n_out}, 0); + ru.conv1_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i, j), + {n_out, 1, 1}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i, j), + {n_out}, 0); + + ru.alpha2 = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha2", i, j), + {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight2", i, j), + {1, n_out, n_out}, 0); + ru.conv2_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale2", i, j), + {n_out, 1, 1}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias2", i, j), + {n_out}, 0); + } + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -11597,6 +11686,96 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; +struct llm_build_snac_dec : public llm_graph_context { + llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + + cur = build_inp_embd(model.tok_embd); + + for (uint32_t il = 0; il < model.layers.size(); ++il) { + const auto & layer = model.layers[il]; + + if (il == 0) { // depthwise + cur = apply_conv1d(cur, layer.conv_w, layer.conv_scale, layer.conv_b, 1, 3); + } else if (il == 1) { // pointwise + cur = apply_conv1d(cur, layer.conv_w, layer.conv_scale, layer.conv_b, 1, 0); + } else if (il == model.layers.size() - 1) { + cur = ggml_snake(ctx0, cur, layer.alpha); + cur = apply_conv1d(cur, layer.conv_w, layer.conv_scale, layer.conv_b, 1, 3); + cur = ggml_tanh(ctx0, cur); + } else { + // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) + const auto & block = layer.decoder_blocks[0]; + const int stride = hparams.upsample_rates[il - 2]; + + cur = ggml_snake(ctx0, cur, block.alpha); + cur = apply_conv1d_transpose(cur, block.up_weight, block.up_scale, block.up_bias, stride, stride); + + // Residual Units (3 per block) + for (int j = 0; j < 3; ++j) { + const auto & ru = block.res_units[j]; + ggml_tensor * inpL = cur; + + cur = ggml_snake(ctx0, cur, ru.alpha1); + int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; + int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation + cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_scale, ru.conv1_b, 1, padding); + + // pw + cur = ggml_snake(ctx0, cur, ru.alpha2); + cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_scale, ru.conv2_b, 1, 0); + + // residual + cur = ggml_add(ctx0, cur, inpL); + } + } + } + + + int64_t target_samples = 24000; // TODO: magic number + if (cur->ne[0] > target_samples) { + cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); + } + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } + + // TODO: move these somewhere else +private: + ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, + int stride, int padding) { + ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); + ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); + if (conv_b) { + cur = ggml_add(ctx0, cur, conv_b); + } + return cur; + } + + ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, + ggml_tensor * up_bias, int stride, int padding) { + ggml_tensor * w_final = normalize_weight(up_weight, up_scale); + int kernel_size = up_weight->ne[0]; + int output_padding = stride % 2; // 0 for even strides (8, 4, 2) + ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, padding / 2, output_padding); + if (up_bias) { + cur = ggml_add(ctx0, cur, up_bias); + } + return cur; + } + + // w_final = scale * (w / || w ||) + ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { + ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? + ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); + ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); + return w_final; + } +}; + llama_memory_i * llama_model::create_memory() const { llama_memory_i * res; @@ -11868,6 +12047,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_SNAC_DEC: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -11976,6 +12159,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_SNAC_DEC: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215abbfd..f9286cb48fc36 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -137,6 +137,25 @@ struct llama_layer_convnext { struct ggml_tensor * gamma = nullptr; }; +struct llama_layer_snac_dec_block { + struct ggml_tensor * alpha = nullptr; // for snake activation + + struct ggml_tensor * up_weight = nullptr; + struct ggml_tensor * up_scale = nullptr; + struct ggml_tensor * up_bias = nullptr; + + struct { + struct ggml_tensor * alpha1 = nullptr; + struct ggml_tensor * conv1_w = nullptr; + struct ggml_tensor * conv1_scale = nullptr; + struct ggml_tensor * conv1_b = nullptr; + struct ggml_tensor * alpha2 = nullptr; + struct ggml_tensor * conv2_w = nullptr; + struct ggml_tensor * conv2_scale = nullptr; + struct ggml_tensor * conv2_b = nullptr; + } res_units[3]; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -304,6 +323,13 @@ struct llama_layer { struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; + + struct ggml_tensor * conv_w = nullptr; + struct ggml_tensor * conv_scale = nullptr; + struct ggml_tensor * conv_b = nullptr; + struct ggml_tensor * alpha = nullptr; + + std::vector decoder_blocks; }; struct llama_model { From 98e5834dc3460a105eb169c1b3ba6e1b7de60b4c Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Tue, 1 Apr 2025 21:47:29 -0700 Subject: [PATCH 4/7] codes to graph build --- convert_hf_to_gguf.py | 179 ++++++++++- examples/tts/CMakeLists.txt | 6 + examples/tts/orpheus-tts.cpp | 344 +++++++++++++++++++++ src/llama-arch.cpp | 112 +++++-- src/llama-arch.h | 56 +++- src/llama-context.cpp | 2 + src/llama-graph.cpp | 2 + src/llama-model-loader.cpp | 2 + src/llama-model.cpp | 571 +++++++++++++++++++++++++---------- src/llama-model.h | 18 +- 10 files changed, 1085 insertions(+), 207 deletions(-) create mode 100644 examples/tts/orpheus-tts.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 220af521f47f2..01ec22aa3cc28 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2331,30 +2331,177 @@ def set_gguf_parameters(self): class SNACDecModel(Model): model_arch = gguf.MODEL_ARCH.SNAC_DEC - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]: - del bid # unused + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._dummy_added = False + + def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, torch.Tensor]]: + """Convert nested PyTorch tensor names to a flat GGUF naming scheme for decoder tensors.""" + del bid # Unused + + # Add dummy token_embd.weight only once + if not self._dummy_added: + import torch + dummy_tok_embd = torch.zeros((4096, 8), dtype=torch.float16) + dummy_tok_embd = dummy_tok_embd.view(4096, 8) + logger.info(f"Adding dummy tensor: token_embd.weight, shape: {list(dummy_tok_embd.shape)}") + yield ("token_embd.weight", dummy_tok_embd) + self._dummy_added = True # Mark as added + + original_name = name + + if name.startswith("quantizer.quantizers."): + match = re.match(r"quantizer\.quantizers\.(\d+)\.(codebook\.weight|out_proj\.bias|out_proj\.parametrizations\.weight\.original[0-1])", name) + if match: + q_idx = int(match.group(1)) + tensor_type = match.group(2) + if tensor_type == "codebook.weight": + new_name = f"quantizer.{q_idx}.codebook" + elif tensor_type == "out_proj.parametrizations.weight.original0": + new_name = f"quantizer.{q_idx}.out_proj.scale" + elif tensor_type == "out_proj.parametrizations.weight.original1": + new_name = f"quantizer.{q_idx}.out_proj.weight" + elif tensor_type == "out_proj.bias": + new_name = f"quantizer.{q_idx}.out_proj.bias" + + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + else: + logger.warning(f"Could not parse quantizer tensor from: {original_name}") + return - if (name.startswith("decoder.")): - logger.info(f"{name} -> {data_torch.shape}") - return [(name, data_torch)] - else: - logger.debug(f"Skipping {name!r}") - return [] + # Skip non-decoder tensors (except quantizers, which were handled above) + if not name.startswith("decoder."): + logger.debug(f"Skipping non-decoder tensor: {original_name}") + return + + base = name[8:] # Remove 'decoder.' + parts = base.split(".") + + if base.startswith("model.0."): + logger.info(f"Skipping incompatible decoder layer 0 tensor: {original_name}") + return # Explicitly skip this layer + + # Layer 1: Second Conv + if base.startswith("model.1."): + if "bias" in name and "parametrizations" not in name: + new_name = "decoder.1.conv2.bias" + elif "parametrizations.weight.original0" in name: + new_name = "decoder.1.conv2.scale" + elif "parametrizations.weight.original1" in name: + new_name = "decoder.1.conv2.weight" + else: + logger.warning(f"Unhandled layer 1 tensor: {original_name}") + return + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + # Layers 2–5: DecoderBlocks + if "model." in base and "block" in base: + try: + layer_idx = int(parts[1]) # e.g., '2' from 'model.2' + if layer_idx not in {2, 3, 4, 5}: + logger.debug(f"Skipping non-DecoderBlock layer {layer_idx}: {original_name}") + return + block_idx = int(parts[3]) # e.g., '1' from 'block.1' + new_base = f"decoder.{layer_idx}.block.{block_idx}" + + if block_idx == 0: # Snake1d + if "alpha" in name: + new_name = f"{new_base}.alpha" + else: + logger.error(f"Expected 'alpha' in {original_name}") + return + elif block_idx == 1: # Transpose Conv + if "bias" in name and "parametrizations" not in name: + new_name = f"{new_base}.trans.bias" + elif "parametrizations.weight.original0" in name: + new_name = f"{new_base}.trans.scale" + elif "parametrizations.weight.original1" in name: + new_name = f"{new_base}.trans.weight" + else: + logger.error(f"Unhandled tensor in block 1: {original_name}") + return + elif block_idx == 2: # Noise Block + if "linear.parametrizations.weight.original0" in name: + new_name = f"{new_base}.noise.scale" + elif "linear.parametrizations.weight.original1" in name: + new_name = f"{new_base}.noise.weight" + else: + logger.error(f"Unhandled tensor in block 2: {original_name}") + return + elif block_idx in {3, 4, 5}: # Residual Units + res_base = f"{new_base}.res" + if "block.0.alpha" in name: + new_name = f"{res_base}.snake1.alpha" + elif "block.1.bias" in name: + new_name = f"{res_base}.conv1.bias" + elif "block.1.parametrizations.weight.original0" in name: + new_name = f"{res_base}.conv1.scale" + elif "block.1.parametrizations.weight.original1" in name: + new_name = f"{res_base}.conv1.weight" + elif "block.2.alpha" in name: + new_name = f"{res_base}.snake2.alpha" + elif "block.3.bias" in name: + new_name = f"{res_base}.conv2.bias" + elif "block.3.parametrizations.weight.original0" in name: + new_name = f"{res_base}.conv2.scale" + elif "block.3.parametrizations.weight.original1" in name: + new_name = f"{res_base}.conv2.weight" + else: + logger.error(f"Unhandled tensor in residual unit: {original_name}") + return + else: + logger.error(f"Unhandled block index {block_idx} in layer {layer_idx}: {original_name}") + return + + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + except (IndexError, ValueError) as e: + logger.error(f"Failed to parse tensor {original_name}: {e}") + return + + # Layer 6: Snake1d + if base == "model.6.alpha": + new_name = "decoder.6.alpha" + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + # Layer 7: Final Conv + if base.startswith("model.7."): + if "bias" in name and "parametrizations" not in name: + new_name = "decoder.7.conv.bias" + elif "parametrizations.weight.original0" in name: + new_name = "decoder.7.conv.scale" + elif "parametrizations.weight.original1" in name: + new_name = "decoder.7.conv.weight" + else: + logger.warning(f"Unhandled layer 7 tensor: {original_name}") + return + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + logger.warning(f"Tensor {original_name} not mapped to any layer") + return def set_vocab(self): self._set_vocab_none() def set_gguf_parameters(self): super().set_gguf_parameters() - # TODO: Don't think codebook is needed if the LM is a drop in quantizer replacement - #self.gguf_writer.add_vocab_size(self.hparams["codebook_size"]) - #self.gguf_writer.add_features_length(self.hparams["codebook_dim"]) - self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) - self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) - self.gguf_writer.add_uint32("n_layers", len(self.hparams["decoder_rates"])) # Infer as 4 from decoder_rates + self.gguf_writer.add_vocab_size (4096) # TODO: Fix + self.gguf_writer.add_uint32("snac.quantizer.codebook_size", self.hparams["codebook_size"]) + self.gguf_writer.add_uint32("snac.quantizer.codebook_dim", self.hparams["codebook_dim"]) + self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) # 1024 + self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) # [8, 8, 4, 2] + self.gguf_writer.add_uint32("n_layers", 8) self.gguf_writer.add_array("decoder_channel_dims", [768, 1024, 512, 256, 128, 64, 1]) - # TODO: Add sampling rate? - #self.gguf_writer.add_decoder_channel_dims(self.hparams["sampling_rate"]) + self.gguf_writer.add_array("vq_strides", self.hparams["vq_strides"]) @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..42f95df7387b4 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-orpheus-tts) +add_executable(${TARGET} orpheus-tts.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp new file mode 100644 index 0000000000000..622ec46fde05a --- /dev/null +++ b/examples/tts/orpheus-tts.cpp @@ -0,0 +1,344 @@ +#include "common.h" +#include "llama.h" +#include "llama-impl.h" +#include "log.h" +#include "arg.h" +#include "sampling.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector redistribute_codes(const std::vector& raw_codes) { + std::vector snac_codes; + for (size_t i = 0; i < raw_codes.size(); i += 7) { + // Ensure we have a full frame (7 codes) + if (i + 6 >= raw_codes.size()) break; + + // Frame offsets (per notebook) + snac_codes.push_back(raw_codes[i]); // Codebook 0 (no offset) + snac_codes.push_back(raw_codes[i+1] - 4096); // Codebook 1 + snac_codes.push_back(raw_codes[i+2] - 8192); // Codebook 2 + snac_codes.push_back(raw_codes[i+3] - 12288); // Codebook 2 + snac_codes.push_back(raw_codes[i+4] - 16384); // Codebook 1 + snac_codes.push_back(raw_codes[i+5] - 20480); // Codebook 2 + snac_codes.push_back(raw_codes[i+6] - 24576); // Codebook 2 + } + return snac_codes; +} + +static std::vector embd_to_audio( + const float * embd, + const int n_codes, + const int n_embd, + const int n_thread); +static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); +static void fill_hann_window(int length, bool periodic, float * output); +static void irfft(int n, const float * inp_cplx, float * out_real); +static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output); + +static void print_usage(int /*argc*/, char **argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); + LOG("\n"); +} + +static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt.insert(prompt.end(), tmp.begin(), tmp.end()); +} + + +// // Include embd_to_audio and save_wav16 from tts.cpp (for now) +static std::vector embd_to_audio( + const float * embd, + const int n_codes, + const int n_embd, + const int n_thread) { + const int n_fft = 1280; + const int n_hop = 320; + const int n_win = 1280; + const int n_pad = (n_win - n_hop)/2; + const int n_out = (n_codes - 1)*n_hop + n_win; + + std::vector hann(n_fft); + fill_hann_window(hann.size(), true, hann.data()); + + int n_spec = n_embd*n_codes; + + std::vector E (n_spec); + std::vector S (n_spec); + std::vector ST(n_spec); + + for (int l = 0; l < n_codes; ++l) { + for (int k = 0; k < n_embd; ++k) { + E[k*n_codes + l] = embd[l*n_embd + k]; + } + } + + for (int k = 0; k < n_embd/2; ++k) { + for (int l = 0; l < n_codes; ++l) { + float mag = E[(k )*n_codes + l]; + float phi = E[(k + n_embd/2)*n_codes + l]; + mag = exp(mag); + if (mag > 1e2) { + mag = 1e2; + } + S[2*(k*n_codes + l) + 0] = mag*cosf(phi); + S[2*(k*n_codes + l) + 1] = mag*sinf(phi); + } + } + + for (int l = 0; l < n_codes; ++l) { + for (int k = 0; k < n_embd/2; ++k) { + ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; + ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; + } + } + + std::vector res (n_codes*n_fft); + std::vector hann2(n_codes*n_fft); + + std::vector workers(n_thread); + for (int i = 0; i < n_thread; ++i) { + workers[i] = std::thread([&, i]() { + for (int l = i; l < n_codes; l += n_thread) { + irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); + for (int j = 0; j < n_fft; ++j) { + res [l*n_fft + j] *= hann[j]; + hann2[l*n_fft + j] = hann[j] * hann[j]; + } + } + }); + } + for (int i = 0; i < n_thread; ++i) { + workers[i].join(); + } + + std::vector audio; + std::vector env; + + fold(res, n_out, n_win, n_hop, n_pad, audio); + fold(hann2, n_out, n_win, n_hop, n_pad, env); + + for (size_t i = 0; i < audio.size(); ++i) { + audio[i] /= env[i]; + } + + return audio; +} + +static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); + return false; + } + + struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; + } header; + + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto & sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + return file.good(); +} + +// Supporting functions from tts.cpp (for embd_to_audio) +static void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} + +static void twiddle(float * real, float * imag, int k, int N) { + float angle = 2 * M_PI * k / N; + *real = cos(angle); + *imag = sin(angle); +} + +static void irfft(int n, const float * inp_cplx, float * out_real) { + int N = n / 2 + 1; + + std::vector real_input(N); + std::vector imag_input(N); + for (int i = 0; i < N; ++i) { + real_input[i] = inp_cplx[2 * i]; + imag_input[i] = inp_cplx[2 * i + 1]; + } + + std::vector real_output(n); + std::vector imag_output(n); + + for (int k = 0; k < n; ++k) { + real_output[k] = 0.0f; + imag_output[k] = 0.0f; + for (int m = 0; m < N; ++m) { + float twiddle_real; + float twiddle_imag; + + twiddle(&twiddle_real, &twiddle_imag, k * m, n); + + real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; + imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; + } + } + + for (int i = 0; i < n; ++i) { + out_real[i] = real_output[i] / N; + } +} + +static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output) { + int64_t output_height = n_out; + int64_t kernel_w = n_win; + int64_t stride_w = n_hop; + int64_t width = n_out; + + output.resize(width, 0.0f); + + int64_t col_idx = 0; + for (int64_t w_col = 0; w_col < width; ++w_col) { + int64_t start = w_col * stride_w - n_pad; + int64_t end = start + kernel_w; + + for (int64_t w_im = start; w_im < end; ++w_im) { + if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { + output[w_im] += data[col_idx]; + } + col_idx++; + } + } + + output.resize(n_out - 2 * n_pad); +} + +int main(int argc, char **argv) { + common_params params; + + params.model = "models/orpheus-3b-0.1-ft-q4_k_m.gguf"; + params.vocoder.model = "models/snac-vocab.gguf"; + params.out_file = "output.wav"; + + params.n_predict = 1200; + params.sampling.top_k = 4; + params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + params.n_batch = 4096; + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + common_init_result orpheus_init_ttc = common_init_from_params(params); + + llama_model * model_ttc = NULL; + llama_context * ctx_ttc = NULL; + + model_ttc = orpheus_init_ttc.model.get(); + ctx_ttc = orpheus_init_ttc.context.get(); + + const llama_vocab *vocab = llama_model_get_vocab(model_ttc); + + common_sampler *sampler = common_sampler_init(model_ttc, params.sampling); + if (!sampler) { + LOG_ERR("Failed to initialize sampler\n"); + return 1; + } + + // Construct prompt: <|startofhuman|> tara: [prompt] <|eot_id|> <|endofhuman|> + std::vector tokens; + tokens.push_back(128259); // <|startofhuman|> + prompt_add(tokens, vocab, "tara: ", false, true); // Voice prefix + prompt_add(tokens, vocab, params.prompt, false, true); // User prompt + prompt_add(tokens, vocab, "", false, true); // Emotion tag + tokens.push_back(128009); // <|eot_id|> + tokens.push_back(128260); // <|endofhuman|> + + + llama_model * model_cts = NULL; + llama_context * ctx_cts = NULL; + + params.model = params.vocoder.model; + params.n_batch = 2; + + params.embedding = true + // disable warmup, SNAC doesn't care about BOS or EOS tokens; + params.warmup = false; + + common_init_result snac_init_cts = common_init_from_params(params); + LOG_INF("SNAC model loaded: %s\n", params.model.c_str()); + + model_cts = snac_init_cts.model.get(); + ctx_cts = snac_init_cts.context.get(); + + std::vector speech_codes = {100, 4200, 8500, 12500, 16500, 21000, 25000, + 200, 4300, 8600, 12600, 16600, 21111, 25100}; + + std::vector snac_codes = redistribute_codes(speech_codes); + + const int n_codes = speech_codes.size(); + const int batch_size = n_codes; + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + for (size_t i = 0; i < n_codes; ++i) { + common_batch_add(batch, snac_codes[i], i, {0}, true); + } + + LOG_INF("Batch before decode: n_tokens = %d\n", batch.n_tokens); + if (llama_decode(ctx_cts, batch) != 0) { /* error */ } + + if (llama_decode(ctx_cts, batch) != 0) { /* error */ } + GGML_ASSERT(batch.n_tokens == n_codes); + + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx_cts, batch) != 0) { + LOG_ERR("Failed to decode SNAC batch\n"); + return 1; + } + llama_synchronize(ctx_cts); + + LOG_INF("SNAC decode completed\n"); + + llama_batch_free(batch); + llama_backend_free(); + return 0; +} diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e3eb12790a5cb..b90214d255584 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_SNAC_DEC, "snac-dec" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1394,27 +1395,58 @@ static const std::map> LLM_TENSOR_N { LLM_ARCH_SNAC_DEC, { - { LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, "decoder.model.%d.parametrizations.weight.original0" }, - { LLM_TENSOR_WEIGHT_NORMALIZATION_KERNEL, "decoder.model.%d.parametrizations.weight.original1" }, - { LLM_TENSOR_TRANSPOSED_CONV_WEIGHT, "decoder.model.%d.parametrizations.weight.original1" }, - { LLM_TENSOR_CONV1D_BIAS, "decoder.model.%d.bias" }, - { LLM_TENSOR_TRANSPOSED_CONV_BIAS, "decoder.model.%d.bias" }, - { LLM_TENSOR_SNAKE_ALPHA, "decoder.model.%d.alpha" }, - { LLM_TENSOR_SNAKE_ALPHA, "decoder.model.%d.block.%d.alpha" }, - { LLM_TENSOR_SNAKE_ALPHA, "decoder.model.%d.block.%d.block.%d.alpha" }, - { LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, "decoder.model.%d.block.%d.parametrizations.weight.original0" }, - { LLM_TENSOR_WEIGHT_NORMALIZATION_KERNEL, "decoder.model.%d.block.%d.parametrizations.weight.original1" }, - { LLM_TENSOR_CONV1D_BIAS, "decoder.model.%d.block.%d.bias" }, - { LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, "decoder.model.%d.block.%d.linear.parametrizations.weight.original0" }, - { LLM_TENSOR_LINEAR_WEIGHT, "decoder.model.%d.block.%d.linear.parametrizations.weight.original1" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_CODEBOOK, "quantizer.%d.codebook" }, + { LLM_TENSOR_CODEBOOK_PROJ_B, "quantizer.%d.out_proj.bias" }, + { LLM_TENSOR_CODEBOOK_PROJ_S, "quantizer.%d.out_proj.scale" }, + { LLM_TENSOR_CODEBOOK_PROJ_W, "quantizer.%d.out_proj.weight" }, + + { LLM_TENSOR_CONV_W2, "decoder.1.conv2.weight" }, + { LLM_TENSOR_CONV_S2, "decoder.1.conv2.scale" }, + { LLM_TENSOR_CONV_B2, "decoder.1.conv2.bias" }, + { LLM_TENSOR_BLOCK_ALPHA, "decoder.%d.block.0.alpha" }, + { LLM_TENSOR_TRANS_W, "decoder.%d.block.1.trans.weight" }, + { LLM_TENSOR_TRANS_S, "decoder.%d.block.1.trans.scale" }, + { LLM_TENSOR_TRANS_B, "decoder.%d.block.1.trans.bias" }, + { LLM_TENSOR_NOISE_W, "decoder.%d.block.2.noise.weight" }, + { LLM_TENSOR_NOISE_S, "decoder.%d.block.2.noise.scale" }, + // Residual Units + { LLM_TENSOR_RES_SNAKE1_A, "decoder.%d.block.3.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W, "decoder.%d.block.3.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S, "decoder.%d.block.3.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B, "decoder.%d.block.3.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A, "decoder.%d.block.3.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W, "decoder.%d.block.3.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S, "decoder.%d.block.3.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B, "decoder.%d.block.3.res.conv2.bias" }, + { LLM_TENSOR_RES_SNAKE1_A_B4, "decoder.%d.block.4.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W_B4, "decoder.%d.block.4.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S_B4, "decoder.%d.block.4.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B_B4, "decoder.%d.block.4.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A_B4, "decoder.%d.block.4.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W_B4, "decoder.%d.block.4.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S_B4, "decoder.%d.block.4.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B_B4, "decoder.%d.block.4.res.conv2.bias" }, + { LLM_TENSOR_RES_SNAKE1_A_B5, "decoder.%d.block.5.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W_B5, "decoder.%d.block.5.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S_B5, "decoder.%d.block.5.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B_B5, "decoder.%d.block.5.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A_B5, "decoder.%d.block.5.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W_B5, "decoder.%d.block.5.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S_B5, "decoder.%d.block.5.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B_B5, "decoder.%d.block.5.res.conv2.bias" }, + { LLM_TENSOR_ALPHA, "decoder.6.alpha" }, + { LLM_TENSOR_CONV_W7, "decoder.7.conv.weight" }, + { LLM_TENSOR_CONV_S7, "decoder.7.conv.scale" }, + { LLM_TENSOR_CONV_B7, "decoder.7.conv.bias" }, }, }, + { + LLM_ARCH_UNKNOWN, { - LLM_ARCH_UNKNOWN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, }, + }, }; static const std::map LLM_TENSOR_INFOS = { @@ -1570,8 +1602,53 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + + { LLM_TENSOR_CONV_B2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CONV_S2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_W2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_BLOCK_ALPHA, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_TRANS_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_TRANS_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_TRANS_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_NOISE_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_NOISE_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT } }, + { LLM_TENSOR_RES_SNAKE1_A, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE1_A_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE1_A_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_ALPHA, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_B7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CONV_S7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_W7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + + { LLM_TENSOR_CODEBOOK, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS } }, + { LLM_TENSOR_CODEBOOK_PROJ_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CODEBOOK_PROJ_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CODEBOOK_PROJ_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, }; + + LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { @@ -1581,6 +1658,7 @@ std::string LLM_KV::operator()(llm_kv kv) const { std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + fprintf(stderr, "LLM_TN_IMPL::str: tensor enum %d not found in map for arch %d\n", (int)tensor, (int)arch); return "__missing__"; } diff --git a/src/llama-arch.h b/src/llama-arch.h index 23eec12e718f1..5d649d045cc78 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -353,17 +353,51 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, - LLM_TENSOR_CODEBOOK_WEIGHT, - LLM_TENSOR_WEIGHT_NORMALIZATION_SCALE, - LLM_TENSOR_WEIGHT_NORMALIZATION_KERNEL, - LLM_TENSOR_CONV1D_WEIGHT, - LLM_TENSOR_CONV1D_SCALE, - LLM_TENSOR_CONV1D_BIAS, - LLM_TENSOR_SNAKE_ALPHA, - LLM_TENSOR_LINEAR_WEIGHT, - LLM_TENSOR_TRANSPOSED_CONV_WEIGHT, - LLM_TENSOR_TRANSPOSED_CONV_SCALE, - LLM_TENSOR_TRANSPOSED_CONV_BIAS, + LLM_TENSOR_CONV_B1, + LLM_TENSOR_CONV_S1, + LLM_TENSOR_CONV_W1, + LLM_TENSOR_CONV_B2, + LLM_TENSOR_CONV_S2, + LLM_TENSOR_CONV_W2, + LLM_TENSOR_BLOCK_ALPHA, + LLM_TENSOR_TRANS_B, + LLM_TENSOR_TRANS_S, + LLM_TENSOR_TRANS_W, + LLM_TENSOR_NOISE_S, + LLM_TENSOR_NOISE_W, + LLM_TENSOR_RES_SNAKE1_A, + LLM_TENSOR_RES_CONV1_B, + LLM_TENSOR_RES_CONV1_S, + LLM_TENSOR_RES_CONV1_W, + LLM_TENSOR_RES_SNAKE2_A, + LLM_TENSOR_RES_CONV2_B, + LLM_TENSOR_RES_CONV2_S, + LLM_TENSOR_RES_CONV2_W, + LLM_TENSOR_RES_SNAKE1_A_B4, + LLM_TENSOR_RES_CONV1_B_B4, + LLM_TENSOR_RES_CONV1_S_B4, + LLM_TENSOR_RES_CONV1_W_B4, + LLM_TENSOR_RES_SNAKE2_A_B4, + LLM_TENSOR_RES_CONV2_B_B4, + LLM_TENSOR_RES_CONV2_S_B4, + LLM_TENSOR_RES_CONV2_W_B4, + LLM_TENSOR_RES_SNAKE1_A_B5, + LLM_TENSOR_RES_CONV1_B_B5, + LLM_TENSOR_RES_CONV1_S_B5, + LLM_TENSOR_RES_CONV1_W_B5, + LLM_TENSOR_RES_SNAKE2_A_B5, + LLM_TENSOR_RES_CONV2_B_B5, + LLM_TENSOR_RES_CONV2_S_B5, + LLM_TENSOR_RES_CONV2_W_B5, + LLM_TENSOR_ALPHA, + LLM_TENSOR_CONV_B7, + LLM_TENSOR_CONV_S7, + LLM_TENSOR_CONV_W7, + + LLM_TENSOR_CODEBOOK, + LLM_TENSOR_CODEBOOK_PROJ_B, + LLM_TENSOR_CODEBOOK_PROJ_S, + LLM_TENSOR_CODEBOOK_PROJ_W, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2e79ff..ca4adaa781cb3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -312,7 +312,9 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { + LLAMA_LOG_DEBUG("here 3\n"); llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto * gf = graph_init(); graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); if (!ggml_backend_sched_reserve(sched.get(), gf)) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438cc..f0a8b1071dc3b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -985,6 +985,8 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { } } else { inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + LLAMA_LOG_DEBUG("build_inp_embd: inp->embd shape = [%ld, %ld, %ld, %ld]\n", + inp->embd->ne[0], inp->embd->ne[1], inp->embd->ne[2], inp->embd->ne[3]); ggml_set_input(inp->embd); cur = inp->embd; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 05d58ad90eba9..4730b8c6a5d22 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -745,6 +745,8 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } } if (!is_ok) { + fprintf(stderr, "check_tensor_dims: name=%s, expected=%s, got=%s\n", + name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(cur).c_str()); throw std::runtime_error( format("%s: tensor '%s' has wrong shape; expected %s, got %s", __func__, name.c_str(), diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b3d3ab1ac7e08..e711b28684837 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1319,12 +1319,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_SNAC_DEC: { - // ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab); - // ml.get_key(LLM_KV_QUANTIZER_COUNT, hparams.n_quantizers); - // ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_codebook_dim); - // ml.get_key(LLM_KV_QUANTIZER_STRIDES, hparams.vq_strides); - // ml.get_key(LLM_KV_DECODER_UPSAMPLE_RATES, hparams.rates); - // ml.get_key(LLM_KV_DECODER_CHANNEL_DIMS, hparams.n_channels); + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; // From decoder_channel_dims + hparams.upsample_rates = {8, 8, 4, 2}; + hparams.n_embd = 768; + hparams.n_layer = 8; + + // Dummy KV cache params to satisfy llama.cpp + for (uint32_t i = 0; i < 7; ++i) { // n_total_layers = 8 + hparams.n_head_arr[i] = 1; + hparams.n_head_kv_arr[i] = 1; + } + hparams.n_embd_head_k = 1; + hparams.n_embd_head_v = 1; } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1482,13 +1488,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ggml_backend_buffer_type_t first_moved_to_buft = nullptr; auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + std::string tn_str = tn.str(); + ggml_tensor * t_meta = ml.get_tensor_meta(tn_str.c_str()); if (!t_meta) { if (flags & TENSOR_NOT_REQUIRED) { return nullptr; } - throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + return nullptr; + //throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); } // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops @@ -1583,6 +1591,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return t; } } + fprintf(stderr, "create_tensor: Creating '%s' with ne=[%ld, %ld, %ld]\n", + tn_str.c_str(), ne.begin()[0], ne.begin()[1], ne.begin()[2]); return ml.create_tensor(ctx, tn, ne, flags); }; @@ -3695,7 +3705,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); } - // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } @@ -3705,82 +3714,134 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_SNAC_DEC: { - const int64_t n_layers = hparams.n_channels.size(); - const int64_t n_blocks = hparams.upsample_rates.size(); - GGML_ASSERT(n_layers == n_blocks + 2); + // TODO: Magic numbers everwhere + const int64_t n_total_layers = hparams.n_layer; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {8, 4096, 1}, 0); + + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; - layers.resize(n_layers); + // Quantizer projection tensors (0, 1, 2) + for (int qid = 0; qid < 3; ++qid) { + fprintf(stderr, "%s: Loading quantizer %d tensors\n", __func__, qid); + // Bias: [768, 1, 1, 1] -> {768} + codebook_proj_b[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_B, qid, -1), {768, 1, 1}, 0); + // Scale: [1, 1, 768, 1] -> {768} + codebook_proj_s[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_S, qid, -1), {1, 1, 768}, 0); + // Weight: [1, 8, 768, 1] -> {8, 768} + codebook_proj_w[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_W, qid, -1), {1, 8, 768}, 0); - for (int i = 0; i < n_layers; ++i) { + codebook[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK, qid, -1), {8, 4096, 1, 1}, 0); + } + + // Decoder tensors + for (int i = 1; i < n_total_layers; ++i) { // Loop from i = 0 to 7 auto & layer = layers[i]; - const int64_t n_in = (i == 0) ? hparams.n_channels[0] : hparams.n_channels[i-1]; - const int64_t n_out = hparams.n_channels[i]; - - if (i == 0) { - layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i), - {7, 1, n_out}, 0); // [7, 1, 768] (original1) - layer.conv_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i), - {n_out, 1, 1}, 0); // [768, 1, 1] (original0) - layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i), - {n_out}, 0); - } else if (i == 1) { - // Pointwise expansion - layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i), - {1, n_in, n_out}, 0); // [1, 768, 1024] - layer.conv_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i), - {n_out, 1, 1}, 0); // [1024, 1, 1] - layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i), - {n_out}, 0); - } else if (i == n_layers - 1) { - // Final convolution - layer.alpha = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha", i), - {1, n_in, 1}, 0); // [1, 64, 1] - layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i), - {7, n_in, n_out}, 0); // [7, 64, 1] - layer.conv_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i), - {n_out, 1, 1}, 0); // [1, 1, 1] - layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i), - {n_out}, 0); - } else { - // Decoder blocks - const int64_t stride = hparams.upsample_rates[i-2]; - layer.decoder_blocks.resize(1); - auto & block = layer.decoder_blocks[0]; - - // Upsampling - block.alpha = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha", i), - {1, n_in, 1}, 0); // something like [1, 1024, 1] - block.up_weight = create_tensor(tn(LLM_TENSOR_TRANSPOSED_CONV_WEIGHT, "weight", i), - {stride * 2, n_out, n_in}, 0); - block.up_scale = create_tensor(tn(LLM_TENSOR_TRANSPOSED_CONV_SCALE, "scale", i), - {n_out, 1, 1}, 0); - block.up_bias = create_tensor(tn(LLM_TENSOR_TRANSPOSED_CONV_BIAS, "bias", i), - {n_out}, 0); - - // Residual units - for (int j = 0; j < 3; ++j) { - auto & ru = block.res_units[j]; - const int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; - - ru.alpha1 = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha1", i, j), - {1, n_out, 1}, 0); - ru.conv1_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight", i, j), - {7, 1, n_out}, 0); - ru.conv1_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale", i, j), - {n_out, 1, 1}, 0); - ru.conv1_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias", i, j), - {n_out}, 0); - - ru.alpha2 = create_tensor(tn(LLM_TENSOR_SNAKE_ALPHA, "alpha2", i, j), - {1, n_out, 1}, 0); - ru.conv2_w = create_tensor(tn(LLM_TENSOR_CONV1D_WEIGHT, "weight2", i, j), - {1, n_out, n_out}, 0); - ru.conv2_scale = create_tensor(tn(LLM_TENSOR_CONV1D_SCALE, "scale2", i, j), - {n_out, 1, 1}, 0); - ru.conv2_b = create_tensor(tn(LLM_TENSOR_CONV1D_BIAS, "bias2", i, j), - {n_out}, 0); - } + + // Calculate n_in and n_out for the current layer i + const int64_t n_in = (i == 0) ? 1 : ((i == 7) ? hparams.n_channels[i-2] /* 64 */ : hparams.n_channels[i-1]); + const int64_t n_out = (i == 7) ? hparams.n_channels[i-1] /* 1 */ : hparams.n_channels[i]; + + fprintf(stderr, "%s: Layer %d: Starting (n_in=%lld, n_out=%lld)\n", __func__, i, n_in, n_out); + + if (i == 1) { // --- Layer 1: Conv2 --- + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV_W2, i, -1), {1, n_in, n_out}, 0); + layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S2, i, -1), {1, 1, n_out}, 0); + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B2, i, -1), {n_out}, 0); } + else if (i >= 2 && i <= 5) { // --- Layers 2-5: Blocks --- + const int n_blocks = 6; + layer.decoder_blocks.resize(n_blocks); + + for (int bid = 0; bid < n_blocks; ++bid) { + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Starting\n", __func__, i, bid); + + switch (bid) { + case 0: // Block 0: Alpha + layer.decoder_blocks[bid].alpha = create_tensor(tn(LLM_TENSOR_BLOCK_ALPHA, i, bid), {1, n_in, 1}, 0); + break; + case 1: // Block 1: Transition + { + int64_t trans_dim; + if (i == 2) trans_dim = 16; + else if (i == 3) trans_dim = 16; + else if (i == 4) trans_dim = 8; + else trans_dim = 4; // Assumed for i == 5 + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Using trans_dim = %lld\n", __func__, i, bid, trans_dim); + layer.decoder_blocks[bid].up_weight = create_tensor(tn(LLM_TENSOR_TRANS_W, i, bid), {trans_dim, n_out, n_in}, 0); + layer.decoder_blocks[bid].up_scale = create_tensor(tn(LLM_TENSOR_TRANS_S, i, bid), {1, 1, n_in}, 0); + layer.decoder_blocks[bid].up_bias = create_tensor(tn(LLM_TENSOR_TRANS_B, i, bid), {n_out}, 0); + if (!layer.decoder_blocks[bid].up_bias) { + LLAMA_LOG_DEBUG("Failed to create decoder.%d.block.%d.trans.bias\n", i, bid); + } + } + break; + case 2: + { + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Loading noise tensors\n", __func__, i, bid); + layer.decoder_blocks[bid].noise_w = create_tensor(tn(LLM_TENSOR_NOISE_W, i, bid), {1, n_out, n_out}, 0); + layer.decoder_blocks[bid].noise_s = create_tensor(tn(LLM_TENSOR_NOISE_S, i, bid), {1, 1, n_out}, 0); + } + break; + case 3: // Block 3: Residual Unit 1 + { + int res_unit_idx = 0; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; + res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); + res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); + res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); + res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); + res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); + res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); + res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); + res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); + } + break; + case 4: // Block 4: Residual Unit 2 + { + int res_unit_idx = 1; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; + res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); + res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); + res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); + res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); + res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); + res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); + res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); + res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); + } + break; + case 5: // Block 5: Residual Unit 3 + { + int res_unit_idx = 2; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; + res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); + res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); + res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); + res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); + res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); + res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); + res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); + res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); + } + break; + default: + fprintf(stderr, "%s: ERROR: Unexpected block id %d in layer %d\n", __func__, bid, i); + return false; // Or handle error appropriately + } + fprintf(stderr, "%s: Layer %d, Block %d: Finished\n", __func__, i, bid); + } // End block loop + } + else if (i == 6) { // --- Layer 6: Alpha --- + layer.alpha = create_tensor(tn(LLM_TENSOR_ALPHA, i, -1), {1, n_in, 1}, 0); + } + else if (i == 7) { // --- Layer 7: Final Conv --- + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV_W7, i, -1), {7, n_in, n_out}, 0); + layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S7, i, -1), {1, 1, n_out}, 0); + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B7, i, -1), {n_out}, 0); + } + else { // Should not happen + fprintf(stderr, "%s: ERROR: Unexpected layer index %d\n", __func__, i); + return false; // Or handle error appropriately + } + fprintf(stderr, "%s: Layer %d: Finished\n", __func__, i); } } break; default: @@ -11686,94 +11747,286 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; +// struct llm_build_snac_dec : public llm_graph_context { + +// llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +// LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); +// for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { +// LLAMA_LOG_INFO("%d ", ubatch.token[i]); +// } +// LLAMA_LOG("\n"); +// LLAMA_LOG_DEBUG("%s: Entering constructor, model.layers.size() = %zu\n", __func__, model.layers.size()); +// ggml_tensor * cur; +// ggml_tensor * inpL; + +// // TODO: probalby just get raw codes +// //cur = build_inp_embd(model.tok_embd); +// //LLAMA_LOG_INFO("After build_inp_embd: shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // hack, hardcode expected SNAC input at first conv layer +// cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); // [channels, seq_len, 1, 1] +// ggml_set_input(cur); +// LLAMA_LOG_INFO("hardcoded shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // end hack + +// // Log input tokens before processing +// LLAMA_LOG_INFO("%s: ubatch.n_tokens = %u\n", __func__, ubatch.n_tokens); +// LLAMA_LOG_WARN("%s: Input tokens from ubatch = ", __func__); +// for (uint32_t i = 0; i < ubatch.n_tokens && i < 20; ++i) { +// LLAMA_LOG_INFO("%d ", ubatch.token[i]); +// } +// if (ubatch.n_tokens > 20) LLAMA_LOG_INFO("..."); +// LLAMA_LOG("\n"); + +// // ggml_tensor * layer_1; +// // ggml_tensor * layer_2; +// // ggml_tensor * layer_3; +// //redistribute_codes(cur, &layer_1, &layer_2, &layer_3); + +// // Log the redistributed layers +// //log_tensor("Layer 1", layer_1); +// //log_tensor("Layer 2", layer_2); +// //log_tensor("Layer 3", layer_3); + +// for (uint32_t il = 1; il < model.layers.size(); ++il) { +// const auto & layer = model.layers[il]; + +// LLAMA_LOG_DEBUG("%s: Layer %u: Starting, cur = %p\n", __func__, il, cur); + +// if (il == 1) { // pointwise +// LLAMA_LOG_INFO("%s: Layer %u: Pointwise conv, conv_w = %p, conv_s = %p, conv_b = %p\n", +// __func__, il, layer.conv_w, layer.conv_s, layer.conv_b); +// LLAMA_LOG_INFO("Before transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// cur = ggml_transpose(ctx0, cur); // [768, 512] -> [512, 768] +// LLAMA_LOG_INFO("After transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 0); +// LLAMA_LOG_INFO("%s: Layer %u: After pointwise conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } else if (il == model.layers.size() - 1) { +// LLAMA_LOG_INFO("%s: Layer %u: Final layer, alpha = %p, conv_w = %p, conv_s = %p, conv_b = %p\n", +// __func__, il, layer.alpha, layer.conv_w, layer.conv_s, layer.conv_b); +// cur = ggml_snake(ctx0, cur, layer.alpha); +// LLAMA_LOG_INFO("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 3); +// LLAMA_LOG_INFO("%s: Layer %u: After final conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// cur = ggml_tanh(ctx0, cur); +// LLAMA_LOG_INFO("%s: Layer %u: After ggml_tanh, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } else { +// // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) +// const int stride = hparams.upsample_rates[il - 2]; // 8 for il = 2 +// const int padding = stride; + +// // Block 0: Snake activation +// const auto & block0 = layer.decoder_blocks[0]; +// LLAMA_LOG_DEBUG("%s: Layer %u: Block 0, alpha = %p\n", __func__, il, block0.alpha); +// cur = ggml_snake(ctx0, cur, block0.alpha); +// LLAMA_LOG_DEBUG("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // Block 1: Transposed convolution +// const auto & block1 = layer.decoder_blocks[1]; +// LLAMA_LOG_DEBUG("%s: Layer %u: Block 1, stride = %d, up_weight = %p, up_scale = %p, up_bias = %p\n", +// __func__, il, stride, block1.up_weight, block1.up_scale, block1.up_bias); + +// cur = apply_conv1d_transpose(cur, block1.up_weight, block1.up_scale, block1.up_bias, stride, padding); +// LLAMA_LOG_DEBUG("%s: Layer %u: After conv1d_transpose, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // Residual Units (3 per block) +// for (int j = 0; j < 3; ++j) { +// const auto & ru = block1.res_units[j]; +// ggml_tensor * inpL = cur; +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Starting, inpL = %p, alpha1 = %p, conv1_w = %p, conv1_s = %p, conv1_b = %p\n", +// __func__, il, j, inpL, ru.alpha1, ru.conv1_w, ru.conv1_s, ru.conv1_b); + +// cur = ggml_snake(ctx0, cur, ru.alpha1); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; +// int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation +// cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_s, ru.conv1_b, 1, padding); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); + +// // pw +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Pointwise, alpha2 = %p, conv2_w = %p, conv2_s = %p, conv2_b = %p\n", +// __func__, il, j, ru.alpha2, ru.conv2_w, ru.conv2_s, ru.conv2_b); +// cur = ggml_snake(ctx0, cur, ru.alpha2); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_s, ru.conv2_b, 1, 0); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); + +// // residual +// cur = ggml_add(ctx0, cur, inpL); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_add, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } +// } +// LLAMA_LOG_DEBUG("%s: Layer %u: Finished, cur = %p\n", __func__, il, cur); +// } + +// int64_t target_samples = 24000; // TODO: magic number +// LLAMA_LOG_DEBUG("%s: Trimming output, cur = %p, target_samples = %ld, cur->ne[0] = %ld\n", +// __func__, cur, target_samples, cur ? cur->ne[0] : -1); +// if (cur->ne[0] > target_samples) { +// cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); +// LLAMA_LOG_DEBUG("%s: After ggml_get_rows, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } + +// LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); +// cb(cur, "result_embd", -1); +// res->t_embd = cur; + +// LLAMA_LOG_DEBUG("%s: Building forward graph, cur = %p\n", __func__, cur); +// ggml_build_forward_expand(gf, cur); +// LLAMA_LOG_DEBUG("%s: Graph build completed\n", __func__); +// } + +// // TODO: move these somewhere else +// private: +// // Helper to log tensor contents +// void log_tensor(const char * name, ggml_tensor * tensor) { +// if (!tensor) { +// LLAMA_LOG_INFO("%s: %s is null\n", __func__, name); +// return; +// } +// LLAMA_LOG_DEBUG("%s: %s shape = [%ld, %ld, %ld, %ld], first 20 elements = ", +// __func__, name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); +// int n_elements = ggml_nelements(tensor); +// float * data = (float *)tensor->data; +// for (int i = 0; i < std::min(20, n_elements); ++i) { +// LLAMA_LOG_DEBUG("%.2f ", data[i]); +// } +// if (n_elements > 20) LLAMA_LOG_DEBUG("..."); +// LLAMA_LOG_DEBUG("\n"); +// } + +// void redistribute_codes(ggml_tensor * input, ggml_tensor ** layer_1, ggml_tensor ** layer_2, ggml_tensor ** layer_3) { +// int64_t n_codes = input->ne[1]; // Assuming input is [n_embd, n_tokens, 1, 1] +// int64_t n_frames = n_codes / 7; +// if (n_codes % 7 != 0) { +// LLAMA_LOG_ERROR("%s: Input codes length %ld is not a multiple of 7\n", __func__, n_codes); +// *layer_1 = *layer_2 = *layer_3 = nullptr; +// return; +// } + +// int64_t n_layer_1 = n_frames; // 1 code per frame +// int64_t n_layer_2 = n_frames * 2; // 2 codes per frame +// int64_t n_layer_3 = n_frames * 4; // 4 codes per frame + +// // Indices for each layer +// std::vector idx_layer_1(n_layer_1); +// std::vector idx_layer_2(n_layer_2); +// std::vector idx_layer_3(n_layer_3); + +// for (int64_t i = 0; i < n_frames; ++i) { +// int64_t base_idx = i * 7; +// idx_layer_1[i] = base_idx + 0; // No offset +// idx_layer_2[i * 2] = base_idx + 1; // Offset -4096 +// idx_layer_2[i * 2 + 1] = base_idx + 4; // Offset -16384 +// idx_layer_3[i * 4] = base_idx + 2; // Offset -8192 +// idx_layer_3[i * 4 + 1] = base_idx + 3; // Offset -12288 +// idx_layer_3[i * 4 + 2] = base_idx + 5; // Offset -20480 +// idx_layer_3[i * 4 + 3] = base_idx + 6; // Offset -24576 +// } + +// // Create index tensors +// ggml_tensor * idx_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); +// ggml_tensor * idx_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); +// ggml_tensor * idx_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); + +// memcpy(idx_1->data, idx_layer_1.data(), n_layer_1 * sizeof(int32_t)); +// memcpy(idx_2->data, idx_layer_2.data(), n_layer_2 * sizeof(int32_t)); +// memcpy(idx_3->data, idx_layer_3.data(), n_layer_3 * sizeof(int32_t)); + +// // Extract layers using ggml_get_rows +// *layer_1 = ggml_get_rows(ctx0, input, idx_1); +// *layer_2 = ggml_get_rows(ctx0, input, idx_2); +// *layer_3 = ggml_get_rows(ctx0, input, idx_3); + +// // Apply offsets +// *layer_2 = ggml_add(ctx0, *layer_2, ggml_new_f32(ctx0, -4096.0f)); // Simplified; we'll refine offsets later +// *layer_3 = ggml_add(ctx0, *layer_3, ggml_new_f32(ctx0, -8192.0f)); // Simplified for now +// } + +// ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, +// int stride, int padding) { +// ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); +// ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); +// if (conv_b) { +// ggml_tensor* bias_reshaped = ggml_reshape_3d(ctx0, conv_b, 1, 1024, 1); +// cur = ggml_add(ctx0, cur, bias_reshaped); +// } +// return cur; +// } + +// ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, ggml_tensor * up_bias, int stride, int padding) { +// // Normalize weights (temporary fix for up_scale shape mismatch) +// if (up_scale->ne[2] != up_weight->ne[1]) { // 1024 != 512 +// LLAMA_LOG_WARN("up_scale channels (%ld) don’t match output channels (%ld), expected behavior may vary\n", up_scale->ne[2], up_weight->ne[1]); +// // Ideally reshape up_scale to [1, 1, 512, 1], but no reshape; proceed with warning +// } +// ggml_tensor * w_final = normalize_weight(up_weight, up_scale); +// LLAMA_LOG_INFO("After normalize weight: w_final shape = [%ld, %ld, %ld, %ld]\n", +// w_final->ne[0], w_final->ne[1], w_final->ne[2], w_final->ne[3]); + +// ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, 0, 1); +// LLAMA_LOG_INFO("After ggml_conv_transpose_1d = [%ld, %ld, %ld, %ld]\n", +// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// if (up_bias) { +// // up_bias is [512, 1, 1, 1]; need [4104, 512, 1, 1] for ggml_add +// LLAMA_LOG_INFO("entering up_bias block. Before ggml_repeat, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// LLAMA_LOG_INFO("Before ggml_repeat, up_bias shape = [%ld, %ld, %ld, %ld]\n", up_bias->ne[0], up_bias->ne[1], up_bias->ne[2], up_bias->ne[3]); +// ggml_tensor * bias_repeated = ggml_repeat(ctx0, up_bias, cur); +// LLAMA_LOG_DEBUG("Repeated up_bias to shape = [%ld, %ld, %ld, %ld]\n", +// bias_repeated->ne[0], bias_repeated->ne[1], bias_repeated->ne[2], bias_repeated->ne[3]); +// cur = ggml_add(ctx0, cur, bias_repeated); +// LLAMA_LOG_DEBUG("After bias add: cur shape = [%ld, %ld, %ld, %ld]\n", +// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// } +// return cur; +// } + +// // w_final = scale * (w / || w ||) +// ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { +// ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? +// ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); +// ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); +// return w_final; +// } +// }; + +// TODO: Placeholder struct llm_build_snac_dec : public llm_graph_context { - llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - ggml_tensor * cur; - - cur = build_inp_embd(model.tok_embd); - for (uint32_t il = 0; il < model.layers.size(); ++il) { - const auto & layer = model.layers[il]; - - if (il == 0) { // depthwise - cur = apply_conv1d(cur, layer.conv_w, layer.conv_scale, layer.conv_b, 1, 3); - } else if (il == 1) { // pointwise - cur = apply_conv1d(cur, layer.conv_w, layer.conv_scale, layer.conv_b, 1, 0); - } else if (il == model.layers.size() - 1) { - cur = ggml_snake(ctx0, cur, layer.alpha); - cur = apply_conv1d(cur, layer.conv_w, layer.conv_scale, layer.conv_b, 1, 3); - cur = ggml_tanh(ctx0, cur); - } else { - // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) - const auto & block = layer.decoder_blocks[0]; - const int stride = hparams.upsample_rates[il - 2]; - - cur = ggml_snake(ctx0, cur, block.alpha); - cur = apply_conv1d_transpose(cur, block.up_weight, block.up_scale, block.up_bias, stride, stride); - - // Residual Units (3 per block) - for (int j = 0; j < 3; ++j) { - const auto & ru = block.res_units[j]; - ggml_tensor * inpL = cur; - - cur = ggml_snake(ctx0, cur, ru.alpha1); - int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; - int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation - cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_scale, ru.conv1_b, 1, padding); - - // pw - cur = ggml_snake(ctx0, cur, ru.alpha2); - cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_scale, ru.conv2_b, 1, 0); + llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - // residual - cur = ggml_add(ctx0, cur, inpL); - } - } + // TODO: Remove + LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); + for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { + LLAMA_LOG_INFO("%d ", ubatch.token[i]); } + LLAMA_LOG("\n"); + ggml_tensor * cur; + // TODO: Hack. Implement codebook lookups and out_proj + cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); + ggml_set_input(cur); + // end hack - int64_t target_samples = 24000; // TODO: magic number - if (cur->ne[0] > target_samples) { - cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); - } - + LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); cb(cur, "result_embd", -1); res->t_embd = cur; - ggml_build_forward_expand(gf, cur); } - - // TODO: move these somewhere else -private: - ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, - int stride, int padding) { - ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); - ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); - if (conv_b) { - cur = ggml_add(ctx0, cur, conv_b); - } - return cur; - } - - ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, - ggml_tensor * up_bias, int stride, int padding) { - ggml_tensor * w_final = normalize_weight(up_weight, up_scale); - int kernel_size = up_weight->ne[0]; - int output_padding = stride % 2; // 0 for even strides (8, 4, 2) - ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, padding / 2, output_padding); - if (up_bias) { - cur = ggml_add(ctx0, cur, up_bias); - } - return cur; - } - - // w_final = scale * (w / || w ||) - ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { - ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? - ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); - ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); - return w_final; - } }; llama_memory_i * llama_model::create_memory() const { diff --git a/src/llama-model.h b/src/llama-model.h index f9286cb48fc36..5e636b0b3b3f3 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -138,20 +138,23 @@ struct llama_layer_convnext { }; struct llama_layer_snac_dec_block { - struct ggml_tensor * alpha = nullptr; // for snake activation + struct ggml_tensor * alpha = nullptr; struct ggml_tensor * up_weight = nullptr; struct ggml_tensor * up_scale = nullptr; struct ggml_tensor * up_bias = nullptr; + struct ggml_tensor * noise_w = nullptr; + struct ggml_tensor * noise_s = nullptr; + struct { struct ggml_tensor * alpha1 = nullptr; struct ggml_tensor * conv1_w = nullptr; - struct ggml_tensor * conv1_scale = nullptr; + struct ggml_tensor * conv1_s = nullptr; struct ggml_tensor * conv1_b = nullptr; struct ggml_tensor * alpha2 = nullptr; struct ggml_tensor * conv2_w = nullptr; - struct ggml_tensor * conv2_scale = nullptr; + struct ggml_tensor * conv2_s = nullptr; struct ggml_tensor * conv2_b = nullptr; } res_units[3]; }; @@ -325,7 +328,7 @@ struct llama_layer { struct llama_layer_convnext convnext; struct ggml_tensor * conv_w = nullptr; - struct ggml_tensor * conv_scale = nullptr; + struct ggml_tensor * conv_s = nullptr; struct ggml_tensor * conv_b = nullptr; struct ggml_tensor * alpha = nullptr; @@ -362,6 +365,13 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + + // TODO: structify + ggml_tensor * codebook[3]; + ggml_tensor * codebook_proj_b[3]; // Array for quantizer 0, 1, 2 bias + ggml_tensor * codebook_proj_s[3]; // Array for quantizer 0, 1, 2 scale + ggml_tensor * codebook_proj_w[3]; + std::vector layers; llama_model_params params; From 1a6fa9865915e1cabed17f74ad02cafe08eda137 Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Tue, 1 Apr 2025 21:55:25 -0700 Subject: [PATCH 5/7] cleanup --- examples/tts/orpheus-tts.cpp | 2 +- src/llama-context.cpp | 2 -- src/llama-model.cpp | 5 +---- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp index 622ec46fde05a..45595e9552fc0 100644 --- a/examples/tts/orpheus-tts.cpp +++ b/examples/tts/orpheus-tts.cpp @@ -298,7 +298,7 @@ int main(int argc, char **argv) { params.model = params.vocoder.model; params.n_batch = 2; - params.embedding = true + params.embedding = true; // disable warmup, SNAC doesn't care about BOS or EOS tokens; params.warmup = false; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ca4adaa781cb3..5bec63e2e79ff 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -312,9 +312,7 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { - LLAMA_LOG_DEBUG("here 3\n"); llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - auto * gf = graph_init(); graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); if (!ggml_backend_sched_reserve(sched.get(), gf)) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e711b28684837..bee6e6bd359b4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1495,8 +1495,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (flags & TENSOR_NOT_REQUIRED) { return nullptr; } - return nullptr; - //throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); } // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops @@ -1591,8 +1590,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return t; } } - fprintf(stderr, "create_tensor: Creating '%s' with ne=[%ld, %ld, %ld]\n", - tn_str.c_str(), ne.begin()[0], ne.begin()[1], ne.begin()[2]); return ml.create_tensor(ctx, tn, ne, flags); }; From 9d25ca1d21768cd36926bbbc1aa7eaaa70f87646 Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Sun, 6 Apr 2025 01:06:59 -0700 Subject: [PATCH 6/7] A forward pass Run forward passes with dummy codes. Output tensor shapes (raw audio samples) seem to match expected shape given number of input frames. Attempts with Orpheus to be done soon. The gguf used in this commit is at: https://huggingface.co/jamorphy/snac-fwd-pass-devel-gguf --- convert_hf_to_gguf.py | 2 +- examples/tts/orpheus-tts.cpp | 343 +++++++--------------- ggml/src/ggml-cpu/ggml-cpu.c | 1 + include/llama.h | 2 + src/llama-context.cpp | 36 ++- src/llama-context.h | 2 + src/llama-model.cpp | 551 ++++++++++++++++------------------- src/llama-model.h | 4 +- 8 files changed, 396 insertions(+), 545 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 01ec22aa3cc28..093e769e338f3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2494,7 +2494,7 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vocab_size (4096) # TODO: Fix + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_uint32("snac.quantizer.codebook_size", self.hparams["codebook_size"]) self.gguf_writer.add_uint32("snac.quantizer.codebook_dim", self.hparams["codebook_dim"]) self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) # 1024 diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp index 45595e9552fc0..a7f0e16dfa296 100644 --- a/examples/tts/orpheus-tts.cpp +++ b/examples/tts/orpheus-tts.cpp @@ -1,6 +1,5 @@ #include "common.h" #include "llama.h" -#include "llama-impl.h" #include "log.h" #include "arg.h" #include "sampling.h" @@ -19,148 +18,30 @@ #include #include -std::vector redistribute_codes(const std::vector& raw_codes) { - std::vector snac_codes; - for (size_t i = 0; i < raw_codes.size(); i += 7) { - // Ensure we have a full frame (7 codes) - if (i + 6 >= raw_codes.size()) break; - - // Frame offsets (per notebook) - snac_codes.push_back(raw_codes[i]); // Codebook 0 (no offset) - snac_codes.push_back(raw_codes[i+1] - 4096); // Codebook 1 - snac_codes.push_back(raw_codes[i+2] - 8192); // Codebook 2 - snac_codes.push_back(raw_codes[i+3] - 12288); // Codebook 2 - snac_codes.push_back(raw_codes[i+4] - 16384); // Codebook 1 - snac_codes.push_back(raw_codes[i+5] - 20480); // Codebook 2 - snac_codes.push_back(raw_codes[i+6] - 24576); // Codebook 2 - } - return snac_codes; -} - -static std::vector embd_to_audio( - const float * embd, - const int n_codes, - const int n_embd, - const int n_thread); -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); -static void fill_hann_window(int length, bool periodic, float * output); -static void irfft(int n, const float * inp_cplx, float * out_real); -static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output); - -static void print_usage(int /*argc*/, char **argv) { - LOG("\nexample usage:\n"); - LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); - LOG("\n"); -} - -static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { - auto tmp = common_tokenize(vocab, txt, add_special, parse_special); - prompt.insert(prompt.end(), tmp.begin(), tmp.end()); -} - - -// // Include embd_to_audio and save_wav16 from tts.cpp (for now) -static std::vector embd_to_audio( - const float * embd, - const int n_codes, - const int n_embd, - const int n_thread) { - const int n_fft = 1280; - const int n_hop = 320; - const int n_win = 1280; - const int n_pad = (n_win - n_hop)/2; - const int n_out = (n_codes - 1)*n_hop + n_win; - - std::vector hann(n_fft); - fill_hann_window(hann.size(), true, hann.data()); - - int n_spec = n_embd*n_codes; - - std::vector E (n_spec); - std::vector S (n_spec); - std::vector ST(n_spec); - - for (int l = 0; l < n_codes; ++l) { - for (int k = 0; k < n_embd; ++k) { - E[k*n_codes + l] = embd[l*n_embd + k]; - } - } - - for (int k = 0; k < n_embd/2; ++k) { - for (int l = 0; l < n_codes; ++l) { - float mag = E[(k )*n_codes + l]; - float phi = E[(k + n_embd/2)*n_codes + l]; - mag = exp(mag); - if (mag > 1e2) { - mag = 1e2; - } - S[2*(k*n_codes + l) + 0] = mag*cosf(phi); - S[2*(k*n_codes + l) + 1] = mag*sinf(phi); - } - } - - for (int l = 0; l < n_codes; ++l) { - for (int k = 0; k < n_embd/2; ++k) { - ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; - ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; - } - } - - std::vector res (n_codes*n_fft); - std::vector hann2(n_codes*n_fft); - - std::vector workers(n_thread); - for (int i = 0; i < n_thread; ++i) { - workers[i] = std::thread([&, i]() { - for (int l = i; l < n_codes; l += n_thread) { - irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); - for (int j = 0; j < n_fft; ++j) { - res [l*n_fft + j] *= hann[j]; - hann2[l*n_fft + j] = hann[j] * hann[j]; - } - } - }); - } - for (int i = 0; i < n_thread; ++i) { - workers[i].join(); - } - - std::vector audio; - std::vector env; - - fold(res, n_out, n_win, n_hop, n_pad, audio); - fold(hann2, n_out, n_win, n_hop, n_pad, env); - - for (size_t i = 0; i < audio.size(); ++i) { - audio[i] /= env[i]; - } - - return audio; -} - -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +static bool save_wav16(const std::string &fname, const std::vector &data, int sample_rate) { std::ofstream file(fname, std::ios::binary); if (!file) { LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); return false; } - struct wav_header { - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t chunk_size; - char wave[4] = {'W', 'A', 'V', 'E'}; - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_chunk_size = 16; - uint16_t audio_format = 1; // PCM - uint16_t num_channels = 1; // Mono - uint32_t sample_rate; - uint32_t byte_rate; - uint16_t block_align; - uint16_t bits_per_sample = 16; - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size; - } header; - + wav_header header; header.sample_rate = sample_rate; header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); header.block_align = header.num_channels * (header.bits_per_sample / 8); @@ -169,95 +50,49 @@ static bool save_wav16(const std::string & fname, const std::vector & dat file.write(reinterpret_cast(&header), sizeof(header)); - for (const auto & sample : data) { - int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + for (const auto &sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0f, -32768.0f, 32767.0f)); file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); } return file.good(); } -// Supporting functions from tts.cpp (for embd_to_audio) -static void fill_hann_window(int length, bool periodic, float * output) { - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } +std::vector redistribute_codes(const std::vector& raw_codes) { + std::vector snac_codes; + for (size_t i = 0; i < raw_codes.size(); i += 7) { + if (i + 6 >= raw_codes.size()) break; + + // Subtract 128266 base and layer-specific offsets + snac_codes.push_back(raw_codes[i] - 128266); // Layer 1: offset 0 + snac_codes.push_back(raw_codes[i + 1] - 128266 - 4096); // Layer 2: offset 4096 + snac_codes.push_back(raw_codes[i + 2] - 128266 - 8192); // Layer 3: offset 8192 + snac_codes.push_back(raw_codes[i + 3] - 128266 - 12288); // Layer 3: offset 12288 + snac_codes.push_back(raw_codes[i + 4] - 128266 - 16384); // Layer 2: offset 16384 + snac_codes.push_back(raw_codes[i + 5] - 128266 - 20480); // Layer 3: offset 20480 + snac_codes.push_back(raw_codes[i + 6] - 128266 - 24576); // Layer 3: offset 24576 + } + return snac_codes; } -static void twiddle(float * real, float * imag, int k, int N) { - float angle = 2 * M_PI * k / N; - *real = cos(angle); - *imag = sin(angle); -} - -static void irfft(int n, const float * inp_cplx, float * out_real) { - int N = n / 2 + 1; - - std::vector real_input(N); - std::vector imag_input(N); - for (int i = 0; i < N; ++i) { - real_input[i] = inp_cplx[2 * i]; - imag_input[i] = inp_cplx[2 * i + 1]; - } - - std::vector real_output(n); - std::vector imag_output(n); - - for (int k = 0; k < n; ++k) { - real_output[k] = 0.0f; - imag_output[k] = 0.0f; - for (int m = 0; m < N; ++m) { - float twiddle_real; - float twiddle_imag; - - twiddle(&twiddle_real, &twiddle_imag, k * m, n); - - real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; - imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; - } - } - - for (int i = 0; i < n; ++i) { - out_real[i] = real_output[i] / N; - } +static void print_usage(int /*argc*/, char **argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); + LOG("\n"); } -static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output) { - int64_t output_height = n_out; - int64_t kernel_w = n_win; - int64_t stride_w = n_hop; - int64_t width = n_out; - - output.resize(width, 0.0f); - - int64_t col_idx = 0; - for (int64_t w_col = 0; w_col < width; ++w_col) { - int64_t start = w_col * stride_w - n_pad; - int64_t end = start + kernel_w; - - for (int64_t w_im = start; w_im < end; ++w_im) { - if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { - output[w_im] += data[col_idx]; - } - col_idx++; - } - } - - output.resize(n_out - 2 * n_pad); +static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt.insert(prompt.end(), tmp.begin(), tmp.end()); } int main(int argc, char **argv) { common_params params; - + params.model = "models/orpheus-3b-0.1-ft-q4_k_m.gguf"; - params.vocoder.model = "models/snac-vocab.gguf"; + params.vocoder.model = "models/snac-fwd-pass-devel.gguf"; params.out_file = "output.wav"; - params.n_predict = 1200; params.sampling.top_k = 4; params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; params.n_batch = 4096; @@ -265,7 +100,8 @@ int main(int argc, char **argv) { common_init(); llama_backend_init(); llama_numa_init(params.numa); - + + common_init_result orpheus_init_ttc = common_init_from_params(params); llama_model * model_ttc = NULL; @@ -290,17 +126,15 @@ int main(int argc, char **argv) { prompt_add(tokens, vocab, "", false, true); // Emotion tag tokens.push_back(128009); // <|eot_id|> tokens.push_back(128260); // <|endofhuman|> - + llama_model * model_cts = NULL; llama_context * ctx_cts = NULL; params.model = params.vocoder.model; - params.n_batch = 2; params.embedding = true; - // disable warmup, SNAC doesn't care about BOS or EOS tokens; - params.warmup = false; + params.warmup = false; // SNAC doesn't care about BOS or EOS tokens common_init_result snac_init_cts = common_init_from_params(params); LOG_INF("SNAC model loaded: %s\n", params.model.c_str()); @@ -308,35 +142,80 @@ int main(int argc, char **argv) { model_cts = snac_init_cts.model.get(); ctx_cts = snac_init_cts.context.get(); - std::vector speech_codes = {100, 4200, 8500, 12500, 16500, 21000, 25000, - 200, 4300, 8600, 12600, 16600, 21111, 25100}; - - std::vector snac_codes = redistribute_codes(speech_codes); - - const int n_codes = speech_codes.size(); - const int batch_size = n_codes; - - llama_batch batch = llama_batch_init(batch_size, 0, 1); - - for (size_t i = 0; i < n_codes; ++i) { + // TODO: Use real orpheus codes + // Just some random numbers for testing + std::vector orpheus_codes = { + // Frame 1, 7 codes per frame + 128266 + 100, // L1: 100 + 128266 + 4096 + 200, // L2: 200 + 128266 + 8192 + 300, // L3: 300 + 128266 + 12288 + 400,// L3: 400 + 128266 + 16384 + 500,// L2: 500 + 128266 + 20480 + 600,// L3: 600 + 128266 + 24576 + 700,// L3: 700 + // Frame 2 + 128266 + 150, 128266 + 4096 + 250, 128266 + 8192 + 350, 128266 + 12288 + 450, + 128266 + 16384 + 550, 128266 + 20480 + 650, 128266 + 24576 + 750, + // Frame 3 + 128266 + 110, 128266 + 4096 + 210, 128266 + 8192 + 310, 128266 + 12288 + 410, + 128266 + 16384 + 510, 128266 + 20480 + 610, 128266 + 24576 + 710, + // Frame 4 + 128266 + 120, 128266 + 4096 + 220, 128266 + 8192 + 320, 128266 + 12288 + 420, + 128266 + 16384 + 520, 128266 + 20480 + 620, 128266 + 24576 + 720, + // Frame 5 + 128266 + 130, 128266 + 4096 + 230, 128266 + 8192 + 330, 128266 + 12288 + 430, + 128266 + 16384 + 530, 128266 + 20480 + 630, 128266 + 24576 + 730, + // Frame 6 + 128266 + 140, 128266 + 4096 + 240, 128266 + 8192 + 340, 128266 + 12288 + 440, + 128266 + 16384 + 540, 128266 + 20480 + 640, 128266 + 24576 + 740, + // Frame 7 + 128266 + 160, 128266 + 4096 + 260, 128266 + 8192 + 360, 128266 + 12288 + 460, + 128266 + 16384 + 560, 128266 + 20480 + 660, 128266 + 24576 + 760, + // Frame 8 + 128266 + 170, 128266 + 4096 + 270, 128266 + 8192 + 370, 128266 + 12288 + 470, + 128266 + 16384 + 570, 128266 + 20480 + 670, 128266 + 24576 + 770, + // Frame 9 + 128266 + 180, 128266 + 4096 + 280, 128266 + 8192 + 380, 128266 + 12288 + 480, + 128266 + 16384 + 580, 128266 + 20480 + 680, 128266 + 24576 + 780, + // Frame 10 + 128266 + 190, 128266 + 4096 + 290, 128266 + 8192 + 390, 128266 + 12288 + 490, + 128266 + 16384 + 590, 128266 + 20480 + 690, 128266 + 24576 + 790 + }; + + std::vector snac_codes = redistribute_codes(orpheus_codes); + + const int batch_size = snac_codes.size(); + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + for (size_t i = 0; i < batch_size; ++i) { common_batch_add(batch, snac_codes[i], i, {0}, true); } LOG_INF("Batch before decode: n_tokens = %d\n", batch.n_tokens); - if (llama_decode(ctx_cts, batch) != 0) { /* error */ } - - if (llama_decode(ctx_cts, batch) != 0) { /* error */ } - GGML_ASSERT(batch.n_tokens == n_codes); + GGML_ASSERT(batch.n_tokens == batch_size); batch.logits[batch.n_tokens - 1] = true; - + if (llama_decode(ctx_cts, batch) != 0) { LOG_ERR("Failed to decode SNAC batch\n"); return 1; } - llama_synchronize(ctx_cts); - LOG_INF("SNAC decode completed\n"); + llama_synchronize(ctx_cts); + + float* embd = llama_get_embeddings(ctx_cts); + if (!embd) { + LOG_ERR("No embeddings available\n"); + return 1; + } + + int n_samples = llama_get_n_outputs(ctx_cts); + std::vector audio(n_samples); + LOG_INF("n_samples: %i\n", n_samples); + memcpy(audio.data(), embd, n_samples * sizeof(float)); + + save_wav16(params.out_file, audio, 24000); llama_batch_free(batch); llama_backend_free(); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index def6eb3423c61..7bded06f88a94 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -14894,6 +14894,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: + case GGML_OP_SNAKE: { n_tasks = 1; } break; diff --git a/include/llama.h b/include/llama.h index 6a44be404d914..f98f1910bcf1c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -629,6 +629,8 @@ extern "C" { llama_seq_id * cells_sequences; }; + LLAMA_API int32_t llama_get_n_outputs(struct llama_context * ctx); + // Create an empty KV cache view. (use only for debugging purposes) LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2e79ff..d15061655da39 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -851,6 +851,10 @@ float * llama_context::get_logits_ith(int32_t i) { } } +int32_t llama_context::get_n_outputs() { + return n_outputs; +} + float * llama_context::get_embeddings() { // reorder embeddings for backward compatibility output_reorder(); @@ -1403,10 +1407,21 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT(embd != nullptr); float * embd_out = embd + n_outputs_prev*n_embd; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (model.arch == LLM_ARCH_SNAC_DEC) { + // TODO: hack, SNAC outputs audio samples, not embeddings + // Rely on n_outputs for now, but perhaps add an `n_samples_snac` to + // llama_context to avoid doing these checks + int64_t n_samples = t_embd->ne[0]; + if (n_samples > 0) { + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_samples * sizeof(float)); + n_outputs = n_samples; // Update for downstream + } + } else { + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs * n_embd * sizeof(float)); + } } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1471,8 +1486,11 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; + // TODO: Hack for now to avoid overwriting n_outputs in previous step + if (model.arch != LLM_ARCH_SNAC_DEC) { + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + } // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); @@ -2417,6 +2435,12 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +int32_t llama_get_n_outputs(struct llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_n_outputs(); +} + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544cb1a..ff9ad663d1fe5 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -48,6 +48,8 @@ struct llama_context { float * get_logits(); float * get_logits_ith(int32_t i); + int32_t get_n_outputs(); + float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bee6e6bd359b4..4051c42852039 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1319,13 +1319,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_SNAC_DEC: { - hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; // From decoder_channel_dims + // TODO: Read from GGUF + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; hparams.upsample_rates = {8, 8, 4, 2}; hparams.n_embd = 768; hparams.n_layer = 8; - // Dummy KV cache params to satisfy llama.cpp - for (uint32_t i = 0; i < 7; ++i) { // n_total_layers = 8 + // Dummy KV cache params to satisfy init error + for (uint32_t i = 0; i < hparams.n_layer; ++i) { hparams.n_head_arr[i] = 1; hparams.n_head_kv_arr[i] = 1; } @@ -3716,8 +3717,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {8, 4096, 1}, 0); - hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; - // Quantizer projection tensors (0, 1, 2) for (int qid = 0; qid < 3; ++qid) { fprintf(stderr, "%s: Loading quantizer %d tensors\n", __func__, qid); @@ -3782,49 +3781,49 @@ bool llama_model::load_tensors(llama_model_loader & ml) { break; case 3: // Block 3: Residual Unit 1 { - int res_unit_idx = 0; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); } break; case 4: // Block 4: Residual Unit 2 { - int res_unit_idx = 1; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); } break; case 5: // Block 5: Residual Unit 3 { - int res_unit_idx = 2; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; - res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); - res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); - res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); - res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); - res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); - res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); - res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); - res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); + auto & ru = layer.decoder_blocks[bid].res_unit; + ru.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); + ru.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); + ru.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); + ru.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); + ru.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); + ru.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); } break; default: fprintf(stderr, "%s: ERROR: Unexpected block id %d in layer %d\n", __func__, bid, i); - return false; // Or handle error appropriately + return false; } fprintf(stderr, "%s: Layer %d, Block %d: Finished\n", __func__, i, bid); - } // End block loop + } } else if (i == 6) { // --- Layer 6: Alpha --- layer.alpha = create_tensor(tn(LLM_TENSOR_ALPHA, i, -1), {1, n_in, 1}, 0); @@ -3834,9 +3833,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S7, i, -1), {1, 1, n_out}, 0); layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B7, i, -1), {n_out}, 0); } - else { // Should not happen + else { fprintf(stderr, "%s: ERROR: Unexpected layer index %d\n", __func__, i); - return false; // Or handle error appropriately + return false; } fprintf(stderr, "%s: Layer %d: Finished\n", __func__, i); } @@ -11744,286 +11743,230 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; -// struct llm_build_snac_dec : public llm_graph_context { - -// llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { -// LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); -// for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { -// LLAMA_LOG_INFO("%d ", ubatch.token[i]); -// } -// LLAMA_LOG("\n"); -// LLAMA_LOG_DEBUG("%s: Entering constructor, model.layers.size() = %zu\n", __func__, model.layers.size()); -// ggml_tensor * cur; -// ggml_tensor * inpL; - -// // TODO: probalby just get raw codes -// //cur = build_inp_embd(model.tok_embd); -// //LLAMA_LOG_INFO("After build_inp_embd: shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // hack, hardcode expected SNAC input at first conv layer -// cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); // [channels, seq_len, 1, 1] -// ggml_set_input(cur); -// LLAMA_LOG_INFO("hardcoded shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // end hack - -// // Log input tokens before processing -// LLAMA_LOG_INFO("%s: ubatch.n_tokens = %u\n", __func__, ubatch.n_tokens); -// LLAMA_LOG_WARN("%s: Input tokens from ubatch = ", __func__); -// for (uint32_t i = 0; i < ubatch.n_tokens && i < 20; ++i) { -// LLAMA_LOG_INFO("%d ", ubatch.token[i]); -// } -// if (ubatch.n_tokens > 20) LLAMA_LOG_INFO("..."); -// LLAMA_LOG("\n"); - -// // ggml_tensor * layer_1; -// // ggml_tensor * layer_2; -// // ggml_tensor * layer_3; -// //redistribute_codes(cur, &layer_1, &layer_2, &layer_3); - -// // Log the redistributed layers -// //log_tensor("Layer 1", layer_1); -// //log_tensor("Layer 2", layer_2); -// //log_tensor("Layer 3", layer_3); - -// for (uint32_t il = 1; il < model.layers.size(); ++il) { -// const auto & layer = model.layers[il]; - -// LLAMA_LOG_DEBUG("%s: Layer %u: Starting, cur = %p\n", __func__, il, cur); - -// if (il == 1) { // pointwise -// LLAMA_LOG_INFO("%s: Layer %u: Pointwise conv, conv_w = %p, conv_s = %p, conv_b = %p\n", -// __func__, il, layer.conv_w, layer.conv_s, layer.conv_b); -// LLAMA_LOG_INFO("Before transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// cur = ggml_transpose(ctx0, cur); // [768, 512] -> [512, 768] -// LLAMA_LOG_INFO("After transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 0); -// LLAMA_LOG_INFO("%s: Layer %u: After pointwise conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } else if (il == model.layers.size() - 1) { -// LLAMA_LOG_INFO("%s: Layer %u: Final layer, alpha = %p, conv_w = %p, conv_s = %p, conv_b = %p\n", -// __func__, il, layer.alpha, layer.conv_w, layer.conv_s, layer.conv_b); -// cur = ggml_snake(ctx0, cur, layer.alpha); -// LLAMA_LOG_INFO("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 3); -// LLAMA_LOG_INFO("%s: Layer %u: After final conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = ggml_tanh(ctx0, cur); -// LLAMA_LOG_INFO("%s: Layer %u: After ggml_tanh, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } else { -// // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) -// const int stride = hparams.upsample_rates[il - 2]; // 8 for il = 2 -// const int padding = stride; - -// // Block 0: Snake activation -// const auto & block0 = layer.decoder_blocks[0]; -// LLAMA_LOG_DEBUG("%s: Layer %u: Block 0, alpha = %p\n", __func__, il, block0.alpha); -// cur = ggml_snake(ctx0, cur, block0.alpha); -// LLAMA_LOG_DEBUG("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // Block 1: Transposed convolution -// const auto & block1 = layer.decoder_blocks[1]; -// LLAMA_LOG_DEBUG("%s: Layer %u: Block 1, stride = %d, up_weight = %p, up_scale = %p, up_bias = %p\n", -// __func__, il, stride, block1.up_weight, block1.up_scale, block1.up_bias); - -// cur = apply_conv1d_transpose(cur, block1.up_weight, block1.up_scale, block1.up_bias, stride, padding); -// LLAMA_LOG_DEBUG("%s: Layer %u: After conv1d_transpose, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// // Residual Units (3 per block) -// for (int j = 0; j < 3; ++j) { -// const auto & ru = block1.res_units[j]; -// ggml_tensor * inpL = cur; -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Starting, inpL = %p, alpha1 = %p, conv1_w = %p, conv1_s = %p, conv1_b = %p\n", -// __func__, il, j, inpL, ru.alpha1, ru.conv1_w, ru.conv1_s, ru.conv1_b); - -// cur = ggml_snake(ctx0, cur, ru.alpha1); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; -// int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation -// cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_s, ru.conv1_b, 1, padding); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); - -// // pw -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Pointwise, alpha2 = %p, conv2_w = %p, conv2_s = %p, conv2_b = %p\n", -// __func__, il, j, ru.alpha2, ru.conv2_w, ru.conv2_s, ru.conv2_b); -// cur = ggml_snake(ctx0, cur, ru.alpha2); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_s, ru.conv2_b, 1, 0); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); - -// // residual -// cur = ggml_add(ctx0, cur, inpL); -// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_add, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } -// } -// LLAMA_LOG_DEBUG("%s: Layer %u: Finished, cur = %p\n", __func__, il, cur); -// } - -// int64_t target_samples = 24000; // TODO: magic number -// LLAMA_LOG_DEBUG("%s: Trimming output, cur = %p, target_samples = %ld, cur->ne[0] = %ld\n", -// __func__, cur, target_samples, cur ? cur->ne[0] : -1); -// if (cur->ne[0] > target_samples) { -// cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); -// LLAMA_LOG_DEBUG("%s: After ggml_get_rows, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", -// __func__, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); -// } - -// LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); -// cb(cur, "result_embd", -1); -// res->t_embd = cur; - -// LLAMA_LOG_DEBUG("%s: Building forward graph, cur = %p\n", __func__, cur); -// ggml_build_forward_expand(gf, cur); -// LLAMA_LOG_DEBUG("%s: Graph build completed\n", __func__); -// } - -// // TODO: move these somewhere else -// private: -// // Helper to log tensor contents -// void log_tensor(const char * name, ggml_tensor * tensor) { -// if (!tensor) { -// LLAMA_LOG_INFO("%s: %s is null\n", __func__, name); -// return; -// } -// LLAMA_LOG_DEBUG("%s: %s shape = [%ld, %ld, %ld, %ld], first 20 elements = ", -// __func__, name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); -// int n_elements = ggml_nelements(tensor); -// float * data = (float *)tensor->data; -// for (int i = 0; i < std::min(20, n_elements); ++i) { -// LLAMA_LOG_DEBUG("%.2f ", data[i]); -// } -// if (n_elements > 20) LLAMA_LOG_DEBUG("..."); -// LLAMA_LOG_DEBUG("\n"); -// } - -// void redistribute_codes(ggml_tensor * input, ggml_tensor ** layer_1, ggml_tensor ** layer_2, ggml_tensor ** layer_3) { -// int64_t n_codes = input->ne[1]; // Assuming input is [n_embd, n_tokens, 1, 1] -// int64_t n_frames = n_codes / 7; -// if (n_codes % 7 != 0) { -// LLAMA_LOG_ERROR("%s: Input codes length %ld is not a multiple of 7\n", __func__, n_codes); -// *layer_1 = *layer_2 = *layer_3 = nullptr; -// return; -// } - -// int64_t n_layer_1 = n_frames; // 1 code per frame -// int64_t n_layer_2 = n_frames * 2; // 2 codes per frame -// int64_t n_layer_3 = n_frames * 4; // 4 codes per frame - -// // Indices for each layer -// std::vector idx_layer_1(n_layer_1); -// std::vector idx_layer_2(n_layer_2); -// std::vector idx_layer_3(n_layer_3); - -// for (int64_t i = 0; i < n_frames; ++i) { -// int64_t base_idx = i * 7; -// idx_layer_1[i] = base_idx + 0; // No offset -// idx_layer_2[i * 2] = base_idx + 1; // Offset -4096 -// idx_layer_2[i * 2 + 1] = base_idx + 4; // Offset -16384 -// idx_layer_3[i * 4] = base_idx + 2; // Offset -8192 -// idx_layer_3[i * 4 + 1] = base_idx + 3; // Offset -12288 -// idx_layer_3[i * 4 + 2] = base_idx + 5; // Offset -20480 -// idx_layer_3[i * 4 + 3] = base_idx + 6; // Offset -24576 -// } - -// // Create index tensors -// ggml_tensor * idx_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); -// ggml_tensor * idx_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); -// ggml_tensor * idx_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); - -// memcpy(idx_1->data, idx_layer_1.data(), n_layer_1 * sizeof(int32_t)); -// memcpy(idx_2->data, idx_layer_2.data(), n_layer_2 * sizeof(int32_t)); -// memcpy(idx_3->data, idx_layer_3.data(), n_layer_3 * sizeof(int32_t)); - -// // Extract layers using ggml_get_rows -// *layer_1 = ggml_get_rows(ctx0, input, idx_1); -// *layer_2 = ggml_get_rows(ctx0, input, idx_2); -// *layer_3 = ggml_get_rows(ctx0, input, idx_3); - -// // Apply offsets -// *layer_2 = ggml_add(ctx0, *layer_2, ggml_new_f32(ctx0, -4096.0f)); // Simplified; we'll refine offsets later -// *layer_3 = ggml_add(ctx0, *layer_3, ggml_new_f32(ctx0, -8192.0f)); // Simplified for now -// } - -// ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, -// int stride, int padding) { -// ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); -// ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); -// if (conv_b) { -// ggml_tensor* bias_reshaped = ggml_reshape_3d(ctx0, conv_b, 1, 1024, 1); -// cur = ggml_add(ctx0, cur, bias_reshaped); -// } -// return cur; -// } - -// ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, ggml_tensor * up_bias, int stride, int padding) { -// // Normalize weights (temporary fix for up_scale shape mismatch) -// if (up_scale->ne[2] != up_weight->ne[1]) { // 1024 != 512 -// LLAMA_LOG_WARN("up_scale channels (%ld) don’t match output channels (%ld), expected behavior may vary\n", up_scale->ne[2], up_weight->ne[1]); -// // Ideally reshape up_scale to [1, 1, 512, 1], but no reshape; proceed with warning -// } -// ggml_tensor * w_final = normalize_weight(up_weight, up_scale); -// LLAMA_LOG_INFO("After normalize weight: w_final shape = [%ld, %ld, %ld, %ld]\n", -// w_final->ne[0], w_final->ne[1], w_final->ne[2], w_final->ne[3]); - -// ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, 0, 1); -// LLAMA_LOG_INFO("After ggml_conv_transpose_1d = [%ld, %ld, %ld, %ld]\n", -// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); - -// if (up_bias) { -// // up_bias is [512, 1, 1, 1]; need [4104, 512, 1, 1] for ggml_add -// LLAMA_LOG_INFO("entering up_bias block. Before ggml_repeat, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// LLAMA_LOG_INFO("Before ggml_repeat, up_bias shape = [%ld, %ld, %ld, %ld]\n", up_bias->ne[0], up_bias->ne[1], up_bias->ne[2], up_bias->ne[3]); -// ggml_tensor * bias_repeated = ggml_repeat(ctx0, up_bias, cur); -// LLAMA_LOG_DEBUG("Repeated up_bias to shape = [%ld, %ld, %ld, %ld]\n", -// bias_repeated->ne[0], bias_repeated->ne[1], bias_repeated->ne[2], bias_repeated->ne[3]); -// cur = ggml_add(ctx0, cur, bias_repeated); -// LLAMA_LOG_DEBUG("After bias add: cur shape = [%ld, %ld, %ld, %ld]\n", -// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); -// } -// return cur; -// } - -// // w_final = scale * (w / || w ||) -// ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { -// ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? -// ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); -// ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); -// return w_final; -// } -// }; - // TODO: Placeholder struct llm_build_snac_dec : public llm_graph_context { llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * emb_layer_1, * emb_layer_2, * emb_layer_3; + build_codebook_embd(model, &emb_layer_1, &emb_layer_2, &emb_layer_3); + + if (emb_layer_1 == nullptr || emb_layer_2 == nullptr || emb_layer_3 == nullptr) { + // graph build is called with garbage ubatch codes during model init + // in this case, bypass normal graph construction and return a dummy + LLAMA_LOG_INFO("build_codebook_inputs returned null, using dummy tensor\n"); + cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, ubatch.n_tokens > 0 ? ubatch.n_tokens : 64, 1, 1); + ggml_set_input(cur); + } else { + // Projections + cur = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[0], 8, 768), emb_layer_1); + cur = ggml_reshape_4d(ctx0, cur, 768, emb_layer_1->ne[1], 1, 1); + ggml_tensor * scale_1 = ggml_reshape_4d(ctx0, model.codebook_proj_s[0], 768, 1, 1, 1); + cur = ggml_mul(ctx0, cur, scale_1); + ggml_tensor * bias_1 = ggml_reshape_4d(ctx0, model.codebook_proj_b[0], 768, 1, 1, 1); // Fix here + cur = ggml_add(ctx0, cur, bias_1); + + ggml_tensor * proj_2 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[1], 8, 768), emb_layer_2); + proj_2 = ggml_reshape_4d(ctx0, proj_2, 768, emb_layer_2->ne[1], 1, 1); + ggml_tensor * scale_2 = ggml_reshape_4d(ctx0, model.codebook_proj_s[1], 768, 1, 1, 1); + proj_2 = ggml_mul(ctx0, proj_2, scale_2); + ggml_tensor * bias_2 = ggml_reshape_4d(ctx0, model.codebook_proj_b[1], 768, 1, 1, 1); + proj_2 = ggml_add(ctx0, proj_2, bias_2); + + ggml_tensor * proj_3 = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[2], 8, 768), emb_layer_3); + proj_3 = ggml_reshape_4d(ctx0, proj_3, 768, emb_layer_3->ne[1], 1, 1); + ggml_tensor * scale_3 = ggml_reshape_4d(ctx0, model.codebook_proj_s[2], 768, 1, 1, 1); + proj_3 = ggml_mul(ctx0, proj_3, scale_3); + ggml_tensor * bias_3 = ggml_reshape_4d(ctx0, model.codebook_proj_b[2], 768, 1, 1, 1); + proj_3 = ggml_add(ctx0, proj_3, bias_3); + + cur = ggml_concat(ctx0, cur, proj_2, 1); + cur = ggml_concat(ctx0, cur, proj_3, 1); + + for (int j = 1; j <= hparams.n_layer; ++j) { + const auto & layer = model.layers[j]; + const int64_t n_in = hparams.n_channels[j-1]; + const int64_t n_out = (j < 7) ? hparams.n_channels[j] : hparams.n_channels[j-1]; + + if (j == 1) { + int64_t seq_len = cur->ne[1]; + cur = ggml_reshape_2d(ctx0, cur, 768, seq_len); // cur starts F32 (type 0) from projections + ggml_tensor * w = ggml_reshape_2d(ctx0, layer.conv_w, 768, 1024); // F16 (type 1) + ggml_tensor * s = ggml_cpy(ctx0, layer.conv_s, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, 1, n_out)); // Cast F32 -> F16 + w = ggml_mul(ctx0, w, s); + cur = ggml_mul_mat(ctx0, w, cur); + cur = ggml_reshape_4d(ctx0, cur, seq_len, 1024, 1, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + // Residual Units + else if (j >= 2 && j <= 5) { + ggml_tensor * alpha = layer.decoder_blocks[0].alpha; + cur = ggml_snake(ctx0, cur, alpha); + + ggml_tensor * w = layer.decoder_blocks[1].up_weight; + ggml_tensor * s = ggml_cpy(ctx0, layer.decoder_blocks[1].up_scale, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_in, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_transpose_1d(ctx0, w, cur, hparams.upsample_rates[j-2], 0, 1); + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.decoder_blocks[1].up_bias, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b); + + ggml_tensor * noise_w = layer.decoder_blocks[2].noise_w; + ggml_tensor * noise_s = ggml_cpy(ctx0, layer.decoder_blocks[2].noise_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + noise_w = ggml_mul(ctx0, noise_w, noise_s); + cur = ggml_conv_1d(ctx0, noise_w, cur, 1, 0, 1); + + for (int r = 0; r < 3; ++r) { + int bid = 3 + r; + ggml_tensor * w1 = layer.decoder_blocks[bid].res_unit.conv1_w; + ggml_tensor * s1 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv1_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w1 = ggml_mul(ctx0, w1, s1); + cur = ggml_conv_1d_dw(ctx0, w1, cur, 1, 3, 1); + ggml_tensor * b1 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv1_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b1); + + ggml_tensor * w2 = layer.decoder_blocks[bid].res_unit.conv2_w; + ggml_tensor * s2 = ggml_cpy(ctx0, layer.decoder_blocks[bid].res_unit.conv2_s, + ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, n_out, 1)); + w2 = ggml_mul(ctx0, w2, s2); + cur = ggml_conv_1d(ctx0, w2, cur, 1, 0, 1); + ggml_tensor * b2 = ggml_reshape_4d(ctx0, layer.decoder_blocks[bid].res_unit.conv2_b, 1, n_out, 1, 1); + cur = ggml_add(ctx0, cur, b2); + } + } + else if (j == 6) { + ggml_tensor * alpha = layer.alpha; + cur = ggml_snake(ctx0, cur, alpha); + } + else if (j == 7) { + ggml_tensor * w = layer.conv_w; + ggml_tensor * s = layer.conv_s; + + s = ggml_reshape_4d(ctx0, s, 1, 1, 1, 1); + s = ggml_cpy(ctx0, s, ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, 1, 1, 1, 1)); + w = ggml_mul(ctx0, w, s); + cur = ggml_conv_1d(ctx0, w, cur, 1, 3, 1); + + ggml_tensor * b = ggml_reshape_4d(ctx0, layer.conv_b, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, b); + } + } - // TODO: Remove - LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); - for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { - LLAMA_LOG_INFO("%d ", ubatch.token[i]); } - LLAMA_LOG("\n"); - ggml_tensor * cur; - // TODO: Hack. Implement codebook lookups and out_proj - cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); - ggml_set_input(cur); - // end hack + cur = ggml_cpy(ctx0, cur, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3])); - LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); cb(cur, "result_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } +private: + // TODO: SNAC expects a multilayered input from 3 different embedding matrices + void build_codebook_embd(const llama_model & model, + ggml_tensor ** emb_layer_1, + ggml_tensor ** emb_layer_2, + ggml_tensor ** emb_layer_3) { + + *emb_layer_1 = nullptr; + *emb_layer_2 = nullptr; + *emb_layer_3 = nullptr; + + + + bool is_initialized = (ubatch.token != nullptr && ubatch.n_tokens > 0); + if (is_initialized) { + for (int i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.token[i] < 0 || ubatch.token[i] >= 4096) { + is_initialized = false; + break; + } + } + } + + if (!is_initialized) { + return; + } + + int32_t n_tokens = ubatch.n_tokens; + int32_t n_frames = n_tokens / 7; + if (n_tokens % 7 != 0) { + LLAMA_LOG_INFO("build_codebook_embd: n_tokens (%d) not a multiple of 7, truncating\n", n_tokens); + n_frames = n_tokens / 7; + } + + // TODO: read from vq_strides + int32_t n_layer_1 = n_frames; + int32_t n_layer_2 = n_frames * 2; + int32_t n_layer_3 = n_frames * 4; + + LLAMA_LOG_INFO("build_codebook_embd: n_frames = %d, n_layer_1 = %d, n_layer_2 = %d, n_layer_3 = %d\n", + n_frames, n_layer_1, n_layer_2, n_layer_3); + + std::vector idx_1_data(n_layer_1); + std::vector idx_2_data(n_layer_2); + std::vector idx_3_data(n_layer_3); + + // map codes to respective codebook + for (int32_t i = 0; i < n_frames; ++i) { + int32_t base_idx = i * 7; + idx_1_data[i] = ubatch.token[base_idx + 0]; + idx_2_data[i * 2] = ubatch.token[base_idx + 1]; + idx_2_data[i * 2 + 1] = ubatch.token[base_idx + 4]; + idx_3_data[i * 4] = ubatch.token[base_idx + 2]; + idx_3_data[i * 4 + 1] = ubatch.token[base_idx + 3]; + idx_3_data[i * 4 + 2] = ubatch.token[base_idx + 5]; + idx_3_data[i * 4 + 3] = ubatch.token[base_idx + 6]; + } + + // Tensors used for codebook lookups + ggml_tensor * idx_layer_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); + ggml_tensor * idx_layer_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); + ggml_tensor * idx_layer_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); + + if (!idx_layer_1 || !idx_layer_2 || !idx_layer_3) { + LLAMA_LOG_INFO("build_codebook_embd: Failed to allocate index tensors\n"); + return; + } + + // ggml is lazy, so explicitly create buffers for codes to be placed in idx_layer_N + ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); + if (!cpu_buft) { + LLAMA_LOG_ERROR("build_codebook_embd: Failed to get CPU buffer type\n"); + return; + } + + ggml_backend_buffer_t buffer_1 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_1 * sizeof(int32_t)); + ggml_backend_buffer_t buffer_2 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_2 * sizeof(int32_t)); + ggml_backend_buffer_t buffer_3 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_3 * sizeof(int32_t)); + + if (!buffer_1 || !buffer_2 || !buffer_3) { + LLAMA_LOG_ERROR("build_codebook_embd: Failed to allocate backend buffers\n"); + if (buffer_1) ggml_backend_buffer_free(buffer_1); + if (buffer_2) ggml_backend_buffer_free(buffer_2); + if (buffer_3) ggml_backend_buffer_free(buffer_3); + return; + } + + // move codes to idx_layer_N + idx_layer_1->buffer = buffer_1; + idx_layer_2->buffer = buffer_2; + idx_layer_3->buffer = buffer_3; + + idx_layer_1->data = ggml_backend_buffer_get_base(buffer_1); + idx_layer_2->data = ggml_backend_buffer_get_base(buffer_2); + idx_layer_3->data = ggml_backend_buffer_get_base(buffer_3); + + ggml_backend_tensor_set(idx_layer_1, idx_1_data.data(), 0, n_layer_1 * sizeof(int32_t)); + ggml_backend_tensor_set(idx_layer_2, idx_2_data.data(), 0, n_layer_2 * sizeof(int32_t)); + ggml_backend_tensor_set(idx_layer_3, idx_3_data.data(), 0, n_layer_3 * sizeof(int32_t)); + + *emb_layer_1 = ggml_get_rows(ctx0, model.codebook[0], idx_layer_1); + *emb_layer_2 = ggml_get_rows(ctx0, model.codebook[1], idx_layer_2); + *emb_layer_3 = ggml_get_rows(ctx0, model.codebook[2], idx_layer_3); + } }; llama_memory_i * llama_model::create_memory() const { diff --git a/src/llama-model.h b/src/llama-model.h index 5e636b0b3b3f3..e75bcf1ed8887 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -156,7 +156,7 @@ struct llama_layer_snac_dec_block { struct ggml_tensor * conv2_w = nullptr; struct ggml_tensor * conv2_s = nullptr; struct ggml_tensor * conv2_b = nullptr; - } res_units[3]; + } res_unit; }; struct llama_layer { @@ -328,7 +328,7 @@ struct llama_layer { struct llama_layer_convnext convnext; struct ggml_tensor * conv_w = nullptr; - struct ggml_tensor * conv_s = nullptr; + struct ggml_tensor * conv_s = nullptr; struct ggml_tensor * conv_b = nullptr; struct ggml_tensor * alpha = nullptr; From b7d0456c0fc21753186858ec4df7832a46a8ab4f Mon Sep 17 00:00:00 2001 From: Ja Morphy Date: Tue, 8 Apr 2025 15:36:32 -0700 Subject: [PATCH 7/7] lazy inputs for snac codes tensors instead of cpu buffers --- src/llama-context.cpp | 14 ++-- src/llama-graph.cpp | 77 ++++++++++++++++++++- src/llama-graph.h | 14 ++++ src/llama-model.cpp | 153 ++++++++++++++++-------------------------- src/llama-model.h | 1 + 5 files changed, 155 insertions(+), 104 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d15061655da39..2958f44e0e76d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1348,13 +1348,13 @@ int llama_context::decode(llama_batch & inp_batch) { const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; } } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f0a8b1071dc3b..449a94b52d642 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -556,6 +556,81 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_snac::set_input(const llama_ubatch * ubatch) { + + LLAMA_LOG_INFO("Setting SNAC input for layer %d\n", ilayer); + + const int n_tokens = ubatch->n_tokens; + if (n_tokens % frame_size != 0) { + return; // TODO: handle gracefully + } + const int n_frames = n_tokens / frame_size; + + int64_t expected_elements = 0; + int vocab_offset = 0; + int tokens_per_frame = 0; + + switch (ilayer) { + case 0: // Layer 1 + tokens_per_frame = 1; + vocab_offset = 128266; // TODO: hparams + break; + case 1: // Layer 2 + tokens_per_frame = 2; + vocab_offset = 132362; + break; + case 2: // Layer 3 + tokens_per_frame = 4; + vocab_offset = 136458; + break; + default: + LLAMA_LOG_ERROR("%s: Invalid SNAC layer index %d encountered.\n", __func__, ilayer); + GGML_ASSERT(false && "Invalid SNAC layer index"); // Should be caught by constructor assert + return; + } + expected_elements = (int64_t)n_frames * tokens_per_frame; + + std::vector indices; + indices.reserve(expected_elements); + + const llama_token * tokens_data = ubatch->token; + + for (int i_frame = 0; i_frame < n_frames; ++i_frame) { + const int frame_start_idx = i_frame * frame_size; + const llama_token * frame_tokens = tokens_data + frame_start_idx; + + switch (ilayer) { + case 0: { // L1: token 0 + int32_t index = (int32_t)(frame_tokens[0] - vocab_offset); + + indices.push_back(index); + break; + } + case 1: { // L2: tokens 1, 4 + int32_t index1 = (int32_t)(frame_tokens[1] - vocab_offset); + int32_t index4 = (int32_t)(frame_tokens[4] - vocab_offset); + + indices.push_back(index1); + indices.push_back(index4); + break; + } + case 2: { // L3: tokens 2, 3, 5, 6 + int32_t index2 = (int32_t)(frame_tokens[2] - vocab_offset); + int32_t index3 = (int32_t)(frame_tokens[3] - vocab_offset); + int32_t index5 = (int32_t)(frame_tokens[5] - vocab_offset); + int32_t index6 = (int32_t)(frame_tokens[6] - vocab_offset); + + indices.push_back(index2); + indices.push_back(index3); + indices.push_back(index5); + indices.push_back(index6); + break; + } + } + } + ggml_backend_tensor_set(target, indices.data(), 0, ggml_nbytes(target)); +} + // // llm_graph_context // @@ -985,8 +1060,6 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { } } else { inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - LLAMA_LOG_DEBUG("build_inp_embd: inp->embd shape = [%ld, %ld, %ld, %ld]\n", - inp->embd->ne[0], inp->embd->ne[1], inp->embd->ne[2], inp->embd->ne[3]); ggml_set_input(inp->embd); cur = inp->embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..cb385c0400474 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -268,6 +268,20 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { const llama_cross * cross = nullptr; }; +class llm_graph_input_snac : public llm_graph_input_i { +public: + llm_graph_input_snac(ggml_tensor * target, int ilayer, + const llama_hparams & hparams) : target(target), ilayer(ilayer), hparams(hparams) {} + virtual ~llm_graph_input_snac() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * target; // idx tensor 1, 2, or 3 + const llama_hparams & hparams; + const int ilayer; + const int frame_size = 7; +}; + // // llm_graph_result // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4051c42852039..a82cfa8aaa337 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1489,8 +1489,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ggml_backend_buffer_type_t first_moved_to_buft = nullptr; auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - std::string tn_str = tn.str(); - ggml_tensor * t_meta = ml.get_tensor_meta(tn_str.c_str()); + ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); if (!t_meta) { if (flags & TENSOR_NOT_REQUIRED) { @@ -11743,21 +11742,22 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; -// TODO: Placeholder struct llm_build_snac_dec : public llm_graph_context { llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * emb_layer_1, * emb_layer_2, * emb_layer_3; - build_codebook_embd(model, &emb_layer_1, &emb_layer_2, &emb_layer_3); - if (emb_layer_1 == nullptr || emb_layer_2 == nullptr || emb_layer_3 == nullptr) { + bool inputs = build_snac_inputs(model, &emb_layer_1, &emb_layer_2, &emb_layer_3); + + if (!inputs) { // graph build is called with garbage ubatch codes during model init // in this case, bypass normal graph construction and return a dummy LLAMA_LOG_INFO("build_codebook_inputs returned null, using dummy tensor\n"); cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, ubatch.n_tokens > 0 ? ubatch.n_tokens : 64, 1, 1); ggml_set_input(cur); } else { + // TODO: Upsampling is wrong // Projections cur = ggml_mul_mat(ctx0, ggml_reshape_2d(ctx0, model.codebook_proj_w[0], 8, 768), emb_layer_1); cur = ggml_reshape_4d(ctx0, cur, 768, emb_layer_1->ne[1], 1, 1); @@ -11859,113 +11859,76 @@ struct llm_build_snac_dec : public llm_graph_context { cur = ggml_cpy(ctx0, cur, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3])); - cb(cur, "result_embd", -1); + LLAMA_LOG_INFO("Final shape of cur = [%ld, %ld, %ld, %ld]\n", + cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + //cb(cur, "result_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } private: - // TODO: SNAC expects a multilayered input from 3 different embedding matrices - void build_codebook_embd(const llama_model & model, - ggml_tensor ** emb_layer_1, - ggml_tensor ** emb_layer_2, - ggml_tensor ** emb_layer_3) { - - *emb_layer_1 = nullptr; - *emb_layer_2 = nullptr; - *emb_layer_3 = nullptr; - - - - bool is_initialized = (ubatch.token != nullptr && ubatch.n_tokens > 0); - if (is_initialized) { - for (int i = 0; i < ubatch.n_tokens; ++i) { - if (ubatch.token[i] < 0 || ubatch.token[i] >= 4096) { - is_initialized = false; - break; - } - } - } - - if (!is_initialized) { - return; + // Create 3 input nodes used for lookups into 3 embd matrices + bool build_snac_inputs(const llama_model & model, + ggml_tensor ** emb_layer_1_out, + ggml_tensor ** emb_layer_2_out, + ggml_tensor ** emb_layer_3_out) { + + *emb_layer_1_out = nullptr; + *emb_layer_2_out = nullptr; + *emb_layer_3_out = nullptr; + + if (this->ubatch.n_tokens <= 0 || this->ubatch.n_tokens % 7 != 0) { + LLAMA_LOG_WARN("%s: Invalid ubatch size n_tokens=%d provided for SNAC graph definition. Cannot define input nodes.\n", + __func__, this->ubatch.n_tokens); + return false; } - int32_t n_tokens = ubatch.n_tokens; - int32_t n_frames = n_tokens / 7; - if (n_tokens % 7 != 0) { - LLAMA_LOG_INFO("build_codebook_embd: n_tokens (%d) not a multiple of 7, truncating\n", n_tokens); - n_frames = n_tokens / 7; - } + const int32_t n_tokens = this->ubatch.n_tokens; + const int32_t n_frames = n_tokens / 7; - // TODO: read from vq_strides - int32_t n_layer_1 = n_frames; - int32_t n_layer_2 = n_frames * 2; - int32_t n_layer_3 = n_frames * 4; - - LLAMA_LOG_INFO("build_codebook_embd: n_frames = %d, n_layer_1 = %d, n_layer_2 = %d, n_layer_3 = %d\n", - n_frames, n_layer_1, n_layer_2, n_layer_3); - - std::vector idx_1_data(n_layer_1); - std::vector idx_2_data(n_layer_2); - std::vector idx_3_data(n_layer_3); - - // map codes to respective codebook - for (int32_t i = 0; i < n_frames; ++i) { - int32_t base_idx = i * 7; - idx_1_data[i] = ubatch.token[base_idx + 0]; - idx_2_data[i * 2] = ubatch.token[base_idx + 1]; - idx_2_data[i * 2 + 1] = ubatch.token[base_idx + 4]; - idx_3_data[i * 4] = ubatch.token[base_idx + 2]; - idx_3_data[i * 4 + 1] = ubatch.token[base_idx + 3]; - idx_3_data[i * 4 + 2] = ubatch.token[base_idx + 5]; - idx_3_data[i * 4 + 3] = ubatch.token[base_idx + 6]; - } + const int32_t n_indices_l1 = n_frames * 1; + const int32_t n_indices_l2 = n_frames * 2; + const int32_t n_indices_l3 = n_frames * 4; - // Tensors used for codebook lookups - ggml_tensor * idx_layer_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); - ggml_tensor * idx_layer_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); - ggml_tensor * idx_layer_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); + ggml_tensor * idx1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_indices_l1); + ggml_tensor * idx2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_indices_l2); + ggml_tensor * idx3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_indices_l3); - if (!idx_layer_1 || !idx_layer_2 || !idx_layer_3) { - LLAMA_LOG_INFO("build_codebook_embd: Failed to allocate index tensors\n"); - return; + if (!idx1 || !idx2 || !idx3) { + LLAMA_LOG_ERROR("%s: Failed to allocate ggml index tensors.\n", __func__); + return false; } - // ggml is lazy, so explicitly create buffers for codes to be placed in idx_layer_N - ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); - if (!cpu_buft) { - LLAMA_LOG_ERROR("build_codebook_embd: Failed to get CPU buffer type\n"); - return; - } + ggml_set_name(idx1, "snac_indices_L1"); + ggml_set_name(idx2, "snac_indices_L2"); + ggml_set_name(idx3, "snac_indices_L3"); - ggml_backend_buffer_t buffer_1 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_1 * sizeof(int32_t)); - ggml_backend_buffer_t buffer_2 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_2 * sizeof(int32_t)); - ggml_backend_buffer_t buffer_3 = ggml_backend_buft_alloc_buffer(cpu_buft, n_layer_3 * sizeof(int32_t)); + // mark as inputs + ggml_set_input(idx1); + ggml_set_input(idx2); + ggml_set_input(idx3); - if (!buffer_1 || !buffer_2 || !buffer_3) { - LLAMA_LOG_ERROR("build_codebook_embd: Failed to allocate backend buffers\n"); - if (buffer_1) ggml_backend_buffer_free(buffer_1); - if (buffer_2) ggml_backend_buffer_free(buffer_2); - if (buffer_3) ggml_backend_buffer_free(buffer_3); - return; - } + // add to res for future access via llama_context + res->add_input(std::make_unique(idx1, 0, this->hparams)); + res->add_input(std::make_unique(idx2, 1, this->hparams)); + res->add_input(std::make_unique(idx3, 2, this->hparams)); - // move codes to idx_layer_N - idx_layer_1->buffer = buffer_1; - idx_layer_2->buffer = buffer_2; - idx_layer_3->buffer = buffer_3; + // lookup + *emb_layer_1_out = ggml_get_rows(ctx0, model.codebook[0], idx1); + *emb_layer_2_out = ggml_get_rows(ctx0, model.codebook[1], idx2); + *emb_layer_3_out = ggml_get_rows(ctx0, model.codebook[2], idx3); - idx_layer_1->data = ggml_backend_buffer_get_base(buffer_1); - idx_layer_2->data = ggml_backend_buffer_get_base(buffer_2); - idx_layer_3->data = ggml_backend_buffer_get_base(buffer_3); + if (!*emb_layer_1_out || !*emb_layer_2_out || !*emb_layer_3_out) { + LLAMA_LOG_ERROR("%s: Failed to create ggml_get_rows nodes.\n", __func__); + *emb_layer_1_out = *emb_layer_2_out = *emb_layer_3_out = nullptr; // Ensure outputs are null on failure + return false; + } - ggml_backend_tensor_set(idx_layer_1, idx_1_data.data(), 0, n_layer_1 * sizeof(int32_t)); - ggml_backend_tensor_set(idx_layer_2, idx_2_data.data(), 0, n_layer_2 * sizeof(int32_t)); - ggml_backend_tensor_set(idx_layer_3, idx_3_data.data(), 0, n_layer_3 * sizeof(int32_t)); + ggml_set_name(*emb_layer_1_out, "snac_embd_L1"); + ggml_set_name(*emb_layer_2_out, "snac_embd_L2"); + ggml_set_name(*emb_layer_3_out, "snac_embd_L3"); - *emb_layer_1 = ggml_get_rows(ctx0, model.codebook[0], idx_layer_1); - *emb_layer_2 = ggml_get_rows(ctx0, model.codebook[1], idx_layer_2); - *emb_layer_3 = ggml_get_rows(ctx0, model.codebook[2], idx_layer_3); + return true; } }; diff --git a/src/llama-model.h b/src/llama-model.h index e75bcf1ed8887..0524d1b82705a 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -138,6 +138,7 @@ struct llama_layer_convnext { }; struct llama_layer_snac_dec_block { + struct ggml_tensor * alpha = nullptr; struct ggml_tensor * up_weight = nullptr;