diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0affdab..fa399c6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -70,6 +70,7 @@ struct whisper_params { float logprob_thold = -1.00f; bool speed_up = false; + bool debug_mode = false; bool translate = false; bool detect_language = false; bool diarize = false; @@ -135,7 +136,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } @@ -190,7 +192,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); @@ -915,6 +918,7 @@ int main(int argc, char ** argv) { wparams.split_on_word = params.split_on_word; wparams.speed_up = params.speed_up; + wparams.debug_mode = params.debug_mode; wparams.tdrz_enable = params.tinydiarize; // [TDRZ] diff --git a/whisper.cpp b/whisper.cpp index e1cc6b7..9cdb271 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2445,41 +2445,51 @@ static void fft(const std::vector & in, std::vector & out) { } } -static void log_mel_spectrogram_worker_thread(int ith, const std::vector &hann, const float *samples, - int n_samples, int fft_size, int fft_step, int n_threads, - const whisper_filters &filters, bool speed_up, whisper_mel &mel) { - std::vector fft_in(fft_size, 0.0); - std::vector fft_out(2 * fft_size); - int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2); +static bool hann_window(int length, bool periodic, std::vector & output) { + if (output.size() < length) { + output.resize(length); + } + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset))); + } - for (int i = ith; i < mel.n_len; i += n_threads) { - const int offset = i * fft_step; + return true; +} - // apply Hanning window - for (int j = 0; j < fft_size; j++) { - if (offset + j < n_samples) { - fft_in[j] = hann[j] * samples[offset + j]; - } else { - fft_in[j] = 0.0; - } +static void log_mel_spectrogram_worker_thread(int ith, const std::vector & hann, const std::vector & samples, + int n_samples, int frame_size, int frame_step, int n_threads, + const whisper_filters & filters, whisper_mel & mel) { + std::vector fft_in(frame_size, 0.0); + std::vector fft_out(2 * frame_step); + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + int n_fft = 1 + (frame_size / 2); + int i = ith; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + + // apply Hanning window (~10% faster) + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + fft_in[j] = hann[j] * samples[offset + j]; + } + // fill the rest with zeros + if (n_samples - offset < frame_size) { + std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); } - // FFT -> mag^2 + // FFT fft(fft_in, fft_out); - for (int j = 0; j < fft_size; j++) { + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < frame_size; j++) { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } - for (int j = 1; j < fft_size / 2; j++) { - fft_out[j] += fft_out[fft_size - j]; - } - - if (speed_up) { - // scale down in the frequency domain results in a speed up in the time domain - for (int j = 0; j < n_fft; j++) { - fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]); - } - } // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { @@ -2489,10 +2499,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector int k = 0; for (k = 0; k < n_fft - 3; k += 4) { sum += - fft_out[k + 0] * filters.data[j*n_fft + k + 0] + - fft_out[k + 1] * filters.data[j*n_fft + k + 1] + - fft_out[k + 2] * filters.data[j*n_fft + k + 2] + - fft_out[k + 3] * filters.data[j*n_fft + k + 3]; + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; } // handle n_fft remainder @@ -2505,68 +2515,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector mel.data[j * mel.n_len + i] = sum; } } + + // Otherwise fft_out are all zero + double sum = log10(1e-10); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[j * mel.n_len + i] = sum; + } + } } -// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 static bool log_mel_spectrogram( - whisper_state & wstate, - const float * samples, + whisper_state & wstate, + const float * samples, const int n_samples, const int /*sample_rate*/, - const int fft_size, - const int fft_step, + const int frame_size, + const int frame_step, const int n_mel, const int n_threads, - const whisper_filters & filters, - const bool speed_up, - whisper_mel & mel) { + const whisper_filters & filters, + const bool debug, + whisper_mel & mel) { const int64_t t_start_us = ggml_time_us(); - // Hanning window + // Hanning window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 std::vector hann; - hann.resize(fft_size); - for (int i = 0; i < fft_size; i++) { - hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); - } + hann_window(frame_size, true, hann); + + + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = frame_size / 2; + + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); mel.n_mel = n_mel; - mel.n_len = n_samples/fft_step; - mel.n_len_org = mel.n_len; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - frame_size) / frame_step; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; + mel.data.resize(mel.n_mel * mel.n_len); - std::vector samples_padded; - - // pad audio with at least one extra chunk of zeros - { - const int pad = (100*WHISPER_CHUNK_SIZE)/2; - - if (mel.n_len % pad != 0) { - mel.n_len = (mel.n_len/pad + 1)*pad; - } - mel.n_len += pad; - - samples_padded.resize(mel.n_len*fft_step); - memcpy(samples_padded.data(), samples, n_samples*sizeof(float)); - memset(samples_padded.data() + n_samples, 0, (mel.n_len*fft_step - n_samples)*sizeof(float)); - - samples = samples_padded.data(); - } - - mel.data.resize(mel.n_mel*mel.n_len); - - //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); - //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples, - n_samples, fft_size, fft_step, n_threads, - std::cref(filters), speed_up, std::ref(mel)); + log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, + n_samples + stage_2_pad, frame_size, frame_step, n_threads, + std::cref(filters), std::ref(mel)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); @@ -2580,7 +2595,6 @@ static bool log_mel_spectrogram( mmax = mel.data[i]; } } - //printf("%s: max = %f\n", __func__, mmax); mmax -= 8.0; @@ -2594,7 +2608,16 @@ static bool log_mel_spectrogram( wstate.t_mel_us += ggml_time_us() - t_start_us; - //printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step); + // Dump log_mel_spectrogram + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } return true; } @@ -3026,9 +3049,9 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); } -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, state->mel)) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { log("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3036,11 +3059,20 @@ int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, st return 0; } -// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); } +// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2 +// TODO + int whisper_set_mel_with_state( struct whisper_context * /*ctx*/, struct whisper_state * state, @@ -3492,6 +3524,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_tokens =*/ 0, /*.speed_up =*/ false, + /*.debug_mode =*/ false, /*.audio_ctx =*/ 0, /*.tdrz_enable =*/ false, @@ -3653,7 +3686,7 @@ static void whisper_process_logits( WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); // extract the logits for the last token - // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly auto & probs = decoder.probs; auto & logits = decoder.logits; auto & logprobs = decoder.logprobs; @@ -4056,10 +4089,9 @@ int whisper_full_with_state( // compute log mel spectrogram if (params.speed_up) { - if (whisper_pcm_to_mel_phase_vocoder_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - log("%s: failed to compute log mel spectrogram\n", __func__); - return -1; - } + // TODO: Replace PV with more advanced algorithm + log("%s: failed to compute log mel spectrogram\n", __func__); + return -1; } else { if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { log("%s: failed to compute log mel spectrogram\n", __func__); @@ -4095,8 +4127,8 @@ int whisper_full_with_state( const int seek_start = params.offset_ms/10; const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10; - // if length of spectrogram is less than 1s (100 samples), then return - // basically don't process anything that is less than 1s + // if length of spectrogram is less than 1.0s (100 frames), then return + // basically don't process anything that is less than 1.0s // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { return 0; diff --git a/whisper.h b/whisper.h index 588c287..73ab4d7 100644 --- a/whisper.h +++ b/whisper.h @@ -375,6 +375,7 @@ extern "C" { // [EXPERIMENTAL] speed-up techniques // note: these can significantly reduce the quality of the output bool speed_up; // speed-up the audio by 2x using Phase Vocoder + bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) int audio_ctx; // overwrite the audio context size (0 = use default) // [EXPERIMENTAL] [TDRZ] tinydiarize