@@ -63,7 +63,47 @@ static void print_usage(int, char ** argv) {
63
63
LOG (" \n " );
64
64
}
65
65
66
- static void fill_hann_window (int length, bool periodic, double * output) {
66
+ struct wav_header {
67
+ char riff[4 ] = {' R' , ' I' , ' F' , ' F' };
68
+ uint32_t chunk_size;
69
+ char wave[4 ] = {' W' , ' A' , ' V' , ' E' };
70
+ char fmt[4 ] = {' f' , ' m' , ' t' , ' ' };
71
+ uint32_t fmt_chunk_size = 16 ;
72
+ uint16_t audio_format = 1 ; // PCM
73
+ uint16_t num_channels = 1 ; // Mono
74
+ uint32_t sample_rate;
75
+ uint32_t byte_rate;
76
+ uint16_t block_align;
77
+ uint16_t bits_per_sample = 16 ;
78
+ char data[4 ] = {' d' , ' a' , ' t' , ' a' };
79
+ uint32_t data_size;
80
+ };
81
+
82
+ static void save_wav16 (const std::string & fname, const std::vector<float > & data, int sample_rate) {
83
+ std::ofstream file (fname, std::ios::binary);
84
+ if (!file) {
85
+ LOG_ERR (" %s: Failed to open file '%s' for writing" , __func__, fname.c_str ());
86
+ return ;
87
+ }
88
+
89
+ wav_header header;
90
+ header.sample_rate = sample_rate;
91
+ header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8 );
92
+ header.block_align = header.num_channels * (header.bits_per_sample / 8 );
93
+ header.data_size = data.size () * (header.bits_per_sample / 8 );
94
+ header.chunk_size = 36 + header.data_size ;
95
+
96
+ file.write (reinterpret_cast <const char *>(&header), sizeof (header));
97
+
98
+ for (const auto & sample : data) {
99
+ int16_t pcm_sample = static_cast <int16_t >(std::clamp (sample * 32767.0 , -32768.0 , 32767.0 ));
100
+ file.write (reinterpret_cast <const char *>(&pcm_sample), sizeof (pcm_sample));
101
+ }
102
+
103
+ file.close ();
104
+ }
105
+
106
+ static void fill_hann_window (int length, bool periodic, float * output) {
67
107
int offset = -1 ;
68
108
if (periodic) {
69
109
offset = 0 ;
@@ -74,31 +114,31 @@ static void fill_hann_window(int length, bool periodic, double * output) {
74
114
}
75
115
76
116
// very poor-man fft
77
- static void twiddle (double * real, double * imag, int k, int N) {
78
- double angle = 2 * M_PI * k / N;
117
+ static void twiddle (float * real, float * imag, int k, int N) {
118
+ float angle = 2 * M_PI * k / N;
79
119
*real = cos (angle);
80
120
*imag = sin (angle);
81
121
}
82
122
83
- static void irfft (int n, const double * inp_cplx, double * out_real) {
123
+ static void irfft (int n, const float * inp_cplx, float * out_real) {
84
124
int N = n / 2 + 1 ;
85
125
86
- std::vector<double > real_input (N);
87
- std::vector<double > imag_input (N);
126
+ std::vector<float > real_input (N);
127
+ std::vector<float > imag_input (N);
88
128
for (int i = 0 ; i < N; ++i) {
89
129
real_input[i] = inp_cplx[2 * i];
90
130
imag_input[i] = inp_cplx[2 * i + 1 ];
91
131
}
92
132
93
- std::vector<double > real_output (n);
94
- std::vector<double > imag_output (n);
133
+ std::vector<float > real_output (n);
134
+ std::vector<float > imag_output (n);
95
135
96
136
for (int k = 0 ; k < n; ++k) {
97
137
real_output[k] = 0 .0f ;
98
138
imag_output[k] = 0 .0f ;
99
139
for (int m = 0 ; m < N; ++m) {
100
- double twiddle_real;
101
- double twiddle_imag;
140
+ float twiddle_real;
141
+ float twiddle_imag;
102
142
103
143
twiddle (&twiddle_real, &twiddle_imag, k * m, n);
104
144
@@ -123,7 +163,7 @@ static void irfft(int n, const double * inp_cplx, double * out_real) {
123
163
// hop_length = 320
124
164
// pad = 480
125
165
//
126
- static void fold (const std::vector<double > & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<double > & output) {
166
+ static void fold (const std::vector<float > & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float > & output) {
127
167
int64_t output_height = n_out;
128
168
int64_t kernel_w = n_win;
129
169
int64_t stride_w = n_hop;
@@ -147,103 +187,63 @@ static void fold(const std::vector<double> & data, int64_t n_out, int64_t n_win,
147
187
output.resize (n_out - 2 * n_pad);
148
188
}
149
189
150
- struct wav_header {
151
- char riff[4 ] = {' R' , ' I' , ' F' , ' F' };
152
- uint32_t chunk_size;
153
- char wave[4 ] = {' W' , ' A' , ' V' , ' E' };
154
- char fmt[4 ] = {' f' , ' m' , ' t' , ' ' };
155
- uint32_t fmt_chunk_size = 16 ;
156
- uint16_t audio_format = 1 ; // PCM
157
- uint16_t num_channels = 1 ; // Mono
158
- uint32_t sample_rate;
159
- uint32_t byte_rate;
160
- uint16_t block_align;
161
- uint16_t bits_per_sample = 16 ;
162
- char data[4 ] = {' d' , ' a' , ' t' , ' a' };
163
- uint32_t data_size;
164
- };
165
-
166
- static void save_wav16 (const std::string & fname, const std::vector<double > & data, int sample_rate) {
167
- std::ofstream file (fname, std::ios::binary);
168
- if (!file) {
169
- LOG_ERR (" %s: Failed to open file '%s' for writing" , __func__, fname.c_str ());
170
- return ;
171
- }
172
-
173
- wav_header header;
174
- header.sample_rate = sample_rate;
175
- header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8 );
176
- header.block_align = header.num_channels * (header.bits_per_sample / 8 );
177
- header.data_size = data.size () * (header.bits_per_sample / 8 );
178
- header.chunk_size = 36 + header.data_size ;
179
-
180
- file.write (reinterpret_cast <const char *>(&header), sizeof (header));
181
-
182
- for (const auto & sample : data) {
183
- int16_t pcm_sample = static_cast <int16_t >(std::clamp (sample * 32767.0 , -32768.0 , 32767.0 ));
184
- file.write (reinterpret_cast <const char *>(&pcm_sample), sizeof (pcm_sample));
185
- }
186
-
187
- file.close ();
188
- }
189
-
190
- static std::vector<double > embd_to_audio (
190
+ // TODO: not optimized at all
191
+ static std::vector<float > embd_to_audio (
191
192
const float * embd,
192
- const std::vector<llama_token> & codes ,
193
+ const int n_codes ,
193
194
const int n_embd,
194
195
const int n_thread) {
195
- const int n = codes.size ();
196
196
const int n_fft = 1280 ;
197
197
const int n_hop = 320 ;
198
198
const int n_win = 1280 ;
199
199
const int n_pad = (n_win - n_hop)/2 ;
200
- const int n_out = (n - 1 )*n_hop + n_win;
200
+ const int n_out = (n_codes - 1 )*n_hop + n_win;
201
201
202
- std::vector<double > hann (n_fft);
202
+ std::vector<float > hann (n_fft);
203
203
204
204
fill_hann_window (hann.size (), true , hann.data ());
205
205
206
- int n_spec = n_embd*n ;
206
+ int n_spec = n_embd*n_codes ;
207
207
208
- std::vector<double > E (n_spec);
209
- std::vector<double > S (n_spec);
210
- std::vector<double > ST (n_spec);
208
+ std::vector<float > E (n_spec);
209
+ std::vector<float > S (n_spec);
210
+ std::vector<float > ST (n_spec);
211
211
212
- for (int l = 0 ; l < n ; ++l) {
212
+ for (int l = 0 ; l < n_codes ; ++l) {
213
213
for (int k = 0 ; k < n_embd; ++k) {
214
- E[k*n + l] = embd[l*n_embd + k];
214
+ E[k*n_codes + l] = embd[l*n_embd + k];
215
215
}
216
216
}
217
217
218
218
for (int k = 0 ; k < n_embd/2 ; ++k) {
219
- for (int l = 0 ; l < n ; ++l) {
220
- double mag = E[(k )*n + l];
221
- double phi = E[(k + n_embd/2 )*n + l];
219
+ for (int l = 0 ; l < n_codes ; ++l) {
220
+ float mag = E[(k )*n_codes + l];
221
+ float phi = E[(k + n_embd/2 )*n_codes + l];
222
222
223
223
mag = exp (mag);
224
224
225
225
if (mag > 1e2 ) {
226
226
mag = 1e2 ;
227
227
}
228
- S[2 *(k*n + l) + 0 ] = mag*cosf (phi);
229
- S[2 *(k*n + l) + 1 ] = mag*sinf (phi);
228
+ S[2 *(k*n_codes + l) + 0 ] = mag*cosf (phi);
229
+ S[2 *(k*n_codes + l) + 1 ] = mag*sinf (phi);
230
230
}
231
231
}
232
232
233
- for (int l = 0 ; l < n ; ++l) {
233
+ for (int l = 0 ; l < n_codes ; ++l) {
234
234
for (int k = 0 ; k < n_embd/2 ; ++k) {
235
- ST[l*n_embd + 2 *k + 0 ] = S[2 *(k*n + l) + 0 ];
236
- ST[l*n_embd + 2 *k + 1 ] = S[2 *(k*n + l) + 1 ];
235
+ ST[l*n_embd + 2 *k + 0 ] = S[2 *(k*n_codes + l) + 0 ];
236
+ ST[l*n_embd + 2 *k + 1 ] = S[2 *(k*n_codes + l) + 1 ];
237
237
}
238
238
}
239
239
240
- std::vector<double > res (n *n_fft);
241
- std::vector<double > hann2 (n *n_fft);
240
+ std::vector<float > res (n_codes *n_fft);
241
+ std::vector<float > hann2 (n_codes *n_fft);
242
242
243
243
std::vector<std::thread> workers (n_thread);
244
244
for (int i = 0 ; i < n_thread; ++i) {
245
245
workers[i] = std::thread ([&, i]() {
246
- for (int l = i; l < n ; l += n_thread) {
246
+ for (int l = i; l < n_codes ; l += n_thread) {
247
247
irfft (n_fft, ST.data () + l*n_embd, res.data () + l*n_fft);
248
248
for (int j = 0 ; j < n_fft; ++j) {
249
249
res [l*n_fft + j] *= hann[j];
@@ -256,8 +256,8 @@ static std::vector<double> embd_to_audio(
256
256
workers[i].join ();
257
257
}
258
258
259
- std::vector<double > audio;
260
- std::vector<double > env;
259
+ std::vector<float > audio;
260
+ std::vector<float > env;
261
261
262
262
fold (res, n_out, n_win, n_hop, n_pad, audio);
263
263
fold (hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
@@ -844,12 +844,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
844
844
845
845
const auto t_voc_start = ggml_time_us ();
846
846
847
- llama_batch batch = llama_batch_init (codes.size (), 0 , 1 );
847
+ const int n_codes = codes.size ();
848
+
849
+ llama_batch batch = llama_batch_init (n_codes, 0 , 1 );
848
850
849
851
for (size_t i = 0 ; i < codes.size (); ++i) {
850
852
common_batch_add (batch, codes[i], i, { 0 }, true ); // TODO: all logits?
851
853
}
852
- GGML_ASSERT (batch.n_tokens == ( int ) codes. size () );
854
+ GGML_ASSERT (batch.n_tokens == n_codes );
853
855
854
856
if (llama_decode (ctx_cts, batch) != 0 ) {
855
857
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
@@ -862,12 +864,40 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
862
864
863
865
const auto t_spec_start = ggml_time_us ();
864
866
867
+ #if 1
865
868
// spectral operations
866
- // TODO: not optimized at all
867
869
const int n_embd = llama_n_embd (model_cts);
868
870
const float * embd = llama_get_embeddings (ctx_cts);
869
871
870
- auto audio = embd_to_audio (embd, codes, n_embd, params.cpuparams .n_threads );
872
+ auto audio = embd_to_audio (embd, n_codes, n_embd, params.cpuparams .n_threads );
873
+
874
+ #else
875
+ // read the spectrogram from a file for debugging purposes
876
+ std::vector<float> audio;
877
+ {
878
+ std::ifstream fin("out.bin", std::ios::binary);
879
+ if (!fin) {
880
+ LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin");
881
+ return 1;
882
+ }
883
+
884
+ std::vector<float> embd;
885
+
886
+ int n_codes;
887
+ int n_embd;
888
+
889
+ fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int));
890
+ fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int));
891
+
892
+ embd.resize(n_codes * n_embd);
893
+ fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float));
894
+ fin.close();
895
+
896
+ LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd);
897
+
898
+ audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
899
+ }
900
+ #endif
871
901
872
902
const std::string fname = " output.wav" ;
873
903
0 commit comments