@@ -399,7 +399,7 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
399
399
break ;
400
400
}
401
401
402
- // remeber the position that we found
402
+ // remember the position that we found
403
403
res.push_back (head_new);
404
404
405
405
// store the old state of the cells in the recovery stack
@@ -3037,11 +3037,93 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
3037
3037
//
3038
3038
// llama_kv_cache_hybrid
3039
3039
//
3040
- llama_kv_cache_hybrid::llama_kv_cache_hybrid (
3041
- const llama_hparams & hparams,
3042
- std::vector<child_cache> children) :
3043
- m_hparams(hparams),
3044
- m_children(
3040
+
3041
+
3042
+ class llama_kv_cache_hybrid_decode_state_t : public llama_memory_decode_state_i {
3043
+ public:
3044
+ explicit llama_kv_cache_hybrid_decode_state_t (
3045
+ std::vector<llama_memory_decode_state_ptr> decode_states) :
3046
+ status([](const std::vector<llama_memory_decode_state_ptr> & decode_states) -> llama_memory_status {
3047
+ for (const auto & decode_state : decode_states) {
3048
+ if (!decode_state) {
3049
+ return LLAMA_MEMORY_STATUS_FAILED_PREPARE;
3050
+ }
3051
+ const auto & status = decode_state->get_status ();
3052
+ if (status != LLAMA_MEMORY_STATUS_SUCCESS) {
3053
+ return status;
3054
+ }
3055
+ }
3056
+ return LLAMA_MEMORY_STATUS_SUCCESS;
3057
+ }(decode_states)),
3058
+ decode_states (std::move(decode_states)) {
3059
+
3060
+ // make sure at least one decode state
3061
+ assert (!decode_states.empty ());
3062
+
3063
+ // make sure all out_ids match across states
3064
+ // TODO: This could be expensive, so maybe don't do it?
3065
+ const auto & out_ids = decode_states[0 ]->out_ids ();
3066
+ for (size_t i = 1 ; i < decode_states.size (); ++i) {
3067
+ const auto & out_ids_i = decode_states[i]->out_ids ();
3068
+ assert (out_ids.size () == out_ids_i.size ());
3069
+ for (size_t j = 0 ; j < out_ids.size (); ++j) {
3070
+ assert (out_ids[j] == out_ids_i[j]);
3071
+ }
3072
+ }
3073
+ }
3074
+
3075
+ ~llama_kv_cache_hybrid_decode_state_t () = default ;
3076
+
3077
+ llama_ubatch * next () override {
3078
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3079
+
3080
+ // hit next on each child
3081
+ std::vector<llama_ubatch *> next_ubatches;
3082
+ for (const auto & decode_state : decode_states) {
3083
+ next_ubatches.push_back (decode_state->next ());
3084
+ }
3085
+
3086
+ // make sure they all match
3087
+ // TODO: unnecessary safety?
3088
+ llama_ubatch * res = next_ubatches[0 ];
3089
+ assert (res);
3090
+ for (size_t i = 1 ; i < next_ubatches.size (); ++i) {
3091
+ llama_ubatch * ubatch_i = next_ubatches[i];
3092
+ assert (ubatch_i);
3093
+ assert (ubatch_i->n_tokens == res->n_tokens );
3094
+ assert (ubatch_i->n_seq_tokens == res->n_seq_tokens );
3095
+ assert (ubatch_i->n_seqs == res->n_seqs );
3096
+ for (size_t j = 0 ; j < res->n_tokens ; ++j) {
3097
+ assert (ubatch_i->token [j] == res->token [j]);
3098
+ assert (ubatch_i->pos [j] == res->pos [j]);
3099
+ assert (ubatch_i->output [j] == res->output [j]);
3100
+ }
3101
+ for (size_t j = 0 ; j < res->n_seqs ; ++j) {
3102
+ assert (ubatch_i->n_seq_id [j] == res->n_seq_id [j]);
3103
+ }
3104
+ }
3105
+
3106
+ // return the first ubatch since they all match
3107
+ return res;
3108
+ }
3109
+
3110
+ std::vector<int64_t > & out_ids () override {
3111
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3112
+
3113
+ return decode_states[0 ]->out_ids ();
3114
+ }
3115
+
3116
+ llama_memory_status get_status () const override {
3117
+ return status;
3118
+ }
3119
+
3120
+ private:
3121
+ const llama_memory_status status;
3122
+ std::vector<llama_memory_decode_state_ptr> decode_states;
3123
+ };
3124
+
3125
+ llama_kv_cache_hybrid::llama_kv_cache_hybrid (std::vector<child_cache> children_) :
3126
+ children(
3045
3127
[](std::vector<child_cache>& caches) -> std::set<std::unique_ptr<llama_kv_cache>> {
3046
3128
// Sort the caches by the lowest layer ID so the order is repeatable
3047
3129
for (auto & cache : caches) {
@@ -3056,26 +3138,26 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
3056
3138
unique_caches.emplace (cache.child .release ());
3057
3139
}
3058
3140
return unique_caches;
3059
- }(children )
3141
+ }(children_ )
3060
3142
),
3061
- m_has_recurrent (
3143
+ has_recurrent (
3062
3144
[](const std::set<std::unique_ptr<llama_kv_cache>> & caches) -> bool {
3063
3145
for (const auto & cache : caches) {
3064
3146
if (dynamic_cast <llama_kv_cache_recurrent *>(cache.get ())) {
3065
3147
return true ;
3066
3148
}
3067
3149
}
3068
3150
return false ;
3069
- }(m_children )
3151
+ }(children )
3070
3152
)
3071
3153
{
3072
3154
// Ensure at least one child
3073
- GGML_ASSERT (m_children .size () > 0 );
3155
+ GGML_ASSERT (children .size () > 0 );
3074
3156
3075
3157
// Ensure layers are not overlapping and are concurrent
3076
3158
std::set<size_t > seen_layers;
3077
3159
size_t max_layer = 0 ;
3078
- for (const auto & cache : children ) {
3160
+ for (const auto & cache : children_ ) {
3079
3161
for (const auto & layer_id : cache.layer_ids ) {
3080
3162
GGML_ASSERT (seen_layers.find (layer_id) == seen_layers.end ());
3081
3163
seen_layers.insert (layer_id);
@@ -3089,7 +3171,7 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid(
3089
3171
}
3090
3172
3091
3173
void llama_kv_cache_hybrid::clear () {
3092
- for (const auto & cache : m_children ) {
3174
+ for (const auto & cache : children ) {
3093
3175
cache->clear ();
3094
3176
}
3095
3177
}
@@ -3102,40 +3184,40 @@ bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
3102
3184
}
3103
3185
3104
3186
// Do the removal from each child which should never fail
3105
- for (const auto & cache : m_children ) {
3187
+ for (const auto & cache : children ) {
3106
3188
const bool failed = cache->seq_rm (seq_id, p0, p1);
3107
3189
GGML_ASSERT (!failed);
3108
3190
}
3109
3191
return true ;
3110
3192
}
3111
3193
3112
3194
void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3113
- for (const auto & cache : m_children ) {
3195
+ for (const auto & cache : children ) {
3114
3196
cache->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3115
3197
}
3116
3198
}
3117
3199
3118
3200
void llama_kv_cache_hybrid::seq_keep (llama_seq_id seq_id) {
3119
- for (const auto & cache : m_children ) {
3201
+ for (const auto & cache : children ) {
3120
3202
cache->seq_keep (seq_id);
3121
3203
}
3122
3204
}
3123
3205
3124
3206
void llama_kv_cache_hybrid::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
3125
- for (const auto & cache : m_children ) {
3207
+ for (const auto & cache : children ) {
3126
3208
cache->seq_add (seq_id, p0, p1, delta);
3127
3209
}
3128
3210
}
3129
3211
3130
3212
void llama_kv_cache_hybrid::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3131
- for (const auto & cache : m_children ) {
3213
+ for (const auto & cache : children ) {
3132
3214
cache->seq_div (seq_id, p0, p1, d);
3133
3215
}
3134
3216
}
3135
3217
3136
3218
llama_pos llama_kv_cache_hybrid::seq_pos_min (llama_seq_id seq_id) const {
3137
3219
llama_pos min_pos = -1 ;
3138
- for (const auto & cache : m_children ) {
3220
+ for (const auto & cache : children ) {
3139
3221
const auto child_min_pos = cache->seq_pos_min (seq_id);
3140
3222
min_pos = min_pos == -1 ? child_min_pos : std::min (min_pos, child_min_pos);
3141
3223
}
@@ -3144,81 +3226,67 @@ llama_pos llama_kv_cache_hybrid::seq_pos_min(llama_seq_id seq_id) const {
3144
3226
3145
3227
llama_pos llama_kv_cache_hybrid::seq_pos_max (llama_seq_id seq_id) const {
3146
3228
llama_pos max_pos = 0 ;
3147
- for (const auto & cache : m_children ) {
3229
+ for (const auto & cache : children ) {
3148
3230
max_pos = std::max (max_pos, cache->seq_pos_max (seq_id));
3149
3231
}
3150
3232
return max_pos;
3151
3233
}
3152
3234
3153
- void llama_kv_cache_hybrid::restore () {
3154
- for (const auto & cache : m_children) {
3155
- cache->restore ();
3156
- }
3157
- }
3235
+ llama_memory_decode_state_ptr llama_kv_cache_hybrid::init (
3236
+ const llama_batch & batch,
3237
+ uint32_t n_ubatch,
3238
+ bool embd_pooled,
3239
+ bool logits_all,
3240
+ bool split_equal) {
3158
3241
3159
- void llama_kv_cache_hybrid::commit () {
3160
- for (const auto & cache : m_children) {
3161
- cache->commit ();
3242
+ // recurrent children require equal splits
3243
+ // TODO: just ignore this if set incorrectly?
3244
+ assert (!has_recurrent || split_equal);
3245
+
3246
+ // init all children and capture their decode states
3247
+ std::vector<llama_memory_decode_state_ptr> decode_states;
3248
+ for (const auto & child : children) {
3249
+ decode_states.emplace_back (
3250
+ child->init (batch, n_ubatch, embd_pooled, logits_all, split_equal));
3162
3251
}
3252
+
3253
+ // return the hybrid decode state
3254
+ return std::make_unique<llama_kv_cache_hybrid_decode_state_t >(std::move (decode_states));
3163
3255
}
3164
3256
3165
3257
bool llama_kv_cache_hybrid::update (llama_context & ctx) {
3166
3258
bool updated = false ;
3167
- for (const auto & cache : m_children ) {
3259
+ for (const auto & cache : children ) {
3168
3260
updated = cache->update (ctx) || updated;
3169
3261
}
3170
3262
return updated;
3171
3263
}
3172
3264
3173
3265
void llama_kv_cache_hybrid::defrag_sched (float thold) {
3174
- for (const auto & cache : m_children ) {
3266
+ for (const auto & cache : children ) {
3175
3267
cache->defrag_sched (thold);
3176
3268
}
3177
3269
}
3178
3270
3179
3271
void llama_kv_cache_hybrid::set_full () {
3180
- for (const auto & cache : m_children ) {
3272
+ for (const auto & cache : children ) {
3181
3273
cache->set_full ();
3182
3274
}
3183
3275
}
3184
3276
3185
3277
bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
3186
- for (const auto & cache : m_children ) {
3278
+ for (const auto & cache : children ) {
3187
3279
if (!cache->can_seq_rm (seq_id, p0, p1)) {
3188
3280
return false ;
3189
3281
}
3190
3282
}
3191
3283
return true ;
3192
3284
}
3193
3285
3194
- llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
3195
- // If any of the caches are recurrent, require equal split
3196
- return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
3197
- }
3198
-
3199
- llama_ubatch llama_kv_cache_hybrid::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
3200
- if (embd_pooled) {
3201
- // Pooled embeddings cannot be split across ubatches (yet)
3202
- return sbatch.split_seq (n_ubatch);
3203
- }
3204
- if (m_has_recurrent) {
3205
- return sbatch.split_equal (n_ubatch);
3206
- }
3207
- return sbatch.split_simple (n_ubatch);
3208
- }
3209
-
3210
- bool llama_kv_cache_hybrid::find_slot (const llama_ubatch & batch) {
3211
- bool found = true ;
3212
- for (const auto & cache : m_children) {
3213
- found = cache->find_slot (batch) && found;
3214
- }
3215
- return found;
3216
- }
3217
-
3218
3286
bool llama_kv_cache_hybrid::get_can_shift () const {
3219
3287
// TODO: Is this correct?
3220
3288
// If any children can shift, return true
3221
- for (const auto & cache : m_children ) {
3289
+ for (const auto & cache : children ) {
3222
3290
if (cache->get_can_shift ()) {
3223
3291
return true ;
3224
3292
}
@@ -3229,15 +3297,15 @@ bool llama_kv_cache_hybrid::get_can_shift() const {
3229
3297
void llama_kv_cache_hybrid::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3230
3298
// Write each cache state in order. Note that order is guaranteed at
3231
3299
// initialization by using an ordered set sorted by lowest layer ID
3232
- for (const auto & cache : m_children ) {
3300
+ for (const auto & cache : children ) {
3233
3301
cache->state_write (io, seq_id);
3234
3302
}
3235
3303
}
3236
3304
3237
3305
void llama_kv_cache_hybrid::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3238
3306
// Read each cache state in order. Note that order is guaranteed at
3239
3307
// initialization by using an ordered set sorted by lowest layer ID
3240
- for (const auto & cache : m_children ) {
3308
+ for (const auto & cache : children ) {
3241
3309
cache->state_read (io, seq_id);
3242
3310
}
3243
3311
}
0 commit comments