@@ -2077,37 +2077,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
2077
2077
// grammar - internal
2078
2078
//
2079
2079
2080
+ struct llama_partial_utf8 {
2081
+ uint32_t value; // bit value so far (unshifted)
2082
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
2083
+ };
2084
+
2080
2085
struct llama_grammar {
2081
2086
const std::vector<std::vector<llama_grammar_element>> rules;
2082
2087
std::vector<std::vector<const llama_grammar_element *>> stacks;
2088
+
2089
+ // buffer for partially generated UTF-8 sequence from accepted tokens
2090
+ llama_partial_utf8 partial_utf8;
2083
2091
};
2084
2092
2085
2093
struct llama_grammar_candidate {
2086
- size_t index;
2087
- const uint32_t * code_points;
2094
+ size_t index;
2095
+ const uint32_t * code_points;
2096
+ llama_partial_utf8 partial_utf8;
2088
2097
};
2089
2098
2090
- // NOTE: assumes valid utf8 (but checks for overrun)
2091
- // adds a terminating 0 for use as pointer
2092
- std::vector<uint32_t > decode_utf8 (const char * src) {
2093
- static const int lookup[] = { 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 3 , 4 };
2099
+ // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
2100
+ // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
2101
+ std::pair<std::vector<uint32_t >, llama_partial_utf8> decode_utf8 (
2102
+ const char * src,
2103
+ llama_partial_utf8 partial_start) {
2104
+ static const int lookup[] = { 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 , 2 , 2 , 3 , 4 };
2094
2105
const char * pos = src;
2095
2106
std::vector<uint32_t > code_points;
2107
+ uint32_t value = partial_start.value ;
2108
+ int n_remain = partial_start.n_remain ;
2109
+
2110
+ // continue previous decode, if applicable
2111
+ while (*pos != 0 && n_remain > 0 ) {
2112
+ uint8_t next_byte = static_cast <uint8_t >(*pos);
2113
+ if ((next_byte >> 6 ) != 2 ) {
2114
+ // invalid sequence, abort
2115
+ code_points.push_back (0 );
2116
+ return std::make_pair (std::move (code_points), llama_partial_utf8{ 0 , -1 });
2117
+ }
2118
+ value = (value << 6 ) + (next_byte & 0x3F );
2119
+ ++pos;
2120
+ --n_remain;
2121
+ }
2122
+
2123
+ if (partial_start.n_remain > 0 && n_remain == 0 ) {
2124
+ code_points.push_back (value);
2125
+ }
2126
+
2127
+ // decode any subsequent utf-8 sequences, which may end in an incomplete one
2096
2128
while (*pos != 0 ) {
2097
2129
uint8_t first_byte = static_cast <uint8_t >(*pos);
2098
2130
uint8_t highbits = first_byte >> 4 ;
2099
- int len = lookup[highbits];
2100
- uint8_t mask = (1 << (8 - len)) - 1 ;
2101
- uint32_t value = first_byte & mask;
2102
- const char * end = pos + len; // may overrun!
2131
+ n_remain = lookup[highbits] - 1 ;
2132
+
2133
+ if (n_remain < 0 ) {
2134
+ // invalid sequence, abort
2135
+ code_points.clear ();
2136
+ code_points.push_back (0 );
2137
+ return std::make_pair (std::move (code_points), llama_partial_utf8{ 0 , n_remain });
2138
+ }
2139
+
2140
+ uint8_t mask = (1 << (7 - n_remain)) - 1 ;
2141
+ value = first_byte & mask;
2103
2142
++pos;
2104
- for ( ; pos < end && *pos != 0 ; ++pos ) {
2143
+ while (* pos != 0 && n_remain > 0 ) {
2105
2144
value = (value << 6 ) + (static_cast <uint8_t >(*pos) & 0x3F );
2145
+ ++pos;
2146
+ --n_remain;
2147
+ }
2148
+ if (n_remain == 0 ) {
2149
+ code_points.push_back (value);
2106
2150
}
2107
- code_points.push_back (value);
2108
2151
}
2109
2152
code_points.push_back (0 );
2110
- return code_points;
2153
+
2154
+ return std::make_pair (std::move (code_points), llama_partial_utf8{ value, n_remain });
2111
2155
}
2112
2156
2113
2157
// returns true iff pos points to the end of one of the definitions of a rule
@@ -2144,6 +2188,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
2144
2188
return std::make_pair (found == is_positive_char, pos);
2145
2189
}
2146
2190
2191
+ // returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
2192
+ // range at pos (regular or inverse range)
2193
+ // asserts that pos is pointing to a char range element
2194
+ static bool llama_grammar_match_partial_char (
2195
+ const llama_grammar_element * pos,
2196
+ const llama_partial_utf8 partial_utf8) {
2197
+
2198
+ bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
2199
+ LLAMA_ASSERT (is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
2200
+
2201
+ uint32_t partial_value = partial_utf8.value ;
2202
+ int n_remain = partial_utf8.n_remain ;
2203
+
2204
+ // invalid sequence or 7-bit char split across 2 bytes (overlong)
2205
+ if (n_remain < 0 || (n_remain == 1 && partial_value < 2 )) {
2206
+ return false ;
2207
+ }
2208
+
2209
+ // range of possible code points this partial UTF-8 sequence could complete to
2210
+ uint32_t low = partial_value << (n_remain * 6 );
2211
+ uint32_t high = low | ((1 << (n_remain * 6 )) - 1 );
2212
+
2213
+ if (low == 0 ) {
2214
+ if (n_remain == 2 ) {
2215
+ low = 1 << 11 ;
2216
+ } else if (n_remain == 3 ) {
2217
+ low = 1 << 16 ;
2218
+ }
2219
+ }
2220
+
2221
+ do {
2222
+ if (pos[1 ].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
2223
+ // inclusive range, e.g. [a-z]
2224
+ if (pos->value <= high && low <= pos[1 ].value ) {
2225
+ return is_positive_char;
2226
+ }
2227
+ pos += 2 ;
2228
+ } else {
2229
+ // exact char match, e.g. [a] or "a"
2230
+ if (low <= pos->value && pos->value <= high) {
2231
+ return is_positive_char;
2232
+ }
2233
+ pos += 1 ;
2234
+ }
2235
+ } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
2236
+
2237
+ return !is_positive_char;
2238
+ }
2239
+
2240
+
2147
2241
// transforms a grammar pushdown stack into N possible stacks, all ending
2148
2242
// at a character range (terminal element)
2149
2243
static void llama_grammar_advance_stack (
@@ -2244,19 +2338,27 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
2244
2338
std::vector<llama_grammar_candidate> rejects;
2245
2339
2246
2340
if (stack.empty ()) {
2247
- // accept nothing; EOS is handled elsewhere
2248
- rejects.insert (rejects.end (), candidates.begin (), candidates.end ());
2341
+ for (auto tok : candidates) {
2342
+ if (*tok.code_points != 0 || tok.partial_utf8 .n_remain != 0 ) {
2343
+ rejects.push_back (tok);
2344
+ }
2345
+ }
2249
2346
return rejects;
2250
2347
}
2251
2348
2252
2349
const llama_grammar_element * stack_pos = stack.back ();
2253
2350
2254
2351
std::vector<llama_grammar_candidate> next_candidates;
2255
2352
for (auto tok : candidates) {
2256
- if (llama_grammar_match_char (stack_pos, tok.code_points [0 ]).first ) {
2257
- if (tok.code_points [1 ] != 0 ) {
2258
- next_candidates.push_back ({ tok.index , tok.code_points + 1 });
2353
+ if (*tok.code_points == 0 ) {
2354
+ // reached end of full codepoints in token, reject iff it ended in a partial sequence
2355
+ // that cannot satisfy this position in grammar
2356
+ if (tok.partial_utf8 .n_remain != 0 &&
2357
+ !llama_grammar_match_partial_char (stack_pos, tok.partial_utf8 )) {
2358
+ rejects.push_back (tok);
2259
2359
}
2360
+ } else if (llama_grammar_match_char (stack_pos, *tok.code_points ).first ) {
2361
+ next_candidates.push_back ({ tok.index , tok.code_points + 1 , tok.partial_utf8 });
2260
2362
} else {
2261
2363
rejects.push_back (tok);
2262
2364
}
@@ -2274,7 +2376,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
2274
2376
2275
2377
auto next_rejects = llama_grammar_reject_candidates (rules, next_stacks, next_candidates);
2276
2378
for (auto tok : next_rejects) {
2277
- rejects.push_back ({ tok.index , tok.code_points - 1 });
2379
+ rejects.push_back ({ tok.index , tok.code_points - 1 , tok. partial_utf8 });
2278
2380
}
2279
2381
2280
2382
return rejects;
@@ -2339,7 +2441,7 @@ struct llama_grammar * llama_grammar_init(
2339
2441
}
2340
2442
} while (true );
2341
2443
2342
- return new llama_grammar{ std::move (vec_rules), std::move (stacks) };
2444
+ return new llama_grammar{ std::move (vec_rules), std::move (stacks), {} };
2343
2445
}
2344
2446
2345
2447
void llama_grammar_free (struct llama_grammar * grammar) {
@@ -2645,8 +2747,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
2645
2747
2646
2748
const llama_token eos = llama_token_eos ();
2647
2749
2648
- std::vector<std::vector<uint32_t >> candidates_decoded;
2649
- std::vector<llama_grammar_candidate> candidates_grammar;
2750
+ std::vector<std::pair<std:: vector<uint32_t >, llama_partial_utf8>> candidates_decoded;
2751
+ std::vector<llama_grammar_candidate> candidates_grammar;
2650
2752
2651
2753
for (size_t i = 0 ; i < candidates->size ; ++i) {
2652
2754
const llama_token id = candidates->data [i].id ;
@@ -2658,8 +2760,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
2658
2760
} else if (*str == 0 ) {
2659
2761
candidates->data [i].logit = -INFINITY;
2660
2762
} else {
2661
- candidates_decoded.push_back (decode_utf8 (str));
2662
- candidates_grammar.push_back ({ i, candidates_decoded.back ().data () });
2763
+ candidates_decoded.push_back (decode_utf8 (str, grammar->partial_utf8 ));
2764
+ candidates_grammar.push_back ({
2765
+ i, candidates_decoded.back ().first .data (), candidates_decoded.back ().second
2766
+ });
2663
2767
}
2664
2768
}
2665
2769
@@ -2860,11 +2964,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
2860
2964
}
2861
2965
2862
2966
const char * str = llama_token_to_str (ctx, token);
2967
+
2863
2968
// Note terminating 0 in decoded string
2864
- auto code_points = decode_utf8 (str);
2969
+ const auto decoded = decode_utf8 (str, grammar->partial_utf8 );
2970
+ const auto & code_points = decoded.first ;
2865
2971
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
2866
2972
grammar->stacks = llama_grammar_accept (grammar->rules , grammar->stacks , *it);
2867
2973
}
2974
+ grammar->partial_utf8 = decoded.second ;
2868
2975
LLAMA_ASSERT (!grammar->stacks .empty ());
2869
2976
2870
2977
ctx->t_sample_us += ggml_time_us () - t_start_sample_us;
0 commit comments