From 8c7c0188934be233fc6732106e66313778b0e144 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 8 Oct 2022 17:22:22 +0300 Subject: [PATCH] ref #17 : add options to output result to file Support for: - plain text - VTT - SRT --- main.cpp | 98 +++++++++++++++++++++++++++++++++++++++++++++++++---- whisper.cpp | 2 +- 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/main.cpp b/main.cpp index 9769b7f..728ab6f 100644 --- a/main.cpp +++ b/main.cpp @@ -5,6 +5,7 @@ #define DR_WAV_IMPLEMENTATION #include "dr_wav.h" +#include #include #include #include @@ -32,6 +33,9 @@ struct whisper_params { bool verbose = false; bool translate = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; bool print_special_tokens = false; bool no_timestamps = false; @@ -69,6 +73,12 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } + } else if (arg == "-otxt" || arg == "--output-txt") { + params.output_txt = true; + } else if (arg == "-ovtt" || arg == "--output-vtt") { + params.output_vtt = true; + } else if (arg == "-osrt" || arg == "--output-srt") { + params.output_srt = true; } else if (arg == "-ps" || arg == "--print_special") { params.print_special_tokens = true; } else if (arg == "-nt" || arg == "--no_timestamps") { @@ -101,6 +111,8 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -o N, --offset N offset in milliseconds (default: %d)\n", params.offset_ms); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); + fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); + fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); @@ -123,7 +135,7 @@ int main(int argc, char ** argv) { if (params.fname_inp.empty()) { fprintf(stderr, "error: no input files specified\n"); whisper_print_usage(argc, argv, params); - return 1; + return 2; } // whisper init @@ -140,22 +152,22 @@ int main(int argc, char ** argv) { if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) { fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); whisper_print_usage(argc, argv, {}); - return 2; + return 3; } if (wav.channels != 1 && wav.channels != 2) { fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); - return 3; + return 4; } if (wav.sampleRate != WHISPER_SAMPLE_RATE) { fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); - return 4; + return 5; } if (wav.bitsPerSample != 16) { fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); - return 5; + return 6; } int n = wav.totalPCMFrameCount; @@ -193,9 +205,11 @@ int main(int argc, char ** argv) { params.language.c_str(), params.translate ? "translate" : "transcribe", params.no_timestamps ? 0 : 1); + printf("\n"); } + // run the inference { whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); @@ -211,10 +225,10 @@ int main(int argc, char ** argv) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 6; + return 7; } - // print result; + // print result if (!wparams.print_realtime) { printf("\n"); @@ -233,6 +247,76 @@ int main(int argc, char ** argv) { } } } + + printf("\n"); + + // output to text file + if (params.output_txt) { + + const auto fname_txt = fname_inp + ".txt"; + std::ofstream fout_txt(fname_txt); + if (!fout_txt.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_txt.c_str()); + return 8; + } + + printf("%s: saving output to '%s.txt'\n", __func__, fname_inp.c_str()); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + fout_txt << text; + } + } + + // output to VTT file + if (params.output_vtt) { + + const auto fname_vtt = fname_inp + ".vtt"; + std::ofstream fout_vtt(fname_vtt); + if (!fout_vtt.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_vtt.c_str()); + return 9; + } + + printf("%s: saving output to '%s.vtt'\n", __func__, fname_inp.c_str()); + + fout_vtt << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + fout_vtt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + fout_vtt << text << "\n\n"; + } + } + + // output to SRT file + if (params.output_srt) { + + const auto fname_srt = fname_inp + ".srt"; + std::ofstream fout_srt(fname_srt); + if (!fout_srt.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_srt.c_str()); + return 10; + } + + printf("%s: saving output to '%s.srt'\n", __func__, fname_inp.c_str()); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + fout_srt << i + 1 << "\n"; + fout_srt << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + fout_srt << text << "\n\n"; + } + } } } diff --git a/whisper.cpp b/whisper.cpp index af89815..b59cfd7 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2242,7 +2242,7 @@ whisper_token whisper_token_transcribe() { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); - printf("\n\n"); + printf("\n"); printf("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f); printf("%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f); printf("%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);