@@ -429,10 +429,12 @@ class llama_batch(ctypes.Structure):
429
429
The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
430
430
431
431
Attributes:
432
+ n_tokens (int): number of tokens
432
433
token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
433
434
embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
434
435
pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
435
436
seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
437
+ logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
436
438
"""
437
439
438
440
_fields_ = [
@@ -547,6 +549,7 @@ class llama_model_params(ctypes.Structure):
547
549
# uint32_t seed; // RNG seed, -1 for random
548
550
# uint32_t n_ctx; // text context, 0 = from model
549
551
# uint32_t n_batch; // prompt processing maximum batch size
552
+ # uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
550
553
# uint32_t n_threads; // number of threads to use for generation
551
554
# uint32_t n_threads_batch; // number of threads to use for batch processing
552
555
@@ -588,6 +591,7 @@ class llama_context_params(ctypes.Structure):
588
591
seed (int): RNG seed, -1 for random
589
592
n_ctx (int): text context, 0 = from model
590
593
n_batch (int): prompt processing maximum batch size
594
+ n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
591
595
n_threads (int): number of threads to use for generation
592
596
n_threads_batch (int): number of threads to use for batch processing
593
597
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
@@ -615,6 +619,7 @@ class llama_context_params(ctypes.Structure):
615
619
("seed" , ctypes .c_uint32 ),
616
620
("n_ctx" , ctypes .c_uint32 ),
617
621
("n_batch" , ctypes .c_uint32 ),
622
+ ("n_parallel" , ctypes .c_uint32 ),
618
623
("n_threads" , ctypes .c_uint32 ),
619
624
("n_threads_batch" , ctypes .c_uint32 ),
620
625
("rope_scaling_type" , ctypes .c_int ),
@@ -1322,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
1322
1327
# // seq_id < 0 : match any sequence
1323
1328
# // p0 < 0 : [0, p1]
1324
1329
# // p1 < 0 : [p0, inf)
1325
- # LLAMA_API void llama_kv_cache_seq_rm(
1330
+ # LLAMA_API bool llama_kv_cache_seq_rm(
1326
1331
# struct llama_context * ctx,
1327
1332
# llama_seq_id seq_id,
1328
1333
# llama_pos p0,
@@ -1335,15 +1340,15 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
1335
1340
llama_pos ,
1336
1341
llama_pos ,
1337
1342
],
1338
- None ,
1343
+ ctypes . c_bool ,
1339
1344
)
1340
1345
def llama_kv_cache_seq_rm (
1341
1346
ctx : llama_context_p ,
1342
1347
seq_id : Union [llama_seq_id , int ],
1343
1348
p0 : Union [llama_pos , int ],
1344
1349
p1 : Union [llama_pos , int ],
1345
1350
/ ,
1346
- ):
1351
+ ) -> bool :
1347
1352
"""Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
1348
1353
seq_id < 0 : match any sequence
1349
1354
p0 < 0 : [0, p1]
@@ -1754,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
1754
1759
The logits for the last token are stored in the last row
1755
1760
Logits for which llama_batch.logits[i] == 0 are undefined
1756
1761
Rows: n_tokens provided with llama_batch
1757
- Cols: n_vocab"""
1762
+ Cols: n_vocab
1763
+
1764
+ Returns:
1765
+ Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
1758
1766
...
1759
1767
1760
1768
0 commit comments