4
4
#include < cstring>
5
5
#include < algorithm>
6
6
7
+ void llama_ubatch::update () {
8
+ if (equal_seqs) {
9
+ // TODO: for now don't compute min/max for recurrent batches since we don't need this.
10
+ // the batches will be refactored anyway, so we'll fix this later
11
+ return ;
12
+ }
13
+
14
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
15
+ const llama_seq_id s = seq_id[i][0 ];
16
+
17
+ seq_pos_min[s] = seq_pos_min[s] == -1 ? pos[i] : std::min (seq_pos_min[s], pos[i]);
18
+ seq_pos_max[s] = seq_pos_max[s] == -1 ? pos[i] : std::max (seq_pos_max[s], pos[i]);
19
+ }
20
+ }
21
+
7
22
llama_ubatch llama_sbatch::reserve_ubatch (size_t n_ubatch, bool has_embd) {
8
23
// clear empty sequences
9
24
// the previous ubatch is assumed to be gone,
@@ -15,24 +30,33 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
15
30
break ;
16
31
}
17
32
}
18
- ubatch_token.resize (!has_embd ? n_ubatch : 0 );
19
- ubatch_embd.resize (has_embd ? n_embd * n_ubatch : 0 );
20
- ubatch_pos.resize (n_ubatch);
21
- ubatch_n_seq_id.resize (n_ubatch);
22
- ubatch_seq_id.resize (n_ubatch);
23
- ubatch_output.resize (n_ubatch);
33
+
34
+ udatas.push_back ({});
35
+
36
+ auto & udata = udatas.back ();
37
+
38
+ udata.token .resize (!has_embd ? n_ubatch : 0 );
39
+ udata.embd .resize (has_embd ? n_embd * n_ubatch : 0 );
40
+ udata.pos .resize (n_ubatch);
41
+ udata.n_seq_id .resize (n_ubatch);
42
+ udata.seq_id .resize (n_ubatch);
43
+ udata.output .resize (n_ubatch);
44
+
24
45
llama_ubatch ubatch = {
25
46
/* equal_seqs =*/ true ,
26
47
/* n_tokens =*/ 0 ,
27
48
/* n_seq_tokens =*/ 0 ,
28
49
/* n_seqs =*/ 0 ,
29
- /* token =*/ !has_embd ? ubatch_token.data () : nullptr ,
30
- /* embd =*/ has_embd ? ubatch_embd.data () : nullptr ,
31
- /* pos =*/ ubatch_pos.data (),
32
- /* n_seq_id =*/ ubatch_n_seq_id.data (),
33
- /* seq_id =*/ ubatch_seq_id.data (),
34
- /* output =*/ ubatch_output.data (),
50
+ /* seq_pos_min =*/ {-1 },
51
+ /* seq_pos_max =*/ {-1 },
52
+ /* token =*/ !has_embd ? udata.token .data () : nullptr ,
53
+ /* embd =*/ has_embd ? udata.embd .data () : nullptr ,
54
+ /* pos =*/ udata.pos .data (),
55
+ /* n_seq_id =*/ udata.n_seq_id .data (),
56
+ /* seq_id =*/ udata.seq_id .data (),
57
+ /* output =*/ udata.output .data (),
35
58
};
59
+
36
60
return ubatch;
37
61
}
38
62
@@ -148,6 +172,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
148
172
GGML_ASSERT (seq.size () == 1 && s.n_seq_id == 0 ); // don't mix with other splits
149
173
add_seq_to_ubatch (ubatch, s, length);
150
174
}
175
+ ubatch.update ();
151
176
return ubatch;
152
177
}
153
178
@@ -175,6 +200,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
175
200
if (length + n_tokens_in_ubatch > n_ubatch) { break ; }
176
201
}
177
202
}
203
+ ubatch.update ();
178
204
return ubatch;
179
205
}
180
206
@@ -187,6 +213,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
187
213
GGML_ASSERT (s.n_seq_id > 0 ); // should not be mixed with simple splits
188
214
add_seq_to_ubatch (ubatch, s, length);
189
215
}
216
+ ubatch.update ();
190
217
return ubatch;
191
218
}
192
219
0 commit comments