Skip to content

Commit a56f0a9

Browse files
committed
context : initial need_reserve logic
ggml-ci
1 parent 71d1169 commit a56f0a9

File tree

3 files changed

+268
-244
lines changed

3 files changed

+268
-244
lines changed

src/llama-context.cpp

Lines changed: 165 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -576,9 +576,7 @@ ggml_tensor * llama_context::build_lora_mm_id(
576576
return res;
577577
}
578578

579-
bool llama_context::kv_self_update() {
580-
bool need_reserve = false;
581-
579+
void llama_context::kv_self_update() {
582580
auto & kv = kv_self;
583581

584582
if (kv.has_shift) {
@@ -655,12 +653,14 @@ bool llama_context::kv_self_update() {
655653

656654
ggml_free(ctx0);
657655

658-
need_reserve = true;
659-
660656
kv.do_defrag = false;
657+
658+
need_reserve = true;
661659
}
660+
}
662661

663-
return need_reserve;
662+
void llama_kv_self_update(llama_context * ctx) {
663+
ctx->kv_self_update();
664664
}
665665

666666
void llama_context::build_attn_inp(
@@ -1824,6 +1824,165 @@ int32_t llama_apply_adapter_cvec(
18241824
return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
18251825
}
18261826

1827+
//
1828+
// kv cache view
1829+
//
1830+
1831+
struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
1832+
return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
1833+
}
1834+
1835+
void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
1836+
llama_kv_cache_view_update(view, ctx->kv_self);
1837+
}
1838+
1839+
//
1840+
// kv cache
1841+
//
1842+
1843+
// deprecated
1844+
int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
1845+
return llama_kv_self_n_tokens(ctx);
1846+
}
1847+
1848+
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
1849+
return llama_kv_cache_n_tokens(&ctx->kv_self);
1850+
}
1851+
1852+
// deprecated
1853+
int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
1854+
return llama_kv_self_used_cells(ctx);
1855+
}
1856+
1857+
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
1858+
return llama_kv_cache_used_cells(&ctx->kv_self);
1859+
}
1860+
1861+
// deprecated
1862+
void llama_kv_cache_clear(llama_context * ctx) {
1863+
llama_kv_self_clear(ctx);
1864+
}
1865+
1866+
void llama_kv_self_clear(llama_context * ctx) {
1867+
llama_kv_cache_clear(&ctx->kv_self);
1868+
}
1869+
1870+
// deprecated
1871+
bool llama_kv_cache_seq_rm(
1872+
llama_context * ctx,
1873+
llama_seq_id seq_id,
1874+
llama_pos p0,
1875+
llama_pos p1) {
1876+
return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
1877+
}
1878+
1879+
bool llama_kv_self_seq_rm(
1880+
llama_context * ctx,
1881+
llama_seq_id seq_id,
1882+
llama_pos p0,
1883+
llama_pos p1) {
1884+
return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1);
1885+
}
1886+
1887+
// deprecated
1888+
void llama_kv_cache_seq_cp(
1889+
llama_context * ctx,
1890+
llama_seq_id seq_id_src,
1891+
llama_seq_id seq_id_dst,
1892+
llama_pos p0,
1893+
llama_pos p1) {
1894+
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
1895+
}
1896+
1897+
void llama_kv_self_seq_cp(
1898+
llama_context * ctx,
1899+
llama_seq_id seq_id_src,
1900+
llama_seq_id seq_id_dst,
1901+
llama_pos p0,
1902+
llama_pos p1) {
1903+
return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
1904+
}
1905+
1906+
// deprecated
1907+
void llama_kv_cache_seq_keep(
1908+
llama_context * ctx,
1909+
llama_seq_id seq_id) {
1910+
return llama_kv_self_seq_keep(ctx, seq_id);
1911+
}
1912+
1913+
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
1914+
return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id);
1915+
}
1916+
1917+
// deprecated
1918+
void llama_kv_cache_seq_add(
1919+
llama_context * ctx,
1920+
llama_seq_id seq_id,
1921+
llama_pos p0,
1922+
llama_pos p1,
1923+
llama_pos delta) {
1924+
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
1925+
}
1926+
1927+
void llama_kv_self_seq_add(
1928+
llama_context * ctx,
1929+
llama_seq_id seq_id,
1930+
llama_pos p0,
1931+
llama_pos p1,
1932+
llama_pos delta) {
1933+
return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta);
1934+
}
1935+
1936+
// deprecated
1937+
void llama_kv_cache_seq_div(
1938+
llama_context * ctx,
1939+
llama_seq_id seq_id,
1940+
llama_pos p0,
1941+
llama_pos p1,
1942+
int d) {
1943+
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
1944+
}
1945+
1946+
void llama_kv_self_seq_div(
1947+
llama_context * ctx,
1948+
llama_seq_id seq_id,
1949+
llama_pos p0,
1950+
llama_pos p1,
1951+
int d) {
1952+
return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d);
1953+
}
1954+
1955+
// deprecated
1956+
llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
1957+
return llama_kv_self_seq_pos_max(ctx, seq_id);
1958+
}
1959+
1960+
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
1961+
return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id);
1962+
}
1963+
1964+
// deprecated
1965+
void llama_kv_cache_defrag(llama_context * ctx) {
1966+
return llama_kv_self_defrag(ctx);
1967+
}
1968+
1969+
void llama_kv_self_defrag(llama_context * ctx) {
1970+
return llama_kv_cache_defrag(&ctx->kv_self);
1971+
}
1972+
1973+
// deprecated
1974+
bool llama_kv_cache_can_shift(const llama_context * ctx) {
1975+
return llama_kv_self_can_shift(ctx);
1976+
}
1977+
1978+
bool llama_kv_self_can_shift(const llama_context * ctx) {
1979+
return llama_kv_cache_can_shift(&ctx->kv_self);
1980+
}
1981+
1982+
// deprecated
1983+
void llama_kv_cache_update(llama_context * ctx) {
1984+
llama_kv_self_update(ctx);
1985+
}
18271986

18281987
// llama state API
18291988

src/llama-context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ struct llama_context {
6262
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
6363

6464
bool logits_all = false;
65+
bool need_reserve = false;
6566

6667
// embeddings output (2-dimensional array: [n_outputs][n_embd])
6768
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
@@ -87,6 +88,7 @@ struct llama_context {
8788
// max token position across all sequences in the current context
8889
llama_pos pos_max() const;
8990

91+
// certain implementations could require a padding for the context size
9092
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
9193

9294
void reset();
@@ -140,7 +142,7 @@ struct llama_context {
140142
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
141143

142144
// return true if need to reserve new worst-case graph
143-
bool kv_self_update();
145+
void kv_self_update();
144146

145147
void build_attn_inp(
146148
ggml_context * ctx0,

0 commit comments

Comments
 (0)