diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..a39d3a982961b 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -1,5 +1,11 @@ -set(TARGET llama-tts) -add_executable(${TARGET} tts.cpp) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) +set(TARGET_1 llama-tts-outetts-v1) +add_executable(${TARGET_1} tts-outetts-v1.cpp) +target_link_libraries(${TARGET_1} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET_1} PRIVATE cxx_std_17) +install(TARGETS ${TARGET_1} RUNTIME) + +set(TARGET_2 llama-tts) +add_executable(${TARGET_2} tts.cpp) +target_link_libraries(${TARGET_2} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET_2} PRIVATE cxx_std_17) +install(TARGETS ${TARGET_2} RUNTIME) \ No newline at end of file diff --git a/examples/tts/convert_pt_to_hf.py b/examples/tts/convert_pt_to_hf.py index 8909a65fd1e13..ebd55d9657b24 100644 --- a/examples/tts/convert_pt_to_hf.py +++ b/examples/tts/convert_pt_to_hf.py @@ -12,7 +12,7 @@ from safetensors.torch import save_file # default -model_path = './model.pt'; +model_path = './model.pt' # read from CLI if len(sys.argv) > 1: diff --git a/examples/tts/default_speaker.h b/examples/tts/default_speaker.h new file mode 100644 index 0000000000000..2fa4da5525c06 --- /dev/null +++ b/examples/tts/default_speaker.h @@ -0,0 +1,1885 @@ +// default_speaker.h +#ifndef DEFAULT_SPEAKER_H +#define DEFAULT_SPEAKER_H + +#include "json.hpp" + +namespace DefaultSpeaker { + +inline const char* jsonDataStr = R"( +{ + "text": "The cat watched from the windowsill, tail flicking with quiet curiosity as the first snowflakes of winter began to fall, dusting the world in fragile white.", + "words": [ + { + "word": "The", + "duration": 0.2, + "c1": [ + 720, + 720, + 474, + 691, + 607, + 126, + 597, + 607, + 897, + 288, + 362, + 903, + 333, + 1009, + 79 + ], + "c2": [ + 658, + 663, + 237, + 915, + 74, + 74, + 966, + 721, + 893, + 722, + 630, + 516, + 861, + 385, + 149 + ], + "features": { + "energy": 10, + "spectral_centroid": 15, + "pitch": 45 + } + }, + { + "word": "cat", + "duration": 0.33, + "c1": [ + 700, + 597, + 639, + 838, + 622, + 336, + 975, + 326, + 67, + 375, + 853, + 761, + 35, + 363, + 31, + 1000, + 982, + 192, + 647, + 564, + 329, + 1002, + 275, + 480, + 551 + ], + "c2": [ + 34, + 810, + 457, + 546, + 42, + 631, + 339, + 867, + 115, + 1011, + 509, + 369, + 473, + 85, + 190, + 715, + 391, + 518, + 562, + 986, + 749, + 193, + 530, + 327, + 820 + ], + "features": { + "energy": 14, + "spectral_centroid": 21, + "pitch": 35 + } + }, + { + "word": "watched", + "duration": 0.44, + "c1": [ + 625, + 668, + 168, + 524, + 462, + 151, + 549, + 951, + 597, + 820, + 489, + 329, + 377, + 144, + 112, + 16, + 481, + 133, + 195, + 744, + 144, + 750, + 288, + 500, + 1000, + 58, + 916, + 597, + 72, + 336, + 224, + 476, + 581 + ], + "c2": [ + 204, + 421, + 318, + 677, + 74, + 953, + 903, + 413, + 809, + 37, + 634, + 824, + 933, + 200, + 14, + 1007, + 111, + 17, + 435, + 718, + 559, + 783, + 415, + 821, + 958, + 247, + 14, + 721, + 158, + 235, + 276, + 875, + 683 + ], + "features": { + "energy": 19, + "spectral_centroid": 21, + "pitch": 26 + } + }, + { + "word": "from", + "duration": 0.2, + "c1": [ + 528, + 668, + 738, + 985, + 126, + 924, + 1003, + 325, + 393, + 86, + 114, + 392, + 638, + 915, + 549 + ], + "c2": [ + 929, + 872, + 332, + 296, + 983, + 406, + 867, + 568, + 374, + 328, + 419, + 348, + 177, + 379, + 181 + ], + "features": { + "energy": 10, + "spectral_centroid": 29, + "pitch": 14 + } + }, + { + "word": "the", + "duration": 0.12, + "c1": [ + 470, + 985, + 152, + 474, + 967, + 558, + 460, + 728, + 470 + ], + "c2": [ + 596, + 246, + 314, + 246, + 756, + 238, + 606, + 262, + 499 + ], + "features": { + "energy": 23, + "spectral_centroid": 10, + "pitch": 23 + } + }, + { + "word": "windowsill,", + "duration": 0.75, + "c1": [ + 217, + 126, + 549, + 700, + 198, + 891, + 95, + 683, + 158, + 680, + 16, + 769, + 402, + 776, + 295, + 258, + 68, + 213, + 669, + 865, + 719, + 29, + 949, + 329, + 216, + 481, + 284, + 224, + 221, + 359, + 328, + 311, + 415, + 443, + 410, + 359, + 600, + 590, + 932, + 611, + 905, + 304, + 292, + 72, + 388, + 333, + 66, + 943, + 489, + 648, + 630, + 648, + 402, + 972, + 392, + 558 + ], + "c2": [ + 911, + 19, + 1007, + 169, + 185, + 182, + 399, + 849, + 656, + 963, + 265, + 80, + 453, + 768, + 919, + 1010, + 501, + 794, + 141, + 123, + 93, + 694, + 499, + 174, + 768, + 689, + 598, + 686, + 10, + 381, + 282, + 556, + 126, + 672, + 872, + 650, + 990, + 556, + 913, + 635, + 174, + 819, + 999, + 423, + 64, + 272, + 112, + 600, + 453, + 678, + 791, + 301, + 206, + 187, + 819, + 948 + ], + "features": { + "energy": 17, + "spectral_centroid": 25, + "pitch": 24 + } + }, + { + "word": "tail", + "duration": 0.6, + "c1": [ + 669, + 94, + 917, + 202, + 607, + 720, + 625, + 597, + 126, + 607, + 885, + 700, + 474, + 480, + 126, + 126, + 551, + 720, + 126, + 551, + 720, + 607, + 572, + 234, + 114, + 963, + 963, + 975, + 587, + 119, + 378, + 696, + 730, + 375, + 46, + 827, + 515, + 447, + 979, + 138, + 22, + 267, + 43, + 495, + 16 + ], + "c2": [ + 1011, + 336, + 157, + 39, + 1000, + 721, + 862, + 413, + 557, + 569, + 74, + 569, + 141, + 493, + 124, + 775, + 204, + 588, + 74, + 588, + 810, + 124, + 102, + 1021, + 83, + 848, + 297, + 339, + 335, + 684, + 400, + 905, + 909, + 710, + 460, + 115, + 81, + 628, + 224, + 663, + 892, + 247, + 392, + 234, + 132 + ], + "features": { + "energy": 15, + "spectral_centroid": 23, + "pitch": 34 + } + }, + { + "word": "flicking", + "duration": 0.45, + "c1": [ + 978, + 489, + 630, + 588, + 436, + 798, + 4, + 975, + 245, + 325, + 415, + 4, + 393, + 4, + 4, + 997, + 982, + 437, + 444, + 180, + 861, + 868, + 225, + 440, + 780, + 597, + 720, + 639, + 168, + 426, + 114, + 621, + 854, + 869 + ], + "c2": [ + 571, + 321, + 376, + 232, + 301, + 678, + 904, + 630, + 990, + 772, + 690, + 870, + 719, + 694, + 332, + 558, + 301, + 194, + 279, + 443, + 852, + 64, + 709, + 401, + 401, + 14, + 74, + 873, + 134, + 754, + 1002, + 595, + 540, + 525 + ], + "features": { + "energy": 9, + "spectral_centroid": 22, + "pitch": 23 + } + }, + { + "word": "with", + "duration": 0.23, + "c1": [ + 621, + 392, + 756, + 459, + 433, + 881, + 786, + 198, + 702, + 847, + 490, + 27, + 680, + 146, + 58, + 808, + 997 + ], + "c2": [ + 460, + 840, + 840, + 303, + 847, + 534, + 801, + 99, + 662, + 666, + 510, + 132, + 376, + 96, + 639, + 240, + 668 + ], + "features": { + "energy": 11, + "spectral_centroid": 15, + "pitch": 20 + } + }, + { + "word": "quiet", + "duration": 0.37, + "c1": [ + 969, + 291, + 572, + 720, + 625, + 85, + 698, + 478, + 811, + 956, + 232, + 85, + 962, + 817, + 986, + 483, + 835, + 526, + 77, + 187, + 178, + 50, + 440, + 16, + 198, + 237, + 418, + 862 + ], + "c2": [ + 498, + 606, + 24, + 629, + 662, + 181, + 119, + 678, + 340, + 736, + 217, + 204, + 935, + 796, + 118, + 478, + 818, + 791, + 329, + 209, + 5, + 234, + 337, + 647, + 110, + 922, + 933, + 1011 + ], + "features": { + "energy": 12, + "spectral_centroid": 12, + "pitch": 43 + } + }, + { + "word": "curiosity", + "duration": 0.71, + "c1": [ + 321, + 402, + 215, + 607, + 720, + 224, + 731, + 621, + 491, + 720, + 551, + 456, + 336, + 688, + 476, + 953, + 718, + 806, + 410, + 786, + 976, + 664, + 855, + 433, + 756, + 396, + 699, + 776, + 443, + 739, + 932, + 22, + 305, + 353, + 503, + 564, + 978, + 407, + 395, + 798, + 324, + 168, + 909, + 328, + 328, + 443, + 738, + 114, + 962, + 681, + 535, + 701, + 382 + ], + "c2": [ + 777, + 665, + 629, + 327, + 831, + 764, + 162, + 725, + 810, + 170, + 629, + 774, + 108, + 948, + 972, + 449, + 600, + 905, + 81, + 765, + 601, + 422, + 820, + 746, + 450, + 346, + 733, + 77, + 733, + 81, + 722, + 576, + 286, + 271, + 714, + 95, + 346, + 133, + 514, + 799, + 122, + 900, + 568, + 666, + 209, + 668, + 558, + 630, + 165, + 587, + 423, + 904, + 629 + ], + "features": { + "energy": 10, + "spectral_centroid": 29, + "pitch": 22 + } + }, + { + "word": "as", + "duration": 0.48, + "c1": [ + 474, + 936, + 336, + 589, + 254, + 854, + 79, + 140, + 863, + 854, + 701, + 260, + 929, + 140, + 669, + 808, + 411, + 232, + 434, + 542, + 597, + 126, + 551, + 126, + 607, + 1011, + 774, + 681, + 94, + 25, + 971, + 288, + 305, + 347, + 355, + 415 + ], + "c2": [ + 267, + 813, + 232, + 361, + 77, + 607, + 252, + 933, + 508, + 658, + 846, + 849, + 873, + 496, + 832, + 167, + 440, + 124, + 557, + 124, + 736, + 588, + 569, + 983, + 497, + 360, + 810, + 274, + 588, + 365, + 517, + 934, + 957, + 839, + 646, + 720 + ], + "features": { + "energy": 7, + "spectral_centroid": 31, + "pitch": 23 + } + }, + { + "word": "the", + "duration": 0.13, + "c1": [ + 359, + 568, + 700, + 985, + 80, + 580, + 274, + 129, + 600, + 794 + ], + "c2": [ + 423, + 833, + 245, + 690, + 209, + 688, + 765, + 453, + 677, + 615 + ], + "features": { + "energy": 9, + "spectral_centroid": 26, + "pitch": 20 + } + }, + { + "word": "first", + "duration": 0.36, + "c1": [ + 997, + 325, + 147, + 4, + 780, + 669, + 621, + 896, + 30, + 686, + 526, + 399, + 210, + 783, + 216, + 144, + 329, + 448, + 481, + 288, + 132, + 600, + 168, + 221, + 415, + 415, + 528 + ], + "c2": [ + 325, + 666, + 627, + 629, + 240, + 665, + 650, + 481, + 962, + 328, + 128, + 358, + 166, + 264, + 555, + 30, + 815, + 10, + 669, + 525, + 450, + 746, + 919, + 621, + 647, + 16, + 601 + ], + "features": { + "energy": 13, + "spectral_centroid": 28, + "pitch": 22 + } + }, + { + "word": "snowflakes", + "duration": 0.76, + "c1": [ + 1003, + 680, + 607, + 720, + 126, + 668, + 336, + 224, + 114, + 997, + 426, + 997, + 147, + 221, + 359, + 328, + 1003, + 738, + 974, + 151, + 782, + 179, + 190, + 553, + 453, + 761, + 778, + 23, + 128, + 643, + 125, + 7, + 345, + 223, + 275, + 524, + 325, + 764, + 114, + 953, + 70, + 75, + 449, + 513, + 783, + 830, + 825, + 365, + 819, + 920, + 669, + 700, + 700, + 720, + 220, + 209, + 221 + ], + "c2": [ + 276, + 489, + 810, + 975, + 775, + 913, + 1022, + 818, + 340, + 481, + 690, + 366, + 924, + 782, + 366, + 481, + 400, + 998, + 872, + 556, + 688, + 719, + 78, + 952, + 119, + 412, + 286, + 847, + 60, + 381, + 86, + 694, + 779, + 55, + 246, + 374, + 143, + 91, + 209, + 640, + 313, + 873, + 295, + 355, + 333, + 705, + 468, + 1008, + 317, + 87, + 105, + 511, + 260, + 650, + 574, + 88, + 690 + ], + "features": { + "energy": 12, + "spectral_centroid": 29, + "pitch": 22 + } + }, + { + "word": "of", + "duration": 0.15, + "c1": [ + 443, + 328, + 528, + 85, + 313, + 145, + 588, + 140, + 114, + 325, + 325 + ], + "c2": [ + 924, + 835, + 400, + 832, + 397, + 1011, + 695, + 716, + 366, + 489, + 487 + ], + "features": { + "energy": 7, + "spectral_centroid": 34, + "pitch": 13 + } + }, + { + "word": "winter", + "duration": 0.29, + "c1": [ + 559, + 71, + 549, + 64, + 902, + 609, + 206, + 386, + 428, + 529, + 92, + 1020, + 148, + 456, + 605, + 673, + 958, + 897, + 250, + 716, + 236, + 232 + ], + "c2": [ + 891, + 358, + 1016, + 185, + 558, + 392, + 63, + 45, + 238, + 404, + 603, + 520, + 657, + 628, + 748, + 649, + 629, + 298, + 772, + 483, + 1008, + 401 + ], + "features": { + "energy": 18, + "spectral_centroid": 16, + "pitch": 31 + } + }, + { + "word": "began", + "duration": 0.24, + "c1": [ + 490, + 6, + 596, + 669, + 1011, + 700, + 583, + 349, + 666, + 783, + 215, + 126, + 61, + 22, + 945, + 773, + 920, + 975 + ], + "c2": [ + 194, + 225, + 140, + 243, + 14, + 650, + 929, + 671, + 323, + 365, + 556, + 298, + 707, + 483, + 550, + 57, + 127, + 886 + ], + "features": { + "energy": 11, + "spectral_centroid": 12, + "pitch": 18 + } + }, + { + "word": "to", + "duration": 0.2, + "c1": [ + 265, + 1021, + 113, + 178, + 698, + 561, + 97, + 402, + 25, + 916, + 766, + 660, + 159, + 945, + 967 + ], + "c2": [ + 141, + 976, + 455, + 403, + 760, + 738, + 519, + 123, + 327, + 721, + 690, + 904, + 689, + 140, + 615 + ], + "features": { + "energy": 13, + "spectral_centroid": 19, + "pitch": 20 + } + }, + { + "word": "fall,", + "duration": 0.39, + "c1": [ + 781, + 325, + 4, + 114, + 997, + 415, + 4, + 443, + 953, + 781, + 399, + 993, + 489, + 383, + 920, + 383, + 272, + 755, + 843, + 450, + 763, + 392, + 411, + 682, + 895, + 443, + 490, + 863, + 79 + ], + "c2": [ + 143, + 990, + 209, + 990, + 990, + 556, + 462, + 952, + 914, + 702, + 301, + 833, + 779, + 982, + 26, + 458, + 519, + 9, + 264, + 74, + 304, + 110, + 646, + 905, + 185, + 959, + 53, + 543, + 909 + ], + "features": { + "energy": 13, + "spectral_centroid": 14, + "pitch": 18 + } + }, + { + "word": "dusting", + "duration": 0.89, + "c1": [ + 27, + 669, + 490, + 691, + 691, + 625, + 625, + 572, + 474, + 885, + 215, + 215, + 215, + 215, + 215, + 215, + 75, + 718, + 94, + 924, + 232, + 818, + 14, + 232, + 985, + 547, + 955, + 4, + 627, + 524, + 524, + 579, + 462, + 104, + 597, + 720, + 720, + 491, + 597, + 571, + 802, + 864, + 315, + 515, + 832, + 219, + 133, + 923, + 773, + 245, + 415, + 328, + 590, + 80, + 528, + 322, + 808, + 551, + 625, + 716, + 158, + 562, + 712, + 477, + 905, + 920, + 424 + ], + "c2": [ + 206, + 521, + 77, + 447, + 260, + 810, + 74, + 301, + 243, + 775, + 243, + 775, + 880, + 862, + 1017, + 806, + 806, + 631, + 873, + 806, + 806, + 722, + 14, + 531, + 630, + 500, + 990, + 240, + 690, + 431, + 240, + 815, + 449, + 273, + 903, + 569, + 325, + 629, + 872, + 239, + 686, + 189, + 774, + 264, + 314, + 628, + 107, + 120, + 560, + 929, + 1008, + 610, + 24, + 929, + 400, + 949, + 431, + 721, + 447, + 443, + 774, + 392, + 923, + 855, + 747, + 144, + 460 + ], + "features": { + "energy": 14, + "spectral_centroid": 28, + "pitch": 30 + } + }, + { + "word": "the", + "duration": 0.12, + "c1": [ + 396, + 433, + 276, + 530, + 316, + 117, + 112, + 7, + 531 + ], + "c2": [ + 332, + 479, + 262, + 239, + 123, + 239, + 453, + 499, + 545 + ], + "features": { + "energy": 23, + "spectral_centroid": 11, + "pitch": 30 + } + }, + { + "word": "world", + "duration": 0.32, + "c1": [ + 217, + 489, + 897, + 607, + 402, + 383, + 496, + 937, + 247, + 206, + 790, + 32, + 406, + 856, + 715, + 458, + 278, + 481, + 503, + 399, + 871, + 453, + 858, + 392 + ], + "c2": [ + 593, + 959, + 461, + 546, + 242, + 438, + 81, + 99, + 939, + 361, + 269, + 571, + 525, + 542, + 246, + 10, + 613, + 228, + 913, + 252, + 132, + 132, + 287, + 559 + ], + "features": { + "energy": 22, + "spectral_centroid": 11, + "pitch": 31 + } + }, + { + "word": "in", + "duration": 0.23, + "c1": [ + 558, + 497, + 436, + 598, + 607, + 416, + 311, + 906, + 955, + 905, + 448, + 54, + 92, + 487, + 770, + 298, + 490 + ], + "c2": [ + 838, + 399, + 420, + 819, + 325, + 929, + 124, + 214, + 1021, + 728, + 975, + 688, + 132, + 718, + 724, + 911, + 536 + ], + "features": { + "energy": 14, + "spectral_centroid": 16, + "pitch": 22 + } + }, + { + "word": "fragile", + "duration": 0.41, + "c1": [ + 415, + 325, + 953, + 359, + 325, + 838, + 359, + 764, + 842, + 341, + 706, + 674, + 971, + 592, + 507, + 16, + 628, + 481, + 626, + 691, + 1011, + 610, + 336, + 476, + 528, + 637, + 472, + 251, + 945, + 811, + 406 + ], + "c2": [ + 126, + 990, + 374, + 143, + 629, + 868, + 338, + 91, + 346, + 393, + 407, + 987, + 987, + 1009, + 617, + 854, + 824, + 439, + 789, + 311, + 810, + 497, + 664, + 549, + 135, + 908, + 702, + 639, + 320, + 698, + 414 + ], + "features": { + "energy": 13, + "spectral_centroid": 20, + "pitch": 18 + } + }, + { + "word": "white.", + "duration": 0.75, + "c1": [ + 26, + 432, + 1, + 651, + 998, + 716, + 998, + 727, + 978, + 311, + 85, + 895, + 279, + 392, + 669, + 916, + 549, + 1011, + 97, + 597, + 296, + 392, + 526, + 998, + 835, + 468, + 871, + 405, + 26, + 759, + 524, + 107, + 77, + 22, + 260, + 682, + 621, + 79, + 682, + 411, + 701, + 972, + 691, + 720, + 551, + 597, + 660, + 224, + 236, + 70, + 652, + 215, + 126, + 474, + 597, + 625 + ], + "c2": [ + 475, + 778, + 695, + 612, + 913, + 315, + 536, + 593, + 55, + 371, + 19, + 560, + 821, + 646, + 151, + 801, + 821, + 413, + 14, + 922, + 629, + 380, + 417, + 679, + 487, + 562, + 821, + 706, + 324, + 896, + 169, + 594, + 810, + 864, + 810, + 588, + 862, + 969, + 14, + 105, + 528, + 165, + 420, + 170, + 821, + 423, + 977, + 904, + 690, + 235, + 702, + 14, + 124, + 350, + 74, + 413 + ], + "features": { + "energy": 13, + "spectral_centroid": 11, + "pitch": 23 + } + } + ], + "global_features": { + "energy": 13, + "spectral_centroid": 20, + "pitch": 28 + }, + "interface_version": 3 +})"; + +inline nlohmann::ordered_json getJsonData() { + static const nlohmann::ordered_json parsedJson = nlohmann::ordered_json::parse(jsonDataStr); + return parsedJson; +} + +} // namespace DefaultSpeaker + +#endif // DEFAULT_SPEAKER_H \ No newline at end of file diff --git a/examples/tts/tts-outetts-v1.cpp b/examples/tts/tts-outetts-v1.cpp new file mode 100644 index 0000000000000..2eb786803a999 --- /dev/null +++ b/examples/tts/tts-outetts-v1.cpp @@ -0,0 +1,1087 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "log.h" +#include "json.hpp" +#include "llama.h" +#include "default_speaker.h" + +#define _USE_MATH_DEFINES // For M_PI on MSVC + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +enum outetts_version { + OUTETTS_V0_2, + OUTETTS_V0_3, + OUTETTS_V1_0, +}; + +// Special Tokens structure +struct SpecialTokens { + std::string bos = "<|im_start|>"; + std::string eos = "<|im_end|>"; + std::string c1 = "<|c1_{}|>"; + std::string c2 = "<|c2_{}|>"; + std::string text_start = "<|text_start|>"; + std::string text_end = "<|text_end|>"; + std::string voice_characteristic_start = "<|voice_characteristic_start|>"; + std::string voice_characteristic_end = "<|voice_characteristic_end|>"; + std::string emotion_start = "<|emotion_start|>"; + std::string emotion_end = "<|emotion_end|>"; + std::string audio_start = "<|audio_start|>"; + std::string audio_end = "<|audio_end|>"; + std::string time = "<|t_{:.2f}|>"; + std::string code = "<|code|>"; + std::string energy = "<|energy_{}|>"; + std::string spectral_centroid = "<|spectral_centroid_{}|>"; + std::string pitch = "<|pitch_{}|>"; + std::string word_start = "<|word_start|>"; + std::string word_end = "<|word_end|>"; + std::string features = "<|features|>"; + std::string global_features_start = "<|global_features_start|>"; + std::string global_features_end = "<|global_features_end|>"; +}; + +std::string text_normalization(std::string result) { + // Normalize whitespace characters (newlines, tabs, etc.) to single spaces + result = std::regex_replace(result, std::regex("\\s+"), " "); + + // Strip leading/trailing whitespace + auto start = result.find_first_not_of(" \t\n\r\f\v"); + if (start == std::string::npos) { + return ""; // String is all whitespace + } + auto end = result.find_last_not_of(" \t\n\r\f\v"); + result = result.substr(start, end - start + 1); + + // Normalize common Unicode characters to ASCII equivalents + result = std::regex_replace(result, std::regex("[""]"), "\""); // Curly quotes to straight quotes + result = std::regex_replace(result, std::regex("['']"), "'"); // Curly single quotes + result = std::regex_replace(result, std::regex("[–—]"), "-"); // Various dashes to hyphen + + return result; +} + +// Utility function to format strings (simple replacement for Python's format) +std::string format_string(const std::string& format_str, double value) { + char buffer[100]; + snprintf(buffer, sizeof(buffer), "%.2f", value); + std::string result = format_str; + size_t pos = result.find("{:.2f}"); + if (pos != std::string::npos) { + result.replace(pos, 6, buffer); + } + return result; +} + +std::string format_string(const std::string& format_str, int value) { + std::string result = format_str; + size_t pos = result.find("{}"); + if (pos != std::string::npos) { + result.replace(pos, 2, std::to_string(value)); + } + return result; +} + +std::string format_string(const std::string& format_str, const std::string& value) { + std::string result = format_str; + size_t pos = result.find("{}"); + if (pos != std::string::npos) { + result.replace(pos, 2, value); + } + return result; +} + +// Function to get features +std::vector get_features(const json& f, const SpecialTokens& special_tokens) { + std::vector result; + + int energy = f.contains("energy") ? f["energy"].get() : 0; + int spectral_centroid = f.contains("spectral_centroid") ? f["spectral_centroid"].get() : 0; + int pitch = f.contains("pitch") ? f["pitch"].get() : 0; + + result.push_back(format_string(special_tokens.energy, energy)); + result.push_back(format_string(special_tokens.spectral_centroid, spectral_centroid)); + result.push_back(format_string(special_tokens.pitch, pitch)); + + return result; +} + +// Function to get global features +std::string get_global_features(const json& f, const SpecialTokens& special_tokens, const std::string& global_features_template) { + std::vector features = get_features(f, special_tokens); + std::string codes; + for (const auto& feature : features) { + codes += feature; + } + + std::string result = global_features_template; + // Replace {fs} with global_features_start + size_t pos = result.find("{fs}"); + if (pos != std::string::npos) { + result.replace(pos, 4, special_tokens.global_features_start); + } + + // Replace {codes} with the joined features + pos = result.find("{codes}"); + if (pos != std::string::npos) { + result.replace(pos, 7, codes); + } + + // Replace {fe} with global_features_end + pos = result.find("{fe}"); + if (pos != std::string::npos) { + result.replace(pos, 4, special_tokens.global_features_end); + } + + return result; +} + +// Function to create codes +std::string create_codes(const json& words, const SpecialTokens& special_tokens) { + std::vector codes; + + for (const auto& word_item : words) { + std::string word = word_item["word"].get() + special_tokens.features; + word += format_string(special_tokens.time, word_item["duration"].get()); + + // Add features + std::vector features = get_features(word_item["features"], special_tokens); + for (const auto& feature : features) { + word += feature; + } + + // Add pairs of c1 and c2 + std::vector pairs; + for (size_t idx = 0; idx < word_item["c1"].size(); idx++) { + std::string c1 = format_string(special_tokens.c1, word_item["c1"][idx].get()); + std::string c2 = format_string(special_tokens.c2, word_item["c2"][idx].get()); + pairs.push_back(c1 + c2); + } + + word += special_tokens.code; + for (const auto& pair : pairs) { + word += pair; + } + + codes.push_back(special_tokens.word_start + word + special_tokens.word_end); + } + + // Join codes with newline + std::string result; + for (size_t i = 0; i < codes.size(); i++) { + result += codes[i]; + if (i < codes.size() - 1) { + result += "\n"; + } + } + + return result; +} + +// Function to initialize prompt +std::string init_prompt(const std::string& text, const SpecialTokens& special_tokens, const std::string& input_prompt_template) { + std::string result = input_prompt_template; + + // Replace {bos} with bos + size_t pos = result.find("{bos}"); + if (pos != std::string::npos) { + result.replace(pos, 5, special_tokens.bos); + } + + // Replace {text_start} with text_start + pos = result.find("{text_start}"); + if (pos != std::string::npos) { + result.replace(pos, 12, special_tokens.text_start); + } + + // Replace {text} with the text + pos = result.find("{text}"); + if (pos != std::string::npos) { + result.replace(pos, 6, text); + } + + // Replace {text_end} with text_end + pos = result.find("{text_end}"); + if (pos != std::string::npos) { + result.replace(pos, 10, special_tokens.text_end); + } + + // Replace {audio_start} with audio_start + pos = result.find("{audio_start}"); + if (pos != std::string::npos) { + result.replace(pos, 13, special_tokens.audio_start); + } + + return result; +} + +// Function to get separator based on text +std::string get_separator(const std::string& text) { + bool has_hiragana = false; + bool has_katakana = false; + bool has_han = false; + bool has_hangul = false; + + for (char c : text) { + unsigned char uc = static_cast(c); + if (uc >= 0xE3 && uc <= 0xE3) { // Simplified check for hiragana (U+3040-U+309F) + has_hiragana = true; + } + else if (uc >= 0xE3 && uc <= 0xE3) { // Simplified check for katakana (U+30A0-U+30FF) + has_katakana = true; + } + else if (uc >= 0xE4 && uc <= 0xE9) { // Simplified check for han (U+4E00-U+9FFF) + has_han = true; + } + else if (uc >= 0xEA && uc <= 0xED) { // Simplified check for hangul (U+AC00-U+D7AF) + has_hangul = true; + } + } + + if (has_hiragana || has_katakana || has_han) { + return "。"; + } + else if (has_hangul) { + return ". "; + } + else { + return ". "; + } +} + +inline std::string trim(std::string_view sv) { + sv.remove_prefix(std::min(sv.find_first_not_of(" \t\n\r\f\v"), sv.size())); + auto pos = sv.find_last_not_of(" \t\n\r\f\v"); + if (pos != std::string_view::npos) { + sv.remove_suffix(sv.size() - pos - 1); + } + return std::string(sv); +} + +inline bool ends_with(const std::string& value, const std::string& ending) { + if (ending.size() > value.size()) return false; + return value.size() >= ending.size() && + value.substr(value.size() - ending.size()) == ending; +} + +std::pair merge_speaker_text(const std::string& input_text, const std::string& speaker_text_orig) { + std::string speaker_text = trim(speaker_text_orig); + std::string separator = get_separator(speaker_text); + + // Determine allowed endings based on the separator + std::vector allowed_ends; + if (separator == "。") { + allowed_ends = {"。", "?", "!", "?", "!"}; + } else { + allowed_ends = {".", "?", "!"}; + } + + std::string rs = ""; // This will be the separator/space to insert + + if (!speaker_text.empty()) { + bool ends_with_allowed_char = false; + for (const std::string& end_char : allowed_ends) { + if (ends_with(speaker_text, end_char)) { + ends_with_allowed_char = true; + break; + } + } + + if (!ends_with_allowed_char) { + rs = separator; + } else { + if (separator != "。") { + rs = " "; + } + } + } + + std::string output = speaker_text + rs + trim(input_text); + std::string trimmed_rs = trim(rs); + return std::make_pair(output, trimmed_rs); +} + +// Main function to get completion prompt +std::string get_completion_prompt(const std::string& text, json& speaker) { + // Initialize special tokens + SpecialTokens special_tokens; + + // Templates (would normally be passed as parameters) + std::string input_prompt_template = "{bos}{text_start}{text}{text_end}\n{audio_start}\n"; + std::string global_features_template = "{fs}{codes}{fe}"; + + // Normalize text + + std::string normalized_text = text_normalization(text); + + std::string prompt; + if (!speaker.is_null()) { + // Merge speaker text + auto [merged_text, separator] = merge_speaker_text(normalized_text, speaker["text"]); + normalized_text = merged_text; + + // Update last word with separator if necessary + if (!separator.empty()) { + speaker["words"].back()["word"] = speaker["words"].back()["word"].get() + separator; + } + + // Create codes + std::string codes = create_codes(speaker["words"], special_tokens); + + // Initialize prompt + prompt = init_prompt(normalized_text, special_tokens, input_prompt_template); + + // Add codes and word_start + prompt += codes + "\n" + special_tokens.word_start; + } + else { + // Initialize prompt without speaker + prompt = init_prompt(normalized_text, special_tokens, input_prompt_template); + } + + return prompt; +} + +static json speaker_from_file(const std::string & speaker_file) { + std::ifstream file(speaker_file); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str()); + return json(); + } + + json speaker = json::parse(file); + return speaker; +} + +static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) { + if (speaker.contains("version")) { + int version = speaker["interface_version"].get(); + if (version == 1) { + return OUTETTS_V0_2; + } else if (version == 2) { + return OUTETTS_V0_3; + } else if (version == 3) { + return OUTETTS_V1_0; + } else { + LOG_ERR("%s: Unsupported speaker version '%d'\n", __func__, version); + } + } + + // Also could get version from model itself + const char *chat_template = llama_model_chat_template(model, nullptr); + if (chat_template && std::string(chat_template) == "outetts-0.3") { + return OUTETTS_V0_3; + } else if (chat_template && std::string(chat_template) == "outetts-1.0") { + return OUTETTS_V1_0; + } + + // Use 0.2 as the default version + return OUTETTS_V0_2; +} + +// ------------------------ +// Helper functions for UTF-8 +// ------------------------ + +// Return the number of bytes in the current UTF-8 character. +int utf8_char_length(unsigned char c) { + if (c < 0x80) return 1; + else if ((c & 0xE0) == 0xC0) return 2; + else if ((c & 0xF0) == 0xE0) return 3; + else if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +// Decode a UTF-8 string into a vector of Unicode code points. +std::vector decode_utf8(const std::string &s) { + std::vector codepoints; + size_t i = 0; + while (i < s.size()) { + unsigned char c = s[i]; + uint32_t cp = 0; + int len = utf8_char_length(c); + if (len == 1) { + cp = c; + } else if (len == 2) { + cp = ((c & 0x1F) << 6) | (s[i+1] & 0x3F); + } else if (len == 3) { + cp = ((c & 0x0F) << 12) | ((s[i+1] & 0x3F) << 6) | (s[i+2] & 0x3F); + } else if (len == 4) { + cp = ((c & 0x07) << 18) | ((s[i+1] & 0x3F) << 12) | ((s[i+2] & 0x3F) << 6) | (s[i+3] & 0x3F); + } + codepoints.push_back(cp); + i += len; + } + return codepoints; +} + +// Encode a single Unicode code point into a UTF-8 string. +std::string encode_utf8(uint32_t cp) { + std::string result; + if (cp <= 0x7F) { + result.push_back(static_cast(cp)); + } else if (cp <= 0x7FF) { + result.push_back(static_cast(0xC0 | ((cp >> 6) & 0x1F))); + result.push_back(static_cast(0x80 | (cp & 0x3F))); + } else if (cp <= 0xFFFF) { + result.push_back(static_cast(0xE0 | ((cp >> 12) & 0x0F))); + result.push_back(static_cast(0x80 | ((cp >> 6) & 0x3F))); + result.push_back(static_cast(0x80 | (cp & 0x3F))); + } else { + result.push_back(static_cast(0xF0 | ((cp >> 18) & 0x07))); + result.push_back(static_cast(0x80 | ((cp >> 12) & 0x3F))); + result.push_back(static_cast(0x80 | ((cp >> 6) & 0x3F))); + result.push_back(static_cast(0x80 | (cp & 0x3F))); + } + return result; +} + +// Tokenize a UTF-8 string into individual characters. +std::vector utf8_tokenize(const std::string &text) { + std::vector tokens; + size_t i = 0; + while (i < text.size()) { + int len = utf8_char_length(static_cast(text[i])); + tokens.push_back(text.substr(i, len)); + i += len; + } + return tokens; +} + +// Trim leading and trailing whitespace. +std::string trim(const std::string &s) { + size_t start = s.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) + return ""; + size_t end = s.find_last_not_of(" \t\n\r"); + return s.substr(start, end - start + 1); +} + +// Replace multiple whitespace characters with a single space. +std::string removeExtraSpaces(const std::string &s) { + std::string result; + bool in_space = false; + for (char c : s) { + if (std::isspace(static_cast(c))) { + if (!in_space) { + result.push_back(' '); + in_space = true; + } + } else { + result.push_back(c); + in_space = false; + } + } + return trim(result); +} + +// ------------------------ +// Language detection and tokenization +// ------------------------ + +class LanguageDetector { +public: + // Return true if any code point is in the ranges for Japanese (Hiragana, Katakana) + // or Chinese/Japanese ideographs. + static bool check(const std::string &text) { + std::vector cps = decode_utf8(text); + for (uint32_t cp : cps) { + if ((cp >= 0x3040 && cp <= 0x309F) || // Hiragana + (cp >= 0x30A0 && cp <= 0x30FF) || // Katakana + (cp >= 0x4E00 && cp <= 0x9FFF)) { // CJK Unified Ideographs + return true; + } + } + return false; + } +}; + +// Tokenize text based on language. For zh/ja, we use our simple per‐character splitting; +// for others, we split on whitespace. +std::vector tokenize_text(const std::string &text) { + std::string t = trim(text); + if (t.empty()) + return {}; + if (LanguageDetector::check(text)) { + return utf8_tokenize(text); + } else { + std::vector tokens; + std::istringstream iss(text); + std::string token; + while (iss >> token) + tokens.push_back(token); + return tokens; + } +} + +// Count words (tokens) in the text. +int count_words(const std::string &text) { + return tokenize_text(text).size(); +} + +// Join tokens into a string. If no_space is true, tokens are concatenated without a separator. +std::string join_tokens(const std::vector &tokens, bool no_space) { + std::string result; + if (no_space) { + for (const auto &token : tokens) + result += token; + } else { + for (size_t i = 0; i < tokens.size(); ++i) { + if (i > 0) + result += " "; + result += tokens[i]; + } + } + return result; +} + +// ------------------------ +// Sentence splitting +// ------------------------ +// +// We split the text into “sentences” by scanning Unicode code points +// and using a set of sentence-ending characters (including punctuation +// for both Latin and CJK texts). The algorithm accumulates code points +// until a sentence end is found, then reassembles the sentence. +std::vector split_into_sentences(const std::string &text) { + std::vector cps = decode_utf8(text); + std::vector< std::vector > sentenceCps; + std::vector currentSentence; + + auto isSentenceEnd = [](uint32_t cp) -> bool { + return cp == 0x002E || cp == 0x0021 || cp == 0x003F || // . ! ? + cp == 0x3002 || cp == 0xFF01 || cp == 0xFF1F || // 。 ! ? + cp == 0xFE56 || cp == 0xFE57; // ︕ ︖ + }; + + for (size_t i = 0; i < cps.size(); ++i) { + currentSentence.push_back(cps[i]); + if (isSentenceEnd(cps[i])) { + // Also include following whitespace as part of the sentence delimiter. + while (i + 1 < cps.size() && + (cps[i + 1] == 0x0020 || cps[i + 1] == 0x0009 || + cps[i + 1] == 0x000A || cps[i + 1] == 0x000D)) { + ++i; + currentSentence.push_back(cps[i]); + } + sentenceCps.push_back(currentSentence); + currentSentence.clear(); + } + } + if (!currentSentence.empty()) { + sentenceCps.push_back(currentSentence); + } + + std::vector sentences; + for (auto &sc : sentenceCps) { + std::string sentence; + for (uint32_t cp : sc) + sentence += encode_utf8(cp); + sentence = trim(sentence); + if (!sentence.empty()) + sentences.push_back(sentence); + } + return sentences; +} + +// ------------------------ +// Text chunking +// ------------------------ +// +// This function splits the text into chunks that contain between min_words and max_words. +// It uses sentence splitting and then tokenizes each sentence. When a sentence is too long, +// it splits it further. Note that for zh/ja the tokens are joined without spaces, +// while for other languages spaces are inserted. +std::vector chunk_text(const std::string &text, int min_words = 10, int max_words = 30) { + std::string norm = removeExtraSpaces(text); + norm = trim(norm); + if (norm.empty()) + return {}; + + std::vector sentences = split_into_sentences(norm); + std::vector chunks; + std::string current_chunk = ""; + int current_word_count = 0; + + for (const auto &sentence : sentences) { + std::string s = trim(sentence); + if (s.empty()) + continue; + + std::vector sentence_tokens = tokenize_text(s); + int sentence_word_count = sentence_tokens.size(); + + // If the sentence is longer than max_words, split it into parts. + if (sentence_word_count > max_words) { + if (!current_chunk.empty()) { + chunks.push_back(current_chunk); + current_chunk = ""; + current_word_count = 0; + } + std::vector current_part; + int word_count = 0; + for (const auto &token : sentence_tokens) { + current_part.push_back(token); + ++word_count; + if (word_count >= max_words) { + bool isLang = LanguageDetector::check(s); + std::string part = isLang ? join_tokens(current_part, true) + : join_tokens(current_part, false); + chunks.push_back(part); + current_part.clear(); + word_count = 0; + } + } + if (!current_part.empty()) { + bool isLang = LanguageDetector::check(s); + std::string part = isLang ? join_tokens(current_part, true) + : join_tokens(current_part, false); + chunks.push_back(part); + } + continue; + } + + if (current_word_count + sentence_word_count <= max_words) { + if (!current_chunk.empty()) { + current_chunk += LanguageDetector::check(sentence) ? s : " " + s; + } else { + current_chunk = s; + } + current_word_count += sentence_word_count; + } else { + if (current_word_count >= min_words) { + chunks.push_back(current_chunk); + current_chunk = s; + current_word_count = sentence_word_count; + } else { + int space_left = max_words - current_word_count; + std::vector current_part(sentence_tokens.begin(), + sentence_tokens.begin() + space_left); + std::vector remaining_part(sentence_tokens.begin() + space_left, + sentence_tokens.end()); + bool isLang = LanguageDetector::check(s); + std::string first_chunk = isLang ? + (current_chunk + join_tokens(current_part, true)) : + (current_chunk + " " + join_tokens(current_part, false)); + chunks.push_back(first_chunk); + current_chunk = isLang ? join_tokens(remaining_part, true) + : join_tokens(remaining_part, false); + current_word_count = remaining_part.size(); + } + } + } + if (!current_chunk.empty()) + chunks.push_back(current_chunk); + + return chunks; +} + +std::vector> extract_codebooks(const std::string& codes) { + std::vector codebook1; + std::vector codebook2; + + // Use regex to find all matches + std::regex pattern1("<\\|c1_(\\d+)\\|>"); + std::regex pattern2("<\\|c2_(\\d+)\\|>"); + + // Iterator for regex matches + std::sregex_iterator iter1(codes.begin(), codes.end(), pattern1); + std::sregex_iterator iter2(codes.begin(), codes.end(), pattern2); + std::sregex_iterator end; + + // Extract codebook1 values + for (; iter1 != end; ++iter1) { + std::smatch match = *iter1; + codebook1.push_back(std::stoi(match[1])); + } + + // Extract codebook2 values + for (; iter2 != end; ++iter2) { + std::smatch match = *iter2; + codebook2.push_back(std::stoi(match[1])); + } + + // Truncate to the minimum size of both codebooks + size_t t = std::min(codebook1.size(), codebook2.size()); + codebook1.resize(t); + codebook2.resize(t); + + return {codebook1, codebook2}; +} + +// +// Terminal utils +// + +#define SQR(X) ((X) * (X)) +#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40 + +/** + * Quantizes 24-bit RGB to xterm256 code range [16,256). + */ +static int rgb2xterm256(int r, int g, int b) { + unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377}; + int av, ir, ig, ib, il, qr, qg, qb, ql; + av = r * .299 + g * .587 + b * .114 + .5; + ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8; + qr = cube[(ir = UNCUBE(r))]; + qg = cube[(ig = UNCUBE(g))]; + qb = cube[(ib = UNCUBE(b))]; + if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <= + SQR(ql - r) + SQR(ql - g) + SQR(ql - b)) + return ir * 36 + ig * 6 + ib + 020; + return il + 0350; +} + +static std::string set_xterm256_foreground(int r, int g, int b) { + int x = rgb2xterm256(r, g, b); + std::ostringstream oss; + oss << "\033[38;5;" << x << "m"; + return oss.str(); +} + +const std::vector k_colors = { + set_xterm256_foreground(220, 5, 12), + set_xterm256_foreground(232, 96, 28), + set_xterm256_foreground(241, 147, 45), + set_xterm256_foreground(246, 193, 65), + set_xterm256_foreground(247, 240, 86), + set_xterm256_foreground(144, 201, 135), + set_xterm256_foreground( 78, 178, 101), +}; + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -p \"Hello!\"\n", argv[0]); + LOG("\n"); +} + +static void prompt_add(llama_tokens & prompt, llama_token token) { + prompt.push_back(token); +} + +static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) { + prompt.insert(prompt.end(), tokens.begin(), tokens.end()); +} + +static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt_add(prompt, tmp); +} + +static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { + prompt.clear(); +} + +int main(int argc, char ** argv) { + common_params params; + + params.prompt = ""; + + params.n_predict = 8192; + params.n_batch = 8192; + params.n_ctx = 8192; + + // Recommended sampling params + params.sampling.top_k = 40; + params.sampling.temp = 0.4f; + params.sampling.penalty_repeat = 1.1f; + params.sampling.penalty_last_n = 64; + params.sampling.min_p = 0.05f; + + params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { + return 1; + } + + const int n_parallel = params.n_parallel; + const int n_predict = params.n_predict; + + common_init(); + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_ttc = NULL; // text-to-codes + llama_context * ctx_ttc = NULL; + + // TODO not implemented + llama_model * model_cts = NULL; // codes-to-speech + llama_context * ctx_cts = NULL; + + common_init_result llama_init_ttc = common_init_from_params(params); + + model_ttc = llama_init_ttc.model.get(); + ctx_ttc = llama_init_ttc.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model_ttc); + + // TODO: refactor in a common struct + params.model = params.vocoder.model; + params.model_url = params.vocoder.model_url; + params.hf_repo = params.vocoder.hf_repo; + params.hf_file = params.vocoder.hf_file; + + params.embedding = true; + + // TODO DAC not implemented. + // common_init_result llama_init_cts = common_init_from_params(params); + // model_cts = llama_init_cts.model.get(); + // ctx_cts = llama_init_cts.context.get(); + + std::vector smpl(n_parallel); + for (int i = 0; i < n_parallel; ++i) { + params.sampling.no_perf = (i != 0); + params.sampling.seed = params.sampling.seed + 1; + + smpl[i] = common_sampler_init(model_ttc, params.sampling); + } + + LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl[0])); + LOG_INF("sampler params: \n%s\n", params.sampling.print().c_str()); + LOG_INF("sampler chain: %s\n", common_sampler_print(smpl[0]).c_str()); + + LOG_INF("%s: loading done\n", __func__); + + const auto t_main_start = ggml_time_us(); + + // process prompt and generate codes + + std::vector codes; + + std::vector chunks; + + if (params.vocoder.chunked) { + chunks = chunk_text(params.prompt, 10, 30); + } else { + chunks.push_back(params.prompt); + } + + { + for (std::string& prompt : chunks) { + + // Reset the context state before processing each new chunk + llama_kv_cache_clear(ctx_ttc); + + LOG_INF("%s: constructing prompt ..\n", __func__); + + json speaker = nullptr; + + // load speaker if given + if (!params.vocoder.speaker_file.empty()) { + LOG_INF("%s: loading speaker ..\n", __func__); + speaker = speaker_from_file(params.vocoder.speaker_file); + + if (speaker.empty()) { + LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str()); + return 1; + } + } else { + speaker = DefaultSpeaker::getJsonData(); + } + + std::vector prompt_inp; + + prompt_init(prompt_inp, vocab); + + // convert the input text into the necessary format expected by OuteTTS + { + std::string completion_prompt = get_completion_prompt(prompt, speaker); + + LOG_INF("%s: prompt: '%s'\n", __func__, completion_prompt.c_str()); + + prompt_add(prompt_inp, vocab, completion_prompt, false, true); + } + + // --- generate codes --- // + + // create a llama_batch + // we use this object to submit token data for decoding + llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); + + std::vector seq_ids(n_parallel, 0); + for (int32_t i = 0; i < n_parallel; ++i) { + seq_ids[i] = i; + } + + // evaluate the initial prompt + for (size_t i = 0; i < prompt_inp.size(); ++i) { + common_batch_add(batch, prompt_inp[i], i, seq_ids, false); + } + GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx_ttc, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + if (n_parallel > 1) { + LOG_INF("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); + } + + llama_synchronize(ctx_ttc); + + LOG_INF("%s: time for prompt: %.3f ms\n\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f); + + const auto t_dec_start = ggml_time_us(); + + // main loop + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + std::vector i_batch(n_parallel, batch.n_tokens - 1); + + int n_past = batch.n_tokens; + int n_decode = 0; + + bool next_token_uses_guide_token = true; + + while (n_decode <= n_predict) { + // prepare the next batch + common_batch_clear(batch); + + // sample the next token for each parallel sequence / stream + for (int32_t i = 0; i < n_parallel; ++i) { + if (i_batch[i] < 0) { + // the stream has already finished + continue; + } + + llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); + + // Chunked text can be used instead of guide tokens + // TODO implement this for v1 if still needed. + + //guide tokens help prevent hallucinations by forcing the TTS to use the correct word + // if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) { + // llama_token guide_token = guide_tokens[0]; + // guide_tokens.erase(guide_tokens.begin()); + // new_token_id = guide_token; //ensure correct word fragment is used + // } + + //this is the token id that always precedes a new word + next_token_uses_guide_token = (new_token_id == 198); + + common_sampler_accept(smpl[i], new_token_id, true); + + codes.push_back(new_token_id); + + const auto * cands = common_sampler_get_candidates(smpl[i]); + + // is it an end of generation? -> mark the stream as finished + if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) { + std::string reason; + if (llama_vocab_is_eog(vocab, new_token_id)) { + reason = "eos"; + } else { + reason = "n_predict"; + } + + i_batch[i] = -1; + + LOG("\n"); + if (n_parallel > 1) { + LOG_CNT("\n"); + LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n", __func__, i, n_past, reason.c_str()); + } + + continue; + } + + { + const float p = cands->data[cands->selected].p; + + const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) ((3*p)*float(k_colors.size())))); + + LOG_CNT("%s%d%s", k_colors[col].c_str(), i, "\033[0m"); + //LOG_CNT("%d", i); + } + + i_batch[i] = batch.n_tokens; + + // push this new token for next evaluation + common_batch_add(batch, new_token_id, n_past, { i }, true); + } + + // all streams are finished + if (batch.n_tokens == 0) { + break; + } + + n_decode += 1; + n_past += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx_ttc, batch)) { + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + llama_batch_free(batch); + + LOG("\n"); + LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); + + common_perf_print(ctx_ttc, smpl[0]); + } + + { + const std::string inp_txt = common_detokenize(ctx_ttc, codes, true); + + // For DAC decoding + std::vector> codebooks = extract_codebooks(inp_txt); + + // Create string representation of the codebooks for debugging + std::stringstream cb1_str, cb2_str; + + cb1_str << "codebook1: ["; + for (size_t i = 0; i < codebooks[0].size(); ++i) { + cb1_str << codebooks[0][i]; + if (i < codebooks[0].size() - 1) cb1_str << ", "; + } + cb1_str << "]"; + + cb2_str << "codebook2: ["; + for (size_t i = 0; i < codebooks[1].size(); ++i) { + cb2_str << codebooks[1][i]; + if (i < codebooks[1].size() - 1) cb2_str << ", "; + } + cb2_str << "]"; + + LOG("\n"); + LOG_INF("codes: '%s'\n", inp_txt.c_str()); + LOG_INF("%s: codes size: %d\n", __func__, (int) codes.size()); + LOG_INF("%s: codebook sizes: cb1=%d, cb2=%d\n", __func__, (int)codebooks[0].size(), (int)codebooks[1].size()); + LOG_INF("%s: %s\n", __func__, cb1_str.str().c_str()); + LOG_INF("%s: %s\n", __func__, cb2_str.str().c_str()); + } + + + // --- Speech Generation --- // + // TODO: Functionality not yet implemented. + // Requires integration with the DAC + + } + +} diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 0f047986965f8..120eacd864488 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -1087,4 +1087,4 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 llama_backend_free(); return retval; -} +} \ No newline at end of file