diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 27c9ab74be112..56203d740b1e5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - const int64_t n_kv = kv_state->get_n_kv(); - - if (s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); - float * data = (float *) s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - data[i] = kv_state->s_mask(i); - } - } -} - void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -986,23 +970,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(kv_state); - - const auto n_kv = kv_state->get_n_kv(); - - auto & cur = inp->s_mask; - - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1455,43 +1422,53 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_copy_mask_state( +ggml_tensor * llm_graph_context::build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_state, - int32_t n_seqs) const { + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { const auto * kv_state = static_cast(mstate); const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx0, states, state_copy); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx0, states, state_mask); + ggml_tensor * output_states; + + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; + } - // copy states which won't be changed further (between n_seqs and n_kv) + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)), - ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0); + return output_states; } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -1502,8 +1479,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = kv_state->get_k_l(il); - ggml_tensor * token_shift = build_copy_mask_state( - gf, token_shift_all, state_copy, state_mask, + ggml_tensor * token_shift = build_recurrent_state( + gf, token_shift_all, state_copy, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 28da6a5228bdc..88fb77f1ddc9a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -200,18 +200,6 @@ class llm_graph_input_s_copy : public llm_graph_input_i { const llama_kv_cache_recurrent_state * kv_state; }; -class llm_graph_input_s_mask : public llm_graph_input_i { -public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} - virtual ~llm_graph_input_s_mask() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * s_mask; // F32 [1, n_kv] - - const llama_kv_cache_recurrent_state * kv_state; -}; - class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -521,7 +509,6 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; ggml_tensor * build_inp_s_copy() const; - ggml_tensor * build_inp_s_mask() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -606,18 +593,17 @@ struct llm_graph_context { // recurrent // - ggml_tensor * build_copy_mask_state( + ggml_tensor * build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_state, - int32_t n_seqs) const; + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index f5c6dcd66ce9e..f8cdd52808d7b 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche bool success = true; - // TODO: here we have to verify that all ubatches can fit in the cells - // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells - // during the compute of each ubatch. to reproduce, uncomment the following loop and run: - // - // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8 - // - // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed - // - GGML_UNUSED(ubatches); - //for (const auto & ubatch : ubatches) { - // if (!find_slot(ubatch)) { - // success = false; - // break; - // } - //} + for (const auto & ubatch : ubatches) { + if (!find_slot(ubatch)) { + success = false; + break; + } + } // restore the original state cells = std::move(org_cells); @@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche } bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seq_tokens = ubatch.n_seq_tokens; // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*n_tokens) { + if (head > used + 2*n_seqs) { head = 0; } @@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { empty_cell.src = orig_cell.src; orig_cell.seq_id.erase(seq_id); empty_cell.seq_id.insert(seq_id); // will be overwritten + GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id } seq_meta.tail = next_empty_cell; // find next empty cell if (s + 1 < n_seqs) { - next_empty_cell += 1; for (uint32_t i = 0; i < size; ++i) { + next_empty_cell += 1; if (next_empty_cell >= size) { next_empty_cell -= size; } kv_cell & cell = cells[next_empty_cell]; if (cell.is_empty()) { break; } - next_empty_cell += 1; } } } @@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { // gather and re-order for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + const int32_t dst_id = s + min; + const int32_t src_id = cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { kv_cell & dst_cell = cells[dst_id]; kv_cell & src_cell = cells[src_id]; @@ -563,12 +553,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { std::swap(dst_cell.src, src_cell.src); std::swap(dst_cell.seq_id, src_cell.seq_id); - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; + // swap tails + for (uint32_t i = 0; i < size; ++i) { + int32_t & tail = cells[i].tail; + if (tail == src_id) { + tail = dst_id; + } else if (tail == dst_id) { + tail = src_id; + } } } } @@ -576,7 +568,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { // update the pos of the used seqs for (uint32_t s = 0; s < n_seqs; ++s) { const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; + const int32_t cell_id = s + min; kv_cell & cell = cells[cell_id]; if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { @@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { } } + // Find first cell without src refs, to use as the zero-ed state + { + // TODO: bake-in src refcounts in the cell metadata + std::vector refcounts(size, 0); + for (size_t i = 0; i < size; ++i) { + const int32_t src = cells[i].src; + if (src >= 0) { + refcounts[src] += 1; + } + } + + rs_z = -1; + for (int i = min; i <= max; ++i) { + if (refcounts[i] == 0) { + rs_z = i; + break; + } + } + + for (int i = min; i <= max; ++i) { + if (cells[i].src < 0) { + GGML_ASSERT(rs_z >= 0); + cells[i].src0 = rs_z; + } else { + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + } + cells[i].src = i; // avoid moving or clearing twice + } + } + // allow getting the range of used cells, from head to head + n head = min; n = max - min + 1; @@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { } bool llama_kv_cache_recurrent::get_can_shift() const { - return false; -} - -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; -} - -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - float res = (float) (cell.src >= 0); - - // only clear once - if (cell.src < 0) { - cell.src = cell_id; - } - - return res; + // shifting the pos is trivial for recurrent models + return true; } size_t llama_kv_cache_recurrent::total_size() const { @@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const { return is_full ? 0 : kv->head; } +int32_t llama_kv_cache_recurrent_state::get_rs_z() const { + return is_full ? 0 : kv->rs_z; +} + uint32_t llama_kv_cache_recurrent_state::get_size() const { return kv->size; } @@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { } int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { - return kv->s_copy(i); -} - -float llama_kv_cache_recurrent_state::s_mask(int i) const { - return kv->s_mask(i); + return kv->cells[i + kv->head].src0; } diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index d1da1225655fa..4b33bafd71cca 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -57,10 +57,6 @@ class llama_kv_cache_recurrent : public llama_memory_i { bool get_can_shift() const override; - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; - // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -73,10 +69,14 @@ class llama_kv_cache_recurrent : public llama_memory_i { // computed before each graph build uint32_t n = 0; + // first zero-ed state + int32_t rs_z = -1; + // TODO: optimize for recurrent state needs struct kv_cell { llama_pos pos = -1; - int32_t src = -1; // used to copy states + int32_t src = -1; // used to know where states should be copied from + int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) int32_t tail = -1; std::set seq_id; @@ -157,13 +157,13 @@ class llama_kv_cache_recurrent_state : public llama_memory_state_i { uint32_t get_n_kv() const; uint32_t get_head() const; + int32_t get_rs_z() const; uint32_t get_size() const; ggml_tensor * get_k_l(int32_t il) const; ggml_tensor * get_v_l(int32_t il) const; int32_t s_copy(int i) const; - float s_mask(int i) const; private: const llama_memory_status status; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 3a40463fd29ca..79994e655d0d0 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -512,8 +512,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { head_cur = 0; } - // otherwise, one cell per token. - if (n_tokens > cells.size()) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); return -1; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c41ee24507fca..0919929309b35 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8857,7 +8857,6 @@ struct llm_build_mamba : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8866,8 +8865,7 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il); - cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il); + cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -8908,7 +8906,6 @@ struct llm_build_mamba : public llm_graph_context { ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -8935,12 +8932,12 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states - ggml_tensor * conv = build_copy_mask_state( - gf, conv_states_all, state_copy, state_mask, + ggml_tensor * conv = build_recurrent_state( + gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_copy_mask_state( - gf, ssm_states_all, state_copy, state_mask, + ggml_tensor * ssm = build_recurrent_state( + gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); @@ -11656,7 +11653,6 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -11780,8 +11776,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_state->get_v_l(il), state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11837,7 +11833,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11848,7 +11843,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -11864,7 +11859,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -11935,7 +11930,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11946,7 +11940,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -11959,7 +11953,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12051,7 +12045,6 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { @@ -12134,8 +12127,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_state->get_v_l(il), state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12193,7 +12186,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12204,7 +12196,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -12220,7 +12212,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12287,7 +12279,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12298,7 +12289,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -12311,7 +12302,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));