File tree 3 files changed +29
-4
lines changed 3 files changed +29
-4
lines changed Original file line number Diff line number Diff line change @@ -136,7 +136,10 @@ class llama_kv_cache_unified_state : public llama_kv_cache_unified_state_i {
136
136
std::vector<uint32_t > heads;
137
137
std::vector<llama_ubatch> ubatches;
138
138
139
+ //
139
140
// data needed for building the compute graph for the current ubatch:
141
+ //
142
+
140
143
// a heuristic, to avoid attending the full cache if it is not yet utilized
141
144
// as the cache gets filled, the benefit from this heuristic disappears
142
145
int32_t n_kv;
@@ -1876,7 +1879,10 @@ class llama_kv_cache_unified_iswa_state : public llama_kv_cache_unified_iswa_sta
1876
1879
1877
1880
std::vector<llama_ubatch> ubatches;
1878
1881
1882
+ //
1879
1883
// data needed for building the compute graph for the current ubatch:
1884
+ //
1885
+
1880
1886
int32_t n_kv_base;
1881
1887
int32_t head_base;
1882
1888
@@ -2123,7 +2129,7 @@ class llama_kv_cache_recurrent_state_t : public llama_kv_cache_recurrent_state_i
2123
2129
return kv->s_copy (i);
2124
2130
}
2125
2131
2126
- float s_mask (int i) const override {
2132
+ float s_mask (int i) const override {
2127
2133
return kv->s_mask (i);
2128
2134
}
2129
2135
@@ -2132,13 +2138,18 @@ class llama_kv_cache_recurrent_state_t : public llama_kv_cache_recurrent_state_i
2132
2138
2133
2139
llama_kv_cache_recurrent * kv;
2134
2140
2135
- const bool is_full = false ;
2136
-
2137
2141
llama_sbatch sbatch;
2138
2142
2139
2143
size_t i_next = 0 ;
2140
2144
2141
2145
std::vector<llama_ubatch> ubatches;
2146
+
2147
+ //
2148
+ // data needed for building the compute graph for the current ubatch:
2149
+ // TODO: extract all the state like `head` and `n` here
2150
+ //
2151
+
2152
+ const bool is_full = false ;
2142
2153
};
2143
2154
2144
2155
llama_kv_cache_recurrent::llama_kv_cache_recurrent (
Original file line number Diff line number Diff line change @@ -40,6 +40,9 @@ struct llama_kv_cache : public llama_memory_i {
40
40
virtual bool update (llama_context & lctx) = 0;
41
41
42
42
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
43
+ // TODO: change to
44
+ // llama_memory_state_ptr init_defrag(float thold) = 0;
45
+ //
43
46
virtual void defrag_sched (float thold) = 0;
44
47
45
48
// getters
@@ -253,7 +256,7 @@ class llama_kv_cache_unified_state_i : public llama_memory_state_i {
253
256
virtual ggml_tensor * get_k (ggml_context * ctx, int32_t il) const = 0;
254
257
virtual ggml_tensor * get_v (ggml_context * ctx, int32_t il) const = 0;
255
258
256
- // store k_cur and v_cur in the cache based on the current head location
259
+ // store k_cur and v_cur in the cache based on the provided head location
257
260
virtual ggml_tensor * cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const = 0;
258
261
virtual ggml_tensor * cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const = 0;
259
262
@@ -359,6 +362,8 @@ class llama_kv_cache_unified_iswa_state_i : public llama_memory_state_i {
359
362
// llama_kv_cache_recurrent
360
363
//
361
364
365
+ // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
366
+ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
362
367
class llama_kv_cache_recurrent : public llama_kv_cache {
363
368
public:
364
369
llama_kv_cache_recurrent (
Original file line number Diff line number Diff line change @@ -42,6 +42,15 @@ enum llama_memory_status {
42
42
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
43
43
};
44
44
45
+ // the interface for managing the memory state during batch processing
46
+ // this interface is extended per memory type with specific methods used for constructing the compute graphs. see:
47
+ // - llama_kv_cache_unified_state_i
48
+ // - llama_kv_cache_unified_iswa_state_i
49
+ // ...
50
+ //
51
+ // these extended interfaces should not mutate neither the memory, nor the current memory state
52
+ // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
53
+ //
45
54
class llama_memory_state_i {
46
55
public:
47
56
virtual ~llama_memory_state_i () = default ;
You can’t perform that action at this time.
0 commit comments