Skip to content

kv-cache : avoid modifying recurrent cells when setting inputs #13834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
8 changes: 5 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ extern "C" {
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;

enum llama_model_kv_override_type {
Expand Down Expand Up @@ -692,12 +692,14 @@ extern "C" {
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);

// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);

// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
// TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);

//
Expand Down
31 changes: 19 additions & 12 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
break;
}
}
ubatch_token.resize(!has_embd ? n_ubatch : 0);
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
ubatch_pos.resize(n_ubatch);
ubatch_n_seq_id.resize(n_ubatch);
ubatch_seq_id.resize(n_ubatch);
ubatch_output.resize(n_ubatch);

udatas.push_back({});

auto & udata = udatas.back();

udata.token.resize(!has_embd ? n_ubatch : 0);
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
udata.pos.resize(n_ubatch);
udata.n_seq_id.resize(n_ubatch);
udata.seq_id.resize(n_ubatch);
udata.output.resize(n_ubatch);

llama_ubatch ubatch = {
/*equal_seqs =*/ true,
/*n_tokens =*/ 0,
/*n_seq_tokens =*/ 0,
/*n_seqs =*/ 0,
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
/*pos =*/ ubatch_pos.data(),
/*n_seq_id =*/ ubatch_n_seq_id.data(),
/*seq_id =*/ ubatch_seq_id.data(),
/*output =*/ ubatch_output.data(),
/*token =*/ !has_embd ? udata.token.data() : nullptr,
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
/*pos =*/ udata.pos.data(),
/*n_seq_id =*/ udata.n_seq_id.data(),
/*seq_id =*/ udata.seq_id.data(),
/*output =*/ udata.output.data(),
};

return ubatch;
}

Expand Down
25 changes: 15 additions & 10 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?

uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;

llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int32_t * n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // [n_seqs] // TODO: become llama_seq_id * seq_id;
int8_t * output; // [n_tokens]
};

Expand Down Expand Up @@ -49,13 +49,18 @@ struct llama_sbatch {

const llama_batch * batch = nullptr;

// buffers for the ubatch
std::vector<llama_token> ubatch_token;
std::vector<float> ubatch_embd;
std::vector<llama_pos> ubatch_pos;
std::vector<int32_t> ubatch_n_seq_id;
std::vector<llama_seq_id *> ubatch_seq_id;
std::vector<int8_t> ubatch_output;
// buffers for the ubatches
// TODO: very hacky, this needs a complete rework
struct ubatch_data {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;
};

std::vector<ubatch_data> udatas;

llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);

Expand Down
Loading
Loading