Skip to content

Commit 14baf2e

Browse files
colincggerganov
andauthored
main : add diarization support for all current output types (ggml-org#1031)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent bc2dcf8 commit 14baf2e

File tree

1 file changed

+118
-50
lines changed

1 file changed

+118
-50
lines changed

examples/main/main.cpp

Lines changed: 118 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,39 @@ struct whisper_print_user_data {
210210
const std::vector<std::vector<float>> * pcmf32s;
211211
};
212212

213+
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
214+
std::string speaker = "";
215+
const int64_t n_samples = pcmf32s[0].size();
216+
217+
const int64_t is0 = timestamp_to_sample(t0, n_samples);
218+
const int64_t is1 = timestamp_to_sample(t1, n_samples);
219+
220+
double energy0 = 0.0f;
221+
double energy1 = 0.0f;
222+
223+
for (int64_t j = is0; j < is1; j++) {
224+
energy0 += fabs(pcmf32s[0][j]);
225+
energy1 += fabs(pcmf32s[1][j]);
226+
}
227+
228+
if (energy0 > 1.1*energy1) {
229+
speaker = "0";
230+
} else if (energy1 > 1.1*energy0) {
231+
speaker = "1";
232+
} else {
233+
speaker = "?";
234+
}
235+
236+
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
237+
238+
if (!id_only) {
239+
speaker.insert(0, "(speaker ");
240+
speaker.append(")");
241+
}
242+
243+
return speaker;
244+
}
245+
213246
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
214247
const auto & params = *((whisper_print_user_data *) user_data)->params;
215248
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@@ -239,28 +272,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
239272
}
240273

241274
if (params.diarize && pcmf32s.size() == 2) {
242-
const int64_t n_samples = pcmf32s[0].size();
243-
244-
const int64_t is0 = timestamp_to_sample(t0, n_samples);
245-
const int64_t is1 = timestamp_to_sample(t1, n_samples);
246-
247-
double energy0 = 0.0f;
248-
double energy1 = 0.0f;
249-
250-
for (int64_t j = is0; j < is1; j++) {
251-
energy0 += fabs(pcmf32s[0][j]);
252-
energy1 += fabs(pcmf32s[1][j]);
253-
}
254-
255-
if (energy0 > 1.1*energy1) {
256-
speaker = "(speaker 0)";
257-
} else if (energy1 > 1.1*energy0) {
258-
speaker = "(speaker 1)";
259-
} else {
260-
speaker = "(speaker ?)";
261-
}
262-
263-
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
275+
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
264276
}
265277

266278
if (params.print_colors) {
@@ -294,7 +306,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
294306
}
295307
}
296308

297-
bool output_txt(struct whisper_context * ctx, const char * fname) {
309+
bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
298310
std::ofstream fout(fname);
299311
if (!fout.is_open()) {
300312
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -306,13 +318,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
306318
const int n_segments = whisper_full_n_segments(ctx);
307319
for (int i = 0; i < n_segments; ++i) {
308320
const char * text = whisper_full_get_segment_text(ctx, i);
309-
fout << text << "\n";
321+
std::string speaker = "";
322+
323+
if (params.diarize && pcmf32s.size() == 2)
324+
{
325+
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
326+
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
327+
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
328+
}
329+
330+
fout << speaker << text << "\n";
310331
}
311332

312333
return true;
313334
}
314335

315-
bool output_vtt(struct whisper_context * ctx, const char * fname) {
336+
bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
316337
std::ofstream fout(fname);
317338
if (!fout.is_open()) {
318339
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -328,15 +349,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
328349
const char * text = whisper_full_get_segment_text(ctx, i);
329350
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
330351
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
352+
std::string speaker = "";
353+
354+
if (params.diarize && pcmf32s.size() == 2)
355+
{
356+
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
357+
speaker.insert(0, "<v Speaker");
358+
speaker.append(">");
359+
}
331360

332361
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
333-
fout << text << "\n\n";
362+
fout << speaker << text << "\n\n";
334363
}
335364

336365
return true;
337366
}
338367

339-
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
368+
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
340369
std::ofstream fout(fname);
341370
if (!fout.is_open()) {
342371
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -350,10 +379,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
350379
const char * text = whisper_full_get_segment_text(ctx, i);
351380
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
352381
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
382+
std::string speaker = "";
383+
384+
if (params.diarize && pcmf32s.size() == 2)
385+
{
386+
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
387+
}
353388

354389
fout << i + 1 + params.offset_n << "\n";
355390
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
356-
fout << text << "\n\n";
391+
fout << speaker << text << "\n\n";
357392
}
358393

359394
return true;
@@ -390,7 +425,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
390425
return escaped;
391426
}
392427

393-
bool output_csv(struct whisper_context * ctx, const char * fname) {
428+
bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
394429
std::ofstream fout(fname);
395430
if (!fout.is_open()) {
396431
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -400,21 +435,32 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
400435
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
401436

402437
const int n_segments = whisper_full_n_segments(ctx);
403-
fout << "start,end,text\n";
438+
fout << "start,end,";
439+
if (params.diarize && pcmf32s.size() == 2)
440+
{
441+
fout << "speaker,";
442+
}
443+
fout << "text\n";
444+
404445
for (int i = 0; i < n_segments; ++i) {
405446
const char * text = whisper_full_get_segment_text(ctx, i);
406447
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
407448
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
408449
char * text_escaped = escape_double_quotes_and_backslashes(text);
409450

410451
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
411-
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
452+
fout << 10 * t0 << "," << 10 * t1 << ",";
453+
if (params.diarize && pcmf32s.size() == 2)
454+
{
455+
fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
456+
}
457+
fout << "\"" << text_escaped << "\"\n";
412458
}
413459

414460
return true;
415461
}
416462

417-
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
463+
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
418464
std::ofstream fout(fname);
419465
int indent = 0;
420466

@@ -530,7 +576,11 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
530576
value_i("from", t0 * 10, false);
531577
value_i("to", t1 * 10, true);
532578
end_obj(false);
533-
value_s("text", text, true);
579+
value_s("text", text, !params.diarize);
580+
581+
if (params.diarize && pcmf32s.size() == 2) {
582+
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
583+
}
534584
end_obj(i == (n_segments - 1));
535585
}
536586

@@ -542,7 +592,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
542592
// karaoke video generation
543593
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
544594
// TODO: font parameter adjustments
545-
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
595+
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
546596
std::ofstream fout(fname);
547597

548598
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@@ -579,6 +629,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
579629
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
580630

581631
bool is_first = true;
632+
std::string speaker = "";
633+
634+
if (params.diarize && pcmf32s.size() == 2) {
635+
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
636+
}
582637

583638
for (int j = 0; j < n; ++j) {
584639
const auto & token = tokens[j];
@@ -587,13 +642,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
587642
continue;
588643
}
589644

590-
std::string txt_bg;
591-
std::string txt_fg; // highlight token
592-
std::string txt_ul; // underline
645+
std::string txt_bg = "";
646+
std::string txt_fg = ""; // highlight token
647+
std::string txt_ul = ""; // underline
593648

594-
txt_bg = "> ";
595-
txt_fg = "> ";
596-
txt_ul = "\\ \\ ";
649+
if (params.diarize && pcmf32s.size() == 2) {
650+
txt_bg = speaker;
651+
txt_fg = speaker;
652+
txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
653+
}
654+
655+
txt_bg.append("> ");
656+
txt_fg.append("> ");
657+
txt_ul.append("\\ \\ ");
597658

598659
{
599660
for (int k = 0; k < n; ++k) {
@@ -656,8 +717,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
656717
return true;
657718
}
658719

659-
bool output_lrc(struct whisper_context * ctx, const char * fname) {
660-
720+
bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
661721
std::ofstream fout(fname);
662722
if (!fout.is_open()) {
663723
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@@ -682,8 +742,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) {
682742
char buf[16];
683743
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
684744
std::string timestamp_lrc = std::string(buf);
745+
std::string speaker = "";
746+
747+
if (params.diarize && pcmf32s.size() == 2)
748+
{
749+
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
750+
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
751+
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
752+
}
685753

686-
fout << '[' << timestamp_lrc << ']' << text << "\n";
754+
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
687755
}
688756

689757
return true;
@@ -828,43 +896,43 @@ int main(int argc, char ** argv) {
828896
// output to text file
829897
if (params.output_txt) {
830898
const auto fname_txt = fname_out + ".txt";
831-
output_txt(ctx, fname_txt.c_str());
899+
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
832900
}
833901

834902
// output to VTT file
835903
if (params.output_vtt) {
836904
const auto fname_vtt = fname_out + ".vtt";
837-
output_vtt(ctx, fname_vtt.c_str());
905+
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
838906
}
839907

840908
// output to SRT file
841909
if (params.output_srt) {
842910
const auto fname_srt = fname_out + ".srt";
843-
output_srt(ctx, fname_srt.c_str(), params);
911+
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
844912
}
845913

846914
// output to WTS file
847915
if (params.output_wts) {
848916
const auto fname_wts = fname_out + ".wts";
849-
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
917+
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
850918
}
851919

852920
// output to CSV file
853921
if (params.output_csv) {
854922
const auto fname_csv = fname_out + ".csv";
855-
output_csv(ctx, fname_csv.c_str());
923+
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
856924
}
857925

858926
// output to JSON file
859927
if (params.output_jsn) {
860928
const auto fname_jsn = fname_out + ".json";
861-
output_json(ctx, fname_jsn.c_str(), params);
929+
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
862930
}
863931

864932
// output to LRC file
865933
if (params.output_lrc) {
866934
const auto fname_lrc = fname_out + ".lrc";
867-
output_lrc(ctx, fname_lrc.c_str());
935+
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
868936
}
869937
}
870938
}

0 commit comments

Comments
 (0)