Skip to content

Commit d7063d7

Browse files
committed
refactor the substring search function
1 parent cf0a43d commit d7063d7

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

examples/server/server.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,19 +2216,12 @@ struct server_context {
22162216
search_pos = slot.generated_text.size() - search_len;
22172217
}
22182218

2219-
auto found_pos = slot.generated_text.npos;
2220-
bool found = false;
2221-
std::string found_string;
2222-
for (auto start_string: slot.params.start_strings) {
2223-
found_pos = slot.generated_text.find(start_string,search_pos);
2224-
if (found_pos != slot.generated_text.npos) {
2225-
found = true;
2226-
found_string = start_string;
2227-
break;
2228-
}
2229-
}
2219+
std::pair<size_t, int> search_result = find_first_substring(slot.generated_text,slot.params.start_strings, search_pos);
2220+
bool found = search_result.first != std::string::npos;
22302221

2231-
if (found && slot.generated_text.size() > (found_pos + found_string.size()) ) {
2222+
if (found) {
2223+
auto found_pos = search_result.first;
2224+
std::string found_string = slot.params.start_strings[search_result.second];
22322225
slot.generated_text.erase(
22332226
slot.generated_text.begin(),
22342227
slot.generated_text.begin() + found_pos + found_string.size());

examples/server/utils.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,25 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
489489
return std::string::npos;
490490
}
491491

492+
/* returns a pair containing the position and index of first element found.
493+
returns npos and -1 if not found */
494+
static std::pair<size_t, int> find_first_substring(const std::string &haystack, const std::vector<std::string> & needles, size_t search_pos = 0)
495+
{
496+
size_t found_pos = std::string::npos;
497+
int found_idx = -1;
498+
499+
for (unsigned int i = 0; i < needles.size(); ++i) {
500+
const std::string & start_string = needles[i];
501+
found_pos = haystack.find(start_string,search_pos);
502+
if (found_pos != std::string::npos) {
503+
found_idx = i;
504+
break;
505+
}
506+
}
507+
508+
return std::pair<size_t, int>(found_pos,found_idx);
509+
}
510+
492511
// TODO: reuse llama_detokenize
493512
template <class Iter>
494513
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {

0 commit comments

Comments
 (0)