Skip to content

Commit 604b8bd

Browse files
authored
Fix unicode in grammars (fixes #2501) (#2553)
* Fix unicode in grammars (fixes #2501) * add more comments * fix test-llama-grammar
1 parent 10151be commit 604b8bd

File tree

2 files changed

+133
-26
lines changed

2 files changed

+133
-26
lines changed

llama.cpp

Lines changed: 132 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,37 +2077,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
20772077
// grammar - internal
20782078
//
20792079

2080+
struct llama_partial_utf8 {
2081+
uint32_t value; // bit value so far (unshifted)
2082+
int n_remain; // num bytes remaining; -1 indicates invalid sequence
2083+
};
2084+
20802085
struct llama_grammar {
20812086
const std::vector<std::vector<llama_grammar_element>> rules;
20822087
std::vector<std::vector<const llama_grammar_element *>> stacks;
2088+
2089+
// buffer for partially generated UTF-8 sequence from accepted tokens
2090+
llama_partial_utf8 partial_utf8;
20832091
};
20842092

20852093
struct llama_grammar_candidate {
2086-
size_t index;
2087-
const uint32_t * code_points;
2094+
size_t index;
2095+
const uint32_t * code_points;
2096+
llama_partial_utf8 partial_utf8;
20882097
};
20892098

2090-
// NOTE: assumes valid utf8 (but checks for overrun)
2091-
// adds a terminating 0 for use as pointer
2092-
std::vector<uint32_t> decode_utf8(const char * src) {
2093-
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
2099+
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
2100+
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
2101+
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
2102+
const char * src,
2103+
llama_partial_utf8 partial_start) {
2104+
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
20942105
const char * pos = src;
20952106
std::vector<uint32_t> code_points;
2107+
uint32_t value = partial_start.value;
2108+
int n_remain = partial_start.n_remain;
2109+
2110+
// continue previous decode, if applicable
2111+
while (*pos != 0 && n_remain > 0) {
2112+
uint8_t next_byte = static_cast<uint8_t>(*pos);
2113+
if ((next_byte >> 6) != 2) {
2114+
// invalid sequence, abort
2115+
code_points.push_back(0);
2116+
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
2117+
}
2118+
value = (value << 6) + (next_byte & 0x3F);
2119+
++pos;
2120+
--n_remain;
2121+
}
2122+
2123+
if (partial_start.n_remain > 0 && n_remain == 0) {
2124+
code_points.push_back(value);
2125+
}
2126+
2127+
// decode any subsequent utf-8 sequences, which may end in an incomplete one
20962128
while (*pos != 0) {
20972129
uint8_t first_byte = static_cast<uint8_t>(*pos);
20982130
uint8_t highbits = first_byte >> 4;
2099-
int len = lookup[highbits];
2100-
uint8_t mask = (1 << (8 - len)) - 1;
2101-
uint32_t value = first_byte & mask;
2102-
const char * end = pos + len; // may overrun!
2131+
n_remain = lookup[highbits] - 1;
2132+
2133+
if (n_remain < 0) {
2134+
// invalid sequence, abort
2135+
code_points.clear();
2136+
code_points.push_back(0);
2137+
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
2138+
}
2139+
2140+
uint8_t mask = (1 << (7 - n_remain)) - 1;
2141+
value = first_byte & mask;
21032142
++pos;
2104-
for ( ; pos < end && *pos != 0; ++pos) {
2143+
while (*pos != 0 && n_remain > 0) {
21052144
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
2145+
++pos;
2146+
--n_remain;
2147+
}
2148+
if (n_remain == 0) {
2149+
code_points.push_back(value);
21062150
}
2107-
code_points.push_back(value);
21082151
}
21092152
code_points.push_back(0);
2110-
return code_points;
2153+
2154+
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
21112155
}
21122156

21132157
// returns true iff pos points to the end of one of the definitions of a rule
@@ -2144,6 +2188,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
21442188
return std::make_pair(found == is_positive_char, pos);
21452189
}
21462190

2191+
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
2192+
// range at pos (regular or inverse range)
2193+
// asserts that pos is pointing to a char range element
2194+
static bool llama_grammar_match_partial_char(
2195+
const llama_grammar_element * pos,
2196+
const llama_partial_utf8 partial_utf8) {
2197+
2198+
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
2199+
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
2200+
2201+
uint32_t partial_value = partial_utf8.value;
2202+
int n_remain = partial_utf8.n_remain;
2203+
2204+
// invalid sequence or 7-bit char split across 2 bytes (overlong)
2205+
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
2206+
return false;
2207+
}
2208+
2209+
// range of possible code points this partial UTF-8 sequence could complete to
2210+
uint32_t low = partial_value << (n_remain * 6);
2211+
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
2212+
2213+
if (low == 0) {
2214+
if (n_remain == 2) {
2215+
low = 1 << 11;
2216+
} else if (n_remain == 3) {
2217+
low = 1 << 16;
2218+
}
2219+
}
2220+
2221+
do {
2222+
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
2223+
// inclusive range, e.g. [a-z]
2224+
if (pos->value <= high && low <= pos[1].value) {
2225+
return is_positive_char;
2226+
}
2227+
pos += 2;
2228+
} else {
2229+
// exact char match, e.g. [a] or "a"
2230+
if (low <= pos->value && pos->value <= high) {
2231+
return is_positive_char;
2232+
}
2233+
pos += 1;
2234+
}
2235+
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
2236+
2237+
return !is_positive_char;
2238+
}
2239+
2240+
21472241
// transforms a grammar pushdown stack into N possible stacks, all ending
21482242
// at a character range (terminal element)
21492243
static void llama_grammar_advance_stack(
@@ -2244,19 +2338,27 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
22442338
std::vector<llama_grammar_candidate> rejects;
22452339

22462340
if (stack.empty()) {
2247-
// accept nothing; EOS is handled elsewhere
2248-
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
2341+
for (auto tok : candidates) {
2342+
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
2343+
rejects.push_back(tok);
2344+
}
2345+
}
22492346
return rejects;
22502347
}
22512348

22522349
const llama_grammar_element * stack_pos = stack.back();
22532350

22542351
std::vector<llama_grammar_candidate> next_candidates;
22552352
for (auto tok : candidates) {
2256-
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
2257-
if (tok.code_points[1] != 0) {
2258-
next_candidates.push_back({ tok.index, tok.code_points + 1 });
2353+
if (*tok.code_points == 0) {
2354+
// reached end of full codepoints in token, reject iff it ended in a partial sequence
2355+
// that cannot satisfy this position in grammar
2356+
if (tok.partial_utf8.n_remain != 0 &&
2357+
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
2358+
rejects.push_back(tok);
22592359
}
2360+
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
2361+
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
22602362
} else {
22612363
rejects.push_back(tok);
22622364
}
@@ -2274,7 +2376,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
22742376

22752377
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
22762378
for (auto tok : next_rejects) {
2277-
rejects.push_back({ tok.index, tok.code_points - 1 });
2379+
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
22782380
}
22792381

22802382
return rejects;
@@ -2339,7 +2441,7 @@ struct llama_grammar * llama_grammar_init(
23392441
}
23402442
} while (true);
23412443

2342-
return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
2444+
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
23432445
}
23442446

23452447
void llama_grammar_free(struct llama_grammar * grammar) {
@@ -2645,8 +2747,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
26452747

26462748
const llama_token eos = llama_token_eos();
26472749

2648-
std::vector<std::vector<uint32_t>> candidates_decoded;
2649-
std::vector<llama_grammar_candidate> candidates_grammar;
2750+
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
2751+
std::vector<llama_grammar_candidate> candidates_grammar;
26502752

26512753
for (size_t i = 0; i < candidates->size; ++i) {
26522754
const llama_token id = candidates->data[i].id;
@@ -2658,8 +2760,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
26582760
} else if (*str == 0) {
26592761
candidates->data[i].logit = -INFINITY;
26602762
} else {
2661-
candidates_decoded.push_back(decode_utf8(str));
2662-
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
2763+
candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8));
2764+
candidates_grammar.push_back({
2765+
i, candidates_decoded.back().first.data(), candidates_decoded.back().second
2766+
});
26632767
}
26642768
}
26652769

@@ -2860,11 +2964,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
28602964
}
28612965

28622966
const char * str = llama_token_to_str(ctx, token);
2967+
28632968
// Note terminating 0 in decoded string
2864-
auto code_points = decode_utf8(str);
2969+
const auto decoded = decode_utf8(str, grammar->partial_utf8);
2970+
const auto & code_points = decoded.first;
28652971
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
28662972
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
28672973
}
2974+
grammar->partial_utf8 = decoded.second;
28682975
LLAMA_ASSERT(!grammar->stacks.empty());
28692976

28702977
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;

tests/test-llama-grammar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ int main()
199199
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
200200
cp[0] = 37 + i;
201201
cp[1] = 0;
202-
next_candidates[i] = {i, cp};
202+
next_candidates[i] = {i, cp, {}};
203203
}
204204

205205
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {

0 commit comments

Comments
 (0)