Skip to content

Commit 7004323

Browse files
committed
rwkv : speed-up tokenization using trie
1 parent 7f2ef56 commit 7004323

File tree

1 file changed

+33
-31
lines changed

1 file changed

+33
-31
lines changed

src/llama-vocab.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ struct naive_trie {
5858
auto res = children.find(c);
5959
if (res != children.end()) {
6060
return res->second.get_longest_prefix(key, len, offset + 1);
61-
} else {
62-
return std::make_pair(key, offset);
6361
}
62+
63+
return std::make_pair(key, offset);
6464
}
65-
struct naive_trie * traverse(const char c) {
65+
const struct naive_trie * traverse(const char c) const {
6666
auto res = children.find(c);
6767
if (res != children.end()) {
6868
return &res->second;
69-
} else {
70-
return NULL;
7169
}
70+
71+
return NULL;
7272
}
7373
std::map<char, struct naive_trie> children;
7474
bool has_value;
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
843843
// traverse the token matcher trie to find a matching token
844844
bool single_codepoint_token_found = false;
845845
const struct best_tokenization & current_best = tokenization_results[input_offset];
846-
struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
846+
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
847847

848848
while (prefix_offset <= input_len && node != NULL) {
849849
// check if we found valid token in prefix
@@ -1103,6 +1103,7 @@ struct llm_tokenizer_ugm {
11031103

11041104
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
11051105
std::vector<uint8_t> output;
1106+
output.reserve(escaped.size());
11061107

11071108
// Parser state
11081109
bool escaping = false;
@@ -1158,46 +1159,47 @@ struct llm_tokenizer_rwkv {
11581159
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
11591160
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
11601161
// For now, we decode the vocab here into the lookup we'll use for tokenization.
1161-
for (const auto & token : vocab.id_to_token) {
1162-
auto data = llama_unescape_rwkv_token(token.text);
1163-
tokens.push_back(data);
1162+
1163+
// build trie
1164+
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
1165+
const auto & token = vocab.id_to_token[id];
1166+
const auto data = llama_unescape_rwkv_token(token.text);
1167+
token_matcher.insert((const char *) data.data(), data.size(), id);
11641168
}
11651169
}
11661170

11671171
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
11681172
uint32_t position = 0;
11691173

11701174
while (position < text.size()) {
1171-
// Iterate through possible tokens backwards, starting with the largest
1172-
for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) {
1173-
// Skip tokens that aren't normal type, we can't match on those
1174-
if (!(vocab.id_to_token[i].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
1175-
continue;
1176-
}
1177-
1178-
uint32_t token_size = tokens[i].size();
1179-
1180-
// If there's not enough left for this token
1181-
if (text.size() - position < token_size) {
1182-
continue;
1183-
}
1175+
const struct naive_trie * node = token_matcher.traverse(text[position]);
1176+
if (node == NULL) {
1177+
// no matching token found, add unknown token
1178+
output.push_back(vocab.special_unk_id);
1179+
position += 1;
1180+
continue;
1181+
}
11841182

1185-
// If the token doesn't match the data
1186-
if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) {
1187-
continue;
1183+
// traverse the trie to find the longest matching token
1184+
uint32_t token_id = 0;
1185+
uint32_t token_length = 0;
1186+
while (node != NULL) {
1187+
if (node->has_value) {
1188+
token_id = node->value;
1189+
token_length = position + 1;
11881190
}
1189-
1190-
// Add the token and advance
1191-
output.push_back(i);
1192-
position += token_size;
1193-
break;
1191+
node = node->traverse(text[++position]);
11941192
}
1193+
1194+
// add the longest matching token
1195+
output.push_back(token_id);
1196+
position = token_length;
11951197
}
11961198
}
11971199

11981200
const llama_vocab & vocab;
11991201

1200-
std::vector<std::vector<uint8_t>> tokens;
1202+
struct naive_trie token_matcher;
12011203
};
12021204

12031205
//

0 commit comments

Comments
 (0)