@@ -210,6 +210,39 @@ struct whisper_print_user_data {
210
210
const std::vector<std::vector<float >> * pcmf32s;
211
211
};
212
212
213
+ std::string estimate_diarization_speaker (std::vector<std::vector<float >> pcmf32s, int64_t t0, int64_t t1, bool id_only = false ) {
214
+ std::string speaker = " " ;
215
+ const int64_t n_samples = pcmf32s[0 ].size ();
216
+
217
+ const int64_t is0 = timestamp_to_sample (t0, n_samples);
218
+ const int64_t is1 = timestamp_to_sample (t1, n_samples);
219
+
220
+ double energy0 = 0 .0f ;
221
+ double energy1 = 0 .0f ;
222
+
223
+ for (int64_t j = is0; j < is1; j++) {
224
+ energy0 += fabs (pcmf32s[0 ][j]);
225
+ energy1 += fabs (pcmf32s[1 ][j]);
226
+ }
227
+
228
+ if (energy0 > 1.1 *energy1) {
229
+ speaker = " 0" ;
230
+ } else if (energy1 > 1.1 *energy0) {
231
+ speaker = " 1" ;
232
+ } else {
233
+ speaker = " ?" ;
234
+ }
235
+
236
+ // printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
237
+
238
+ if (!id_only) {
239
+ speaker.insert (0 , " (speaker " );
240
+ speaker.append (" )" );
241
+ }
242
+
243
+ return speaker;
244
+ }
245
+
213
246
void whisper_print_segment_callback (struct whisper_context * ctx, struct whisper_state * /* state*/ , int n_new, void * user_data) {
214
247
const auto & params = *((whisper_print_user_data *) user_data)->params ;
215
248
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s ;
@@ -239,28 +272,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
239
272
}
240
273
241
274
if (params.diarize && pcmf32s.size () == 2 ) {
242
- const int64_t n_samples = pcmf32s[0 ].size ();
243
-
244
- const int64_t is0 = timestamp_to_sample (t0, n_samples);
245
- const int64_t is1 = timestamp_to_sample (t1, n_samples);
246
-
247
- double energy0 = 0 .0f ;
248
- double energy1 = 0 .0f ;
249
-
250
- for (int64_t j = is0; j < is1; j++) {
251
- energy0 += fabs (pcmf32s[0 ][j]);
252
- energy1 += fabs (pcmf32s[1 ][j]);
253
- }
254
-
255
- if (energy0 > 1.1 *energy1) {
256
- speaker = " (speaker 0)" ;
257
- } else if (energy1 > 1.1 *energy0) {
258
- speaker = " (speaker 1)" ;
259
- } else {
260
- speaker = " (speaker ?)" ;
261
- }
262
-
263
- // printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
275
+ speaker = estimate_diarization_speaker (pcmf32s, t0, t1);
264
276
}
265
277
266
278
if (params.print_colors ) {
@@ -294,7 +306,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
294
306
}
295
307
}
296
308
297
- bool output_txt (struct whisper_context * ctx, const char * fname) {
309
+ bool output_txt (struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector< float >> pcmf32s ) {
298
310
std::ofstream fout (fname);
299
311
if (!fout.is_open ()) {
300
312
fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname);
@@ -306,13 +318,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
306
318
const int n_segments = whisper_full_n_segments (ctx);
307
319
for (int i = 0 ; i < n_segments; ++i) {
308
320
const char * text = whisper_full_get_segment_text (ctx, i);
309
- fout << text << " \n " ;
321
+ std::string speaker = " " ;
322
+
323
+ if (params.diarize && pcmf32s.size () == 2 )
324
+ {
325
+ const int64_t t0 = whisper_full_get_segment_t0 (ctx, i);
326
+ const int64_t t1 = whisper_full_get_segment_t1 (ctx, i);
327
+ speaker = estimate_diarization_speaker (pcmf32s, t0, t1);
328
+ }
329
+
330
+ fout << speaker << text << " \n " ;
310
331
}
311
332
312
333
return true ;
313
334
}
314
335
315
- bool output_vtt (struct whisper_context * ctx, const char * fname) {
336
+ bool output_vtt (struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector< float >> pcmf32s ) {
316
337
std::ofstream fout (fname);
317
338
if (!fout.is_open ()) {
318
339
fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname);
@@ -328,15 +349,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
328
349
const char * text = whisper_full_get_segment_text (ctx, i);
329
350
const int64_t t0 = whisper_full_get_segment_t0 (ctx, i);
330
351
const int64_t t1 = whisper_full_get_segment_t1 (ctx, i);
352
+ std::string speaker = " " ;
353
+
354
+ if (params.diarize && pcmf32s.size () == 2 )
355
+ {
356
+ speaker = estimate_diarization_speaker (pcmf32s, t0, t1, true );
357
+ speaker.insert (0 , " <v Speaker" );
358
+ speaker.append (" >" );
359
+ }
331
360
332
361
fout << to_timestamp (t0) << " --> " << to_timestamp (t1) << " \n " ;
333
- fout << text << " \n\n " ;
362
+ fout << speaker << text << " \n\n " ;
334
363
}
335
364
336
365
return true ;
337
366
}
338
367
339
- bool output_srt (struct whisper_context * ctx, const char * fname, const whisper_params & params) {
368
+ bool output_srt (struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector< float >> pcmf32s ) {
340
369
std::ofstream fout (fname);
341
370
if (!fout.is_open ()) {
342
371
fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname);
@@ -350,10 +379,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
350
379
const char * text = whisper_full_get_segment_text (ctx, i);
351
380
const int64_t t0 = whisper_full_get_segment_t0 (ctx, i);
352
381
const int64_t t1 = whisper_full_get_segment_t1 (ctx, i);
382
+ std::string speaker = " " ;
383
+
384
+ if (params.diarize && pcmf32s.size () == 2 )
385
+ {
386
+ speaker = estimate_diarization_speaker (pcmf32s, t0, t1);
387
+ }
353
388
354
389
fout << i + 1 + params.offset_n << " \n " ;
355
390
fout << to_timestamp (t0, true ) << " --> " << to_timestamp (t1, true ) << " \n " ;
356
- fout << text << " \n\n " ;
391
+ fout << speaker << text << " \n\n " ;
357
392
}
358
393
359
394
return true ;
@@ -390,7 +425,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
390
425
return escaped;
391
426
}
392
427
393
- bool output_csv (struct whisper_context * ctx, const char * fname) {
428
+ bool output_csv (struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector< float >> pcmf32s ) {
394
429
std::ofstream fout (fname);
395
430
if (!fout.is_open ()) {
396
431
fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname);
@@ -400,21 +435,32 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
400
435
fprintf (stderr, " %s: saving output to '%s'\n " , __func__, fname);
401
436
402
437
const int n_segments = whisper_full_n_segments (ctx);
403
- fout << " start,end,text\n " ;
438
+ fout << " start,end," ;
439
+ if (params.diarize && pcmf32s.size () == 2 )
440
+ {
441
+ fout << " speaker," ;
442
+ }
443
+ fout << " text\n " ;
444
+
404
445
for (int i = 0 ; i < n_segments; ++i) {
405
446
const char * text = whisper_full_get_segment_text (ctx, i);
406
447
const int64_t t0 = whisper_full_get_segment_t0 (ctx, i);
407
448
const int64_t t1 = whisper_full_get_segment_t1 (ctx, i);
408
449
char * text_escaped = escape_double_quotes_and_backslashes (text);
409
450
410
451
// need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
411
- fout << 10 * t0 << " ," << 10 * t1 << " ,\" " << text_escaped << " \"\n " ;
452
+ fout << 10 * t0 << " ," << 10 * t1 << " ," ;
453
+ if (params.diarize && pcmf32s.size () == 2 )
454
+ {
455
+ fout << estimate_diarization_speaker (pcmf32s, t0, t1, true ) << " ," ;
456
+ }
457
+ fout << " \" " << text_escaped << " \"\n " ;
412
458
}
413
459
414
460
return true ;
415
461
}
416
462
417
- bool output_json (struct whisper_context * ctx, const char * fname, const whisper_params & params) {
463
+ bool output_json (struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector< float >> pcmf32s ) {
418
464
std::ofstream fout (fname);
419
465
int indent = 0 ;
420
466
@@ -530,7 +576,11 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
530
576
value_i (" from" , t0 * 10 , false );
531
577
value_i (" to" , t1 * 10 , true );
532
578
end_obj (false );
533
- value_s (" text" , text, true );
579
+ value_s (" text" , text, !params.diarize );
580
+
581
+ if (params.diarize && pcmf32s.size () == 2 ) {
582
+ value_s (" speaker" , estimate_diarization_speaker (pcmf32s, t0, t1, true ).c_str (), true );
583
+ }
534
584
end_obj (i == (n_segments - 1 ));
535
585
}
536
586
@@ -542,7 +592,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
542
592
// karaoke video generation
543
593
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
544
594
// TODO: font parameter adjustments
545
- bool output_wts (struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
595
+ bool output_wts (struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector< float >> pcmf32s ) {
546
596
std::ofstream fout (fname);
547
597
548
598
fprintf (stderr, " %s: saving output to '%s'\n " , __func__, fname);
@@ -579,6 +629,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
579
629
fout << " drawtext=fontfile='" << font << " ':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << " ," << t0/100.0 << " )'" ;
580
630
581
631
bool is_first = true ;
632
+ std::string speaker = " " ;
633
+
634
+ if (params.diarize && pcmf32s.size () == 2 ) {
635
+ speaker = estimate_diarization_speaker (pcmf32s, t0, t1);
636
+ }
582
637
583
638
for (int j = 0 ; j < n; ++j) {
584
639
const auto & token = tokens[j];
@@ -587,13 +642,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
587
642
continue ;
588
643
}
589
644
590
- std::string txt_bg;
591
- std::string txt_fg; // highlight token
592
- std::string txt_ul; // underline
645
+ std::string txt_bg = " " ;
646
+ std::string txt_fg = " " ; // highlight token
647
+ std::string txt_ul = " " ; // underline
593
648
594
- txt_bg = " > " ;
595
- txt_fg = " > " ;
596
- txt_ul = " \\ \\ " ;
649
+ if (params.diarize && pcmf32s.size () == 2 ) {
650
+ txt_bg = speaker;
651
+ txt_fg = speaker;
652
+ txt_ul = " \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ " ;
653
+ }
654
+
655
+ txt_bg.append (" > " );
656
+ txt_fg.append (" > " );
657
+ txt_ul.append (" \\ \\ " );
597
658
598
659
{
599
660
for (int k = 0 ; k < n; ++k) {
@@ -656,8 +717,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
656
717
return true ;
657
718
}
658
719
659
- bool output_lrc (struct whisper_context * ctx, const char * fname) {
660
-
720
+ bool output_lrc (struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float >> pcmf32s) {
661
721
std::ofstream fout (fname);
662
722
if (!fout.is_open ()) {
663
723
fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname);
@@ -682,8 +742,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) {
682
742
char buf[16 ];
683
743
snprintf (buf, sizeof (buf), " %02d:%02d.%02d" , (int ) min, (int ) sec, (int ) ( msec / 10 ));
684
744
std::string timestamp_lrc = std::string (buf);
745
+ std::string speaker = " " ;
746
+
747
+ if (params.diarize && pcmf32s.size () == 2 )
748
+ {
749
+ const int64_t t0 = whisper_full_get_segment_t0 (ctx, i);
750
+ const int64_t t1 = whisper_full_get_segment_t1 (ctx, i);
751
+ speaker = estimate_diarization_speaker (pcmf32s, t0, t1);
752
+ }
685
753
686
- fout << ' [' << timestamp_lrc << ' ]' << text << " \n " ;
754
+ fout << ' [' << timestamp_lrc << ' ]' << speaker << text << " \n " ;
687
755
}
688
756
689
757
return true ;
@@ -828,43 +896,43 @@ int main(int argc, char ** argv) {
828
896
// output to text file
829
897
if (params.output_txt ) {
830
898
const auto fname_txt = fname_out + " .txt" ;
831
- output_txt (ctx, fname_txt.c_str ());
899
+ output_txt (ctx, fname_txt.c_str (), params, pcmf32s );
832
900
}
833
901
834
902
// output to VTT file
835
903
if (params.output_vtt ) {
836
904
const auto fname_vtt = fname_out + " .vtt" ;
837
- output_vtt (ctx, fname_vtt.c_str ());
905
+ output_vtt (ctx, fname_vtt.c_str (), params, pcmf32s );
838
906
}
839
907
840
908
// output to SRT file
841
909
if (params.output_srt ) {
842
910
const auto fname_srt = fname_out + " .srt" ;
843
- output_srt (ctx, fname_srt.c_str (), params);
911
+ output_srt (ctx, fname_srt.c_str (), params, pcmf32s );
844
912
}
845
913
846
914
// output to WTS file
847
915
if (params.output_wts ) {
848
916
const auto fname_wts = fname_out + " .wts" ;
849
- output_wts (ctx, fname_wts.c_str (), fname_inp.c_str (), params, float (pcmf32.size () + 1000 )/WHISPER_SAMPLE_RATE);
917
+ output_wts (ctx, fname_wts.c_str (), fname_inp.c_str (), params, float (pcmf32.size () + 1000 )/WHISPER_SAMPLE_RATE, pcmf32s );
850
918
}
851
919
852
920
// output to CSV file
853
921
if (params.output_csv ) {
854
922
const auto fname_csv = fname_out + " .csv" ;
855
- output_csv (ctx, fname_csv.c_str ());
923
+ output_csv (ctx, fname_csv.c_str (), params, pcmf32s );
856
924
}
857
925
858
926
// output to JSON file
859
927
if (params.output_jsn ) {
860
928
const auto fname_jsn = fname_out + " .json" ;
861
- output_json (ctx, fname_jsn.c_str (), params);
929
+ output_json (ctx, fname_jsn.c_str (), params, pcmf32s );
862
930
}
863
931
864
932
// output to LRC file
865
933
if (params.output_lrc ) {
866
934
const auto fname_lrc = fname_out + " .lrc" ;
867
- output_lrc (ctx, fname_lrc.c_str ());
935
+ output_lrc (ctx, fname_lrc.c_str (), params, pcmf32s );
868
936
}
869
937
}
870
938
}
0 commit comments