Skip to content

Commit f1b5b6b

Browse files
committed
tts : text pre-processing
1 parent 8f34d0d commit f1b5b6b

File tree

1 file changed

+146
-19
lines changed

1 file changed

+146
-19
lines changed

examples/tts/tts.cpp

Lines changed: 146 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
#define _USE_MATH_DEFINES // For M_PI on MSVC
88

99
#include <algorithm>
10-
#include <cstdio>
1110
#include <cmath>
12-
#include <string>
13-
#include <vector>
11+
#include <cstdio>
1412
#include <fstream>
13+
#include <map>
14+
#include <regex>
15+
#include <string>
1516
#include <thread>
17+
#include <vector>
1618

1719
//
1820
// Terminal utils
@@ -267,6 +269,143 @@ static std::vector<double> embd_to_audio(
267269
return audio;
268270
}
269271

272+
static const std::map<int, std::string> ones = {
273+
{0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
274+
{5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
275+
{10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
276+
{15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
277+
};
278+
279+
static const std::map<int, std::string> tens = {
280+
{2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
281+
{6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
282+
};
283+
284+
// Convert a number less than 1000 to words
285+
static std::string convert_less_than_thousand(int num) {
286+
std::string result;
287+
288+
if (num >= 100) {
289+
result += ones.at(num / 100) + " hundred ";
290+
num %= 100;
291+
}
292+
293+
if (num >= 20) {
294+
result += tens.at(num / 10);
295+
if (num % 10 > 0) {
296+
result += "-" + ones.at(num % 10);
297+
}
298+
} else if (num > 0) {
299+
result += ones.at(num);
300+
}
301+
302+
return result;
303+
}
304+
305+
static std::string number_to_words(const std::string & number_str) {
306+
try {
307+
size_t decimal_pos = number_str.find('.');
308+
std::string integer_part = number_str.substr(0, decimal_pos);
309+
310+
int int_number = std::stoi(integer_part);
311+
std::string result;
312+
313+
if (int_number == 0) {
314+
result = "zero";
315+
} else {
316+
if (int_number >= 1000000000) {
317+
int billions = int_number / 1000000000;
318+
result += convert_less_than_thousand(billions) + " billion ";
319+
int_number %= 1000000000;
320+
}
321+
322+
if (int_number >= 1000000) {
323+
int millions = int_number / 1000000;
324+
result += convert_less_than_thousand(millions) + " million ";
325+
int_number %= 1000000;
326+
}
327+
328+
if (int_number >= 1000) {
329+
int thousands = int_number / 1000;
330+
result += convert_less_than_thousand(thousands) + " thousand ";
331+
int_number %= 1000;
332+
}
333+
334+
if (int_number > 0) {
335+
result += convert_less_than_thousand(int_number);
336+
}
337+
}
338+
339+
// Handle decimal part
340+
if (decimal_pos != std::string::npos) {
341+
result += " point";
342+
std::string decimal_part = number_str.substr(decimal_pos + 1);
343+
for (char digit : decimal_part) {
344+
result += " " + ones.at(digit - '0');
345+
}
346+
}
347+
348+
return result;
349+
} catch (const std::exception& e) {
350+
// Skip if fails
351+
return " ";
352+
}
353+
}
354+
355+
static std::string replace_numbers_with_words(const std::string & input_text) {
356+
std::regex number_pattern(R"(\d+(\.\d+)?)");
357+
std::string result;
358+
auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
359+
auto end = std::sregex_iterator();
360+
361+
size_t last_pos = 0;
362+
for (std::sregex_iterator i = it; i != end; ++i) {
363+
const std::smatch& match = *i;
364+
result.append(input_text, last_pos, match.position() - last_pos);
365+
result.append(number_to_words(match.str()));
366+
last_pos = match.position() + match.length();
367+
}
368+
result.append(input_text, last_pos);
369+
370+
return result;
371+
}
372+
373+
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
374+
static std::string process_text(const std::string & text) {
375+
376+
// For now I skipped text romanization as I am unsure how to handle
377+
// uroman and MeCab implementations in C++
378+
// maybe something like https://github.com/anyascii/anyascii/ could work.
379+
// currently only English would be supported in this function
380+
381+
std::string processed_text = replace_numbers_with_words(text);
382+
383+
std::transform(processed_text.begin(), processed_text.end(),
384+
processed_text.begin(), ::tolower);
385+
386+
std::regex special_chars(R"([-_/,\.\\])");
387+
processed_text = std::regex_replace(processed_text, special_chars, " ");
388+
389+
std::regex non_alpha(R"([^a-z\s])");
390+
processed_text = std::regex_replace(processed_text, non_alpha, "");
391+
392+
std::regex multiple_spaces(R"(\s+)");
393+
processed_text = std::regex_replace(processed_text, multiple_spaces, " ");
394+
395+
processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), "");
396+
397+
/*
398+
Replace spaces with the separator token same as in line 365
399+
400+
for (auto & c : prompt_user) {
401+
if (c == ' ') {
402+
prompt_clean += "<|text_sep|>";
403+
*/
404+
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
405+
406+
return processed_text;
407+
}
408+
270409
static void prompt_add(llama_tokens & prompt, llama_token token) {
271410
prompt.push_back(token);
272411
}
@@ -353,23 +492,11 @@ int main(int argc, char ** argv) {
353492

354493
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
355494

356-
// TODO: not sure if this is correct
495+
// convert the input text into the necessary format expected by OuteTTS
357496
{
358-
std::string prompt_clean;
359-
std::string prompt_user = params.prompt;
360-
361-
for (auto & c : prompt_user) {
362-
if (c == ' ') {
363-
prompt_clean += "<|text_sep|>";
364-
} else {
365-
if (isalpha(c) || isdigit(c)) {
366-
c = tolower(c);
367-
} else {
368-
continue;
369-
}
370-
prompt_clean += c;
371-
}
372-
}
497+
std::string prompt_clean = process_text(params.prompt);
498+
499+
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
373500

374501
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
375502
}

0 commit comments

Comments
 (0)