@@ -159,11 +159,13 @@ class llama_token_data_array(Structure):
159
159
160
160
161
161
# struct llama_context_params {
162
- # uint32_t seed; // RNG seed, -1 for random
163
- # int32_t n_ctx; // text context
164
- # int32_t n_batch; // prompt processing batch size
165
- # int32_t n_gpu_layers; // number of layers to store in VRAM
166
- # int32_t main_gpu; // the GPU that is used for scratch and small tensors
162
+ # uint32_t seed; // RNG seed, -1 for random
163
+ # int32_t n_ctx; // text context
164
+ # int32_t n_batch; // prompt processing batch size
165
+ # int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
166
+ # int32_t n_gpu_layers; // number of layers to store in VRAM
167
+ # int32_t main_gpu; // the GPU that is used for scratch and small tensors
168
+ #
167
169
# const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
168
170
169
171
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
@@ -190,6 +192,7 @@ class llama_context_params(Structure):
190
192
("seed" , c_uint32 ),
191
193
("n_ctx" , c_int32 ),
192
194
("n_batch" , c_int32 ),
195
+ ("n_gqa" , c_int32 ),
193
196
("n_gpu_layers" , c_int32 ),
194
197
("main_gpu" , c_int32 ),
195
198
("tensor_split" , POINTER (c_float )),
@@ -265,6 +268,57 @@ class llama_model_quantize_params(Structure):
265
268
]
266
269
267
270
271
+ # // grammar types
272
+ # struct llama_grammar;
273
+ llama_grammar_p = c_void_p
274
+
275
+ # // grammar element type
276
+ # enum llama_gretype {
277
+ # // end of rule definition
278
+ # LLAMA_GRETYPE_END = 0,
279
+
280
+ # // start of alternate definition for rule
281
+ # LLAMA_GRETYPE_ALT = 1,
282
+
283
+ # // non-terminal element: reference to rule
284
+ # LLAMA_GRETYPE_RULE_REF = 2,
285
+
286
+ # // terminal element: character (code point)
287
+ # LLAMA_GRETYPE_CHAR = 3,
288
+
289
+ # // inverse char(s) ([^a], [^a-b] [^abc])
290
+ # LLAMA_GRETYPE_CHAR_NOT = 4,
291
+
292
+ # // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
293
+ # // be an inclusive range ([a-z])
294
+ # LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
295
+
296
+ # // modifies a preceding LLAMA_GRETYPE_CHAR or
297
+ # // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
298
+ # LLAMA_GRETYPE_CHAR_ALT = 6,
299
+ # };
300
+ LLAMA_GRETYPE_END = c_int (0 )
301
+ LLAMA_GRETYPE_ALT = c_int (1 )
302
+ LLAMA_GRETYPE_RULE_REF = c_int (2 )
303
+ LLAMA_GRETYPE_CHAR = c_int (3 )
304
+ LLAMA_GRETYPE_CHAR_NOT = c_int (4 )
305
+ LLAMA_GRETYPE_CHAR_RNG_UPPER = c_int (5 )
306
+ LLAMA_GRETYPE_CHAR_ALT = c_int (6 )
307
+
308
+
309
+ # typedef struct llama_grammar_element {
310
+ # enum llama_gretype type;
311
+ # uint32_t value; // Unicode code point or rule ID
312
+ # } llama_grammar_element;
313
+ class llama_grammar_element (Structure ):
314
+ _fields_ = [
315
+ ("type" , c_int ),
316
+ ("value" , c_uint32 ),
317
+ ]
318
+
319
+
320
+ llama_grammar_element_p = POINTER (llama_grammar_element )
321
+
268
322
# // performance timing information
269
323
# struct llama_timings {
270
324
# double t_start_ms;
@@ -871,6 +925,37 @@ def llama_token_nl() -> int:
871
925
_lib .llama_token_nl .restype = llama_token
872
926
873
927
928
+ # // Grammar
929
+ # //
930
+ # LLAMA_API struct llama_grammar * llama_grammar_init(
931
+ # const llama_grammar_element ** rules,
932
+ # size_t n_rules,
933
+ # size_t start_rule_index);
934
+ def llama_grammar_init (
935
+ rules , # type: Array[llama_grammar_element_p] # type: ignore
936
+ n_rules : c_size_t ,
937
+ start_rule_index : c_size_t ,
938
+ ) -> llama_grammar_p :
939
+ return _lib .llama_grammar_init (rules , n_rules , start_rule_index )
940
+
941
+
942
+ _lib .llama_grammar_init .argtypes = [
943
+ POINTER (llama_grammar_element_p ),
944
+ c_size_t ,
945
+ c_size_t ,
946
+ ]
947
+ _lib .llama_grammar_init .restype = llama_grammar_p
948
+
949
+
950
+ # LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
951
+ def llama_grammar_free (grammar : llama_grammar_p ):
952
+ return _lib .llama_grammar_free (grammar )
953
+
954
+
955
+ _lib .llama_grammar_free .argtypes = [llama_grammar_p ]
956
+ _lib .llama_grammar_free .restype = None
957
+
958
+
874
959
# Sampling functions
875
960
876
961
0 commit comments