File tree 2 files changed +14
-5
lines changed 2 files changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -554,16 +554,18 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
554
554
555
555
bool found = true ;
556
556
for (uint32_t i = 0 ; i < n_tokens; i++) {
557
- // TODO: improve to accept cells that are masked by the SWA
558
- // if (!cells.is_empty(head_cur + i)) {
559
-
560
557
const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
561
558
559
+ // can we use this cell? either:
560
+ // - the cell is empty
561
+ // - the cell is occupied only by the same sequence, and the sequence is not masked
562
562
const bool can_use =
563
563
cells.is_empty (head_cur + i) ||
564
564
(
565
- cells.seq_has (head_cur + i, seq_id) && // TODO: seq_has_only
566
- is_masked_swa (cells.pos_get (head_cur + i), ubatch.seq_pos_min [seq_id])
565
+ cells.pos_get (head_cur + i) <= ubatch.pos [i] && // causal mask
566
+ cells.seq_has (head_cur + i, seq_id) && // sequence mask
567
+ cells.seq_count (head_cur + i) == 1 &&
568
+ is_masked_swa (cells.pos_get (head_cur + i), ubatch.seq_pos_min [seq_id]) // SWA mask
567
569
);
568
570
569
571
if (!can_use) {
Original file line number Diff line number Diff line change @@ -155,6 +155,13 @@ class llama_kv_cells_unified {
155
155
return false ;
156
156
}
157
157
158
+ int seq_count (uint32_t i) const {
159
+ assert (i < pos.size ());
160
+ assert (pos[i] != -1 );
161
+
162
+ return seq[i].count ();
163
+ }
164
+
158
165
bool seq_has (uint32_t i, llama_seq_id seq_id) const {
159
166
assert (i < pos.size ());
160
167
assert (seq_id >= 0 );
You can’t perform that action at this time.
0 commit comments