Skip to content

Commit 985d559

Browse files
committed
Update llama.cpp
1 parent 231123e commit 985d559

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

llama_cpp/llama_cpp.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ class llama_token_data_array(Structure):
159159

160160

161161
# struct llama_context_params {
162-
# uint32_t seed; // RNG seed, -1 for random
163-
# int32_t n_ctx; // text context
164-
# int32_t n_batch; // prompt processing batch size
165-
# int32_t n_gpu_layers; // number of layers to store in VRAM
166-
# int32_t main_gpu; // the GPU that is used for scratch and small tensors
162+
# uint32_t seed; // RNG seed, -1 for random
163+
# int32_t n_ctx; // text context
164+
# int32_t n_batch; // prompt processing batch size
165+
# int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
166+
# int32_t n_gpu_layers; // number of layers to store in VRAM
167+
# int32_t main_gpu; // the GPU that is used for scratch and small tensors
168+
#
167169
# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
168170

169171
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -190,6 +192,7 @@ class llama_context_params(Structure):
190192
("seed", c_uint32),
191193
("n_ctx", c_int32),
192194
("n_batch", c_int32),
195+
("n_gqa", c_int32),
193196
("n_gpu_layers", c_int32),
194197
("main_gpu", c_int32),
195198
("tensor_split", POINTER(c_float)),
@@ -265,6 +268,57 @@ class llama_model_quantize_params(Structure):
265268
]
266269

267270

271+
# // grammar types
272+
# struct llama_grammar;
273+
llama_grammar_p = c_void_p
274+
275+
# // grammar element type
276+
# enum llama_gretype {
277+
# // end of rule definition
278+
# LLAMA_GRETYPE_END = 0,
279+
280+
# // start of alternate definition for rule
281+
# LLAMA_GRETYPE_ALT = 1,
282+
283+
# // non-terminal element: reference to rule
284+
# LLAMA_GRETYPE_RULE_REF = 2,
285+
286+
# // terminal element: character (code point)
287+
# LLAMA_GRETYPE_CHAR = 3,
288+
289+
# // inverse char(s) ([^a], [^a-b] [^abc])
290+
# LLAMA_GRETYPE_CHAR_NOT = 4,
291+
292+
# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
293+
# // be an inclusive range ([a-z])
294+
# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
295+
296+
# // modifies a preceding LLAMA_GRETYPE_CHAR or
297+
# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
298+
# LLAMA_GRETYPE_CHAR_ALT = 6,
299+
# };
300+
LLAMA_GRETYPE_END = c_int(0)
301+
LLAMA_GRETYPE_ALT = c_int(1)
302+
LLAMA_GRETYPE_RULE_REF = c_int(2)
303+
LLAMA_GRETYPE_CHAR = c_int(3)
304+
LLAMA_GRETYPE_CHAR_NOT = c_int(4)
305+
LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int(5)
306+
LLAMA_GRETYPE_CHAR_ALT = c_int(6)
307+
308+
309+
# typedef struct llama_grammar_element {
310+
# enum llama_gretype type;
311+
# uint32_t value; // Unicode code point or rule ID
312+
# } llama_grammar_element;
313+
class llama_grammar_element(Structure):
314+
_fields_ = [
315+
("type", c_int),
316+
("value", c_uint32),
317+
]
318+
319+
320+
llama_grammar_element_p = POINTER(llama_grammar_element)
321+
268322
# // performance timing information
269323
# struct llama_timings {
270324
# double t_start_ms;
@@ -871,6 +925,37 @@ def llama_token_nl() -> int:
871925
_lib.llama_token_nl.restype = llama_token
872926

873927

928+
# // Grammar
929+
# //
930+
# LLAMA_API struct llama_grammar * llama_grammar_init(
931+
# const llama_grammar_element ** rules,
932+
# size_t n_rules,
933+
# size_t start_rule_index);
934+
def llama_grammar_init(
935+
rules, # type: Array[llama_grammar_element_p] # type: ignore
936+
n_rules: c_size_t,
937+
start_rule_index: c_size_t,
938+
) -> llama_grammar_p:
939+
return _lib.llama_grammar_init(rules, n_rules, start_rule_index)
940+
941+
942+
_lib.llama_grammar_init.argtypes = [
943+
POINTER(llama_grammar_element_p),
944+
c_size_t,
945+
c_size_t,
946+
]
947+
_lib.llama_grammar_init.restype = llama_grammar_p
948+
949+
950+
# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
951+
def llama_grammar_free(grammar: llama_grammar_p):
952+
return _lib.llama_grammar_free(grammar)
953+
954+
955+
_lib.llama_grammar_free.argtypes = [llama_grammar_p]
956+
_lib.llama_grammar_free.restype = None
957+
958+
874959
# Sampling functions
875960

876961

vendor/llama.cpp

0 commit comments

Comments
 (0)