|
7 | 7 | #define _USE_MATH_DEFINES // For M_PI on MSVC
|
8 | 8 |
|
9 | 9 | #include <algorithm>
|
10 |
| -#include <cstdio> |
11 | 10 | #include <cmath>
|
12 |
| -#include <string> |
13 |
| -#include <vector> |
| 11 | +#include <cstdio> |
14 | 12 | #include <fstream>
|
| 13 | +#include <map> |
| 14 | +#include <regex> |
| 15 | +#include <string> |
15 | 16 | #include <thread>
|
| 17 | +#include <vector> |
16 | 18 |
|
17 | 19 | //
|
18 | 20 | // Terminal utils
|
@@ -267,6 +269,143 @@ static std::vector<double> embd_to_audio(
|
267 | 269 | return audio;
|
268 | 270 | }
|
269 | 271 |
|
| 272 | +static const std::map<int, std::string> ones = { |
| 273 | + {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}, |
| 274 | + {5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"}, |
| 275 | + {10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, |
| 276 | + {15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"} |
| 277 | +}; |
| 278 | + |
| 279 | +static const std::map<int, std::string> tens = { |
| 280 | + {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, |
| 281 | + {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"} |
| 282 | +}; |
| 283 | + |
| 284 | +// Convert a number less than 1000 to words |
| 285 | +static std::string convert_less_than_thousand(int num) { |
| 286 | + std::string result; |
| 287 | + |
| 288 | + if (num >= 100) { |
| 289 | + result += ones.at(num / 100) + " hundred "; |
| 290 | + num %= 100; |
| 291 | + } |
| 292 | + |
| 293 | + if (num >= 20) { |
| 294 | + result += tens.at(num / 10); |
| 295 | + if (num % 10 > 0) { |
| 296 | + result += "-" + ones.at(num % 10); |
| 297 | + } |
| 298 | + } else if (num > 0) { |
| 299 | + result += ones.at(num); |
| 300 | + } |
| 301 | + |
| 302 | + return result; |
| 303 | +} |
| 304 | + |
| 305 | +static std::string number_to_words(const std::string & number_str) { |
| 306 | + try { |
| 307 | + size_t decimal_pos = number_str.find('.'); |
| 308 | + std::string integer_part = number_str.substr(0, decimal_pos); |
| 309 | + |
| 310 | + int int_number = std::stoi(integer_part); |
| 311 | + std::string result; |
| 312 | + |
| 313 | + if (int_number == 0) { |
| 314 | + result = "zero"; |
| 315 | + } else { |
| 316 | + if (int_number >= 1000000000) { |
| 317 | + int billions = int_number / 1000000000; |
| 318 | + result += convert_less_than_thousand(billions) + " billion "; |
| 319 | + int_number %= 1000000000; |
| 320 | + } |
| 321 | + |
| 322 | + if (int_number >= 1000000) { |
| 323 | + int millions = int_number / 1000000; |
| 324 | + result += convert_less_than_thousand(millions) + " million "; |
| 325 | + int_number %= 1000000; |
| 326 | + } |
| 327 | + |
| 328 | + if (int_number >= 1000) { |
| 329 | + int thousands = int_number / 1000; |
| 330 | + result += convert_less_than_thousand(thousands) + " thousand "; |
| 331 | + int_number %= 1000; |
| 332 | + } |
| 333 | + |
| 334 | + if (int_number > 0) { |
| 335 | + result += convert_less_than_thousand(int_number); |
| 336 | + } |
| 337 | + } |
| 338 | + |
| 339 | + // Handle decimal part |
| 340 | + if (decimal_pos != std::string::npos) { |
| 341 | + result += " point"; |
| 342 | + std::string decimal_part = number_str.substr(decimal_pos + 1); |
| 343 | + for (char digit : decimal_part) { |
| 344 | + result += " " + ones.at(digit - '0'); |
| 345 | + } |
| 346 | + } |
| 347 | + |
| 348 | + return result; |
| 349 | + } catch (const std::exception& e) { |
| 350 | + // Skip if fails |
| 351 | + return " "; |
| 352 | + } |
| 353 | +} |
| 354 | + |
| 355 | +static std::string replace_numbers_with_words(const std::string & input_text) { |
| 356 | + std::regex number_pattern(R"(\d+(\.\d+)?)"); |
| 357 | + std::string result; |
| 358 | + auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern); |
| 359 | + auto end = std::sregex_iterator(); |
| 360 | + |
| 361 | + size_t last_pos = 0; |
| 362 | + for (std::sregex_iterator i = it; i != end; ++i) { |
| 363 | + const std::smatch& match = *i; |
| 364 | + result.append(input_text, last_pos, match.position() - last_pos); |
| 365 | + result.append(number_to_words(match.str())); |
| 366 | + last_pos = match.position() + match.length(); |
| 367 | + } |
| 368 | + result.append(input_text, last_pos); |
| 369 | + |
| 370 | + return result; |
| 371 | +} |
| 372 | + |
| 373 | +// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 |
| 374 | +static std::string process_text(const std::string & text) { |
| 375 | + |
| 376 | + // For now I skipped text romanization as I am unsure how to handle |
| 377 | + // uroman and MeCab implementations in C++ |
| 378 | + // maybe something like https://github.com/anyascii/anyascii/ could work. |
| 379 | + // currently only English would be supported in this function |
| 380 | + |
| 381 | + std::string processed_text = replace_numbers_with_words(text); |
| 382 | + |
| 383 | + std::transform(processed_text.begin(), processed_text.end(), |
| 384 | + processed_text.begin(), ::tolower); |
| 385 | + |
| 386 | + std::regex special_chars(R"([-_/,\.\\])"); |
| 387 | + processed_text = std::regex_replace(processed_text, special_chars, " "); |
| 388 | + |
| 389 | + std::regex non_alpha(R"([^a-z\s])"); |
| 390 | + processed_text = std::regex_replace(processed_text, non_alpha, ""); |
| 391 | + |
| 392 | + std::regex multiple_spaces(R"(\s+)"); |
| 393 | + processed_text = std::regex_replace(processed_text, multiple_spaces, " "); |
| 394 | + |
| 395 | + processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), ""); |
| 396 | + |
| 397 | + /* |
| 398 | + Replace spaces with the separator token same as in line 365 |
| 399 | +
|
| 400 | + for (auto & c : prompt_user) { |
| 401 | + if (c == ' ') { |
| 402 | + prompt_clean += "<|text_sep|>"; |
| 403 | + */ |
| 404 | + processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); |
| 405 | + |
| 406 | + return processed_text; |
| 407 | +} |
| 408 | + |
270 | 409 | static void prompt_add(llama_tokens & prompt, llama_token token) {
|
271 | 410 | prompt.push_back(token);
|
272 | 411 | }
|
@@ -353,23 +492,11 @@ int main(int argc, char ** argv) {
|
353 | 492 |
|
354 | 493 | prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
355 | 494 |
|
356 |
| - // TODO: not sure if this is correct |
| 495 | + // convert the input text into the necessary format expected by OuteTTS |
357 | 496 | {
|
358 |
| - std::string prompt_clean; |
359 |
| - std::string prompt_user = params.prompt; |
360 |
| - |
361 |
| - for (auto & c : prompt_user) { |
362 |
| - if (c == ' ') { |
363 |
| - prompt_clean += "<|text_sep|>"; |
364 |
| - } else { |
365 |
| - if (isalpha(c) || isdigit(c)) { |
366 |
| - c = tolower(c); |
367 |
| - } else { |
368 |
| - continue; |
369 |
| - } |
370 |
| - prompt_clean += c; |
371 |
| - } |
372 |
| - } |
| 497 | + std::string prompt_clean = process_text(params.prompt); |
| 498 | + |
| 499 | + LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); |
373 | 500 |
|
374 | 501 | prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
|
375 | 502 | }
|
|
0 commit comments