Skip to content

Commit 40c6b54

Browse files
committed
feat: Update llama.cpp
1 parent 93dc56a commit 40c6b54

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

llama_cpp/llama_cpp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,12 @@ class llama_batch(ctypes.Structure):
429429
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
430430
431431
Attributes:
432+
n_tokens (int): number of tokens
432433
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
433434
embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
434435
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
435436
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
437+
logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
436438
"""
437439

438440
_fields_ = [
@@ -547,6 +549,7 @@ class llama_model_params(ctypes.Structure):
547549
# uint32_t seed; // RNG seed, -1 for random
548550
# uint32_t n_ctx; // text context, 0 = from model
549551
# uint32_t n_batch; // prompt processing maximum batch size
552+
# uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
550553
# uint32_t n_threads; // number of threads to use for generation
551554
# uint32_t n_threads_batch; // number of threads to use for batch processing
552555

@@ -588,6 +591,7 @@ class llama_context_params(ctypes.Structure):
588591
seed (int): RNG seed, -1 for random
589592
n_ctx (int): text context, 0 = from model
590593
n_batch (int): prompt processing maximum batch size
594+
n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
591595
n_threads (int): number of threads to use for generation
592596
n_threads_batch (int): number of threads to use for batch processing
593597
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
@@ -615,6 +619,7 @@ class llama_context_params(ctypes.Structure):
615619
("seed", ctypes.c_uint32),
616620
("n_ctx", ctypes.c_uint32),
617621
("n_batch", ctypes.c_uint32),
622+
("n_parallel", ctypes.c_uint32),
618623
("n_threads", ctypes.c_uint32),
619624
("n_threads_batch", ctypes.c_uint32),
620625
("rope_scaling_type", ctypes.c_int),
@@ -1322,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
13221327
# // seq_id < 0 : match any sequence
13231328
# // p0 < 0 : [0, p1]
13241329
# // p1 < 0 : [p0, inf)
1325-
# LLAMA_API void llama_kv_cache_seq_rm(
1330+
# LLAMA_API bool llama_kv_cache_seq_rm(
13261331
# struct llama_context * ctx,
13271332
# llama_seq_id seq_id,
13281333
# llama_pos p0,
@@ -1335,15 +1340,15 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
13351340
llama_pos,
13361341
llama_pos,
13371342
],
1338-
None,
1343+
ctypes.c_bool,
13391344
)
13401345
def llama_kv_cache_seq_rm(
13411346
ctx: llama_context_p,
13421347
seq_id: Union[llama_seq_id, int],
13431348
p0: Union[llama_pos, int],
13441349
p1: Union[llama_pos, int],
13451350
/,
1346-
):
1351+
) -> bool:
13471352
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
13481353
seq_id < 0 : match any sequence
13491354
p0 < 0 : [0, p1]
@@ -1754,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
17541759
The logits for the last token are stored in the last row
17551760
Logits for which llama_batch.logits[i] == 0 are undefined
17561761
Rows: n_tokens provided with llama_batch
1757-
Cols: n_vocab"""
1762+
Cols: n_vocab
1763+
1764+
Returns:
1765+
Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
17581766
...
17591767

17601768

vendor/llama.cpp

0 commit comments

Comments
 (0)