@@ -58,17 +58,17 @@ struct naive_trie {
58
58
auto res = children.find (c);
59
59
if (res != children.end ()) {
60
60
return res->second .get_longest_prefix (key, len, offset + 1 );
61
- } else {
62
- return std::make_pair (key, offset);
63
61
}
62
+
63
+ return std::make_pair (key, offset);
64
64
}
65
- struct naive_trie * traverse (const char c) {
65
+ const struct naive_trie * traverse (const char c) const {
66
66
auto res = children.find (c);
67
67
if (res != children.end ()) {
68
68
return &res->second ;
69
- } else {
70
- return NULL ;
71
69
}
70
+
71
+ return NULL ;
72
72
}
73
73
std::map<char , struct naive_trie > children;
74
74
bool has_value;
@@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
843
843
// traverse the token matcher trie to find a matching token
844
844
bool single_codepoint_token_found = false ;
845
845
const struct best_tokenization & current_best = tokenization_results[input_offset];
846
- struct naive_trie * node = token_matcher.traverse (normalized[prefix_offset++]);
846
+ const struct naive_trie * node = token_matcher.traverse (normalized[prefix_offset++]);
847
847
848
848
while (prefix_offset <= input_len && node != NULL ) {
849
849
// check if we found valid token in prefix
@@ -1103,6 +1103,7 @@ struct llm_tokenizer_ugm {
1103
1103
1104
1104
static std::vector<uint8_t > llama_unescape_rwkv_token (const std::string & escaped) {
1105
1105
std::vector<uint8_t > output;
1106
+ output.reserve (escaped.size ());
1106
1107
1107
1108
// Parser state
1108
1109
bool escaping = false ;
@@ -1158,46 +1159,47 @@ struct llm_tokenizer_rwkv {
1158
1159
llm_tokenizer_rwkv (const llama_vocab & vocab): vocab(vocab) {
1159
1160
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
1160
1161
// For now, we decode the vocab here into the lookup we'll use for tokenization.
1161
- for (const auto & token : vocab.id_to_token ) {
1162
- auto data = llama_unescape_rwkv_token (token.text );
1163
- tokens.push_back (data);
1162
+
1163
+ // build trie
1164
+ for (unsigned int id = 0 ; id < vocab.id_to_token .size (); ++id) {
1165
+ const auto & token = vocab.id_to_token [id];
1166
+ const auto data = llama_unescape_rwkv_token (token.text );
1167
+ token_matcher.insert ((const char *) data.data (), data.size (), id);
1164
1168
}
1165
1169
}
1166
1170
1167
1171
void tokenize (const std::string & text, std::vector<llama_vocab::id> & output) {
1168
1172
uint32_t position = 0 ;
1169
1173
1170
1174
while (position < text.size ()) {
1171
- // Iterate through possible tokens backwards, starting with the largest
1172
- for (int32_t i = (int32_t )tokens.size () - 1 ; i >= 0 ; i--) {
1173
- // Skip tokens that aren't normal type, we can't match on those
1174
- if (!(vocab.id_to_token [i].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
1175
- continue ;
1176
- }
1177
-
1178
- uint32_t token_size = tokens[i].size ();
1179
-
1180
- // If there's not enough left for this token
1181
- if (text.size () - position < token_size) {
1182
- continue ;
1183
- }
1175
+ const struct naive_trie * node = token_matcher.traverse (text[position]);
1176
+ if (node == NULL ) {
1177
+ // no matching token found, add unknown token
1178
+ output.push_back (vocab.special_unk_id );
1179
+ position += 1 ;
1180
+ continue ;
1181
+ }
1184
1182
1185
- // If the token doesn't match the data
1186
- if (std::memcmp (text.data () + position, tokens[i].data (), token_size) != 0 ) {
1187
- continue ;
1183
+ // traverse the trie to find the longest matching token
1184
+ uint32_t token_id = 0 ;
1185
+ uint32_t token_length = 0 ;
1186
+ while (node != NULL ) {
1187
+ if (node->has_value ) {
1188
+ token_id = node->value ;
1189
+ token_length = position + 1 ;
1188
1190
}
1189
-
1190
- // Add the token and advance
1191
- output.push_back (i);
1192
- position += token_size;
1193
- break ;
1191
+ node = node->traverse (text[++position]);
1194
1192
}
1193
+
1194
+ // add the longest matching token
1195
+ output.push_back (token_id);
1196
+ position = token_length;
1195
1197
}
1196
1198
}
1197
1199
1198
1200
const llama_vocab & vocab;
1199
1201
1200
- std::vector<std::vector< uint8_t >> tokens ;
1202
+ struct naive_trie token_matcher ;
1201
1203
};
1202
1204
1203
1205
//
0 commit comments