Skip to content

Commit 0a8cdc3

Browse files
committed
llama : add llama_max_parallel_sequences()
ggml-ci
1 parent 44856a7 commit 0a8cdc3

File tree

5 files changed

+16
-4
lines changed

5 files changed

+16
-4
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ extern "C" {
471471
LLAMA_API int64_t llama_time_us(void);
472472

473473
LLAMA_API size_t llama_max_devices(void);
474+
LLAMA_API size_t llama_max_parallel_sequences(void);
474475

475476
LLAMA_API bool llama_supports_mmap (void);
476477
LLAMA_API bool llama_supports_mlock (void);

src/llama-context.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ llama_context::llama_context(
2525

2626
const auto & hparams = model.hparams;
2727

28-
cparams.n_seq_max = std::max(1u, params.n_seq_max);
28+
cparams.n_seq_max = std::max(1u, params.n_seq_max);
29+
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
30+
LLAMA_LOG_WARN("%s: n_seq_max (%d) is larger than the maximum supported (%d) - clamping\n", __func__, cparams.n_seq_max, LLAMA_MAX_PARALLEL_SEQUENCES);
31+
cparams.n_seq_max = LLAMA_MAX_PARALLEL_SEQUENCES;
32+
}
33+
2934
cparams.n_threads = params.n_threads;
3035
cparams.n_threads_batch = params.n_threads_batch;
3136
cparams.yarn_ext_factor = params.yarn_ext_factor;

src/llama-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <string>
66
#include <vector>
77

8+
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
9+
810
#ifdef __GNUC__
911
# if defined(__MINGW32__) && !defined(__clang__)
1012
# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))

src/llama-kv-cells.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "llama.h"
3+
#include "llama-impl.h"
44

55
#include <bitset>
66
#include <cassert>
@@ -119,7 +119,7 @@ class llama_kv_cells_unified {
119119
seq[i].reset(seq_id);
120120

121121
if (seq[i].none()) {
122-
pos[i]= -1;
122+
pos[i] = -1;
123123

124124
used--;
125125

@@ -267,6 +267,6 @@ class llama_kv_cells_unified {
267267
std::vector<llama_pos> shift;
268268

269269
// TODO: assert n_seq_max <= 64
270-
std::vector<std::bitset<64>> seq;
270+
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
271271
};
272272

src/llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ size_t llama_max_devices(void) {
3737
return 16;
3838
}
3939

40+
size_t llama_max_parallel_sequences(void) {
41+
return LLAMA_MAX_PARALLEL_SEQUENCES;
42+
}
43+
4044
bool llama_supports_mmap(void) {
4145
return llama_mmap::SUPPORTED;
4246
}

0 commit comments

Comments
 (0)