Allow a regular expression to describe tokens to suppress.
Example: --suppress-tokens-re "[,\.]|[ ]?[0-9]+" will suppress commas, periods, and numeric tokens. Technique inspired by https://github.com/openai/whisper/discussions/1041
This commit is contained in:
parent
1558ec5a16
commit
326b1eed51
|
@ -52,6 +52,9 @@ struct whisper_params {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::string context;
|
std::string context;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
|
|
||||||
|
// A regular expression that matches tokens to suppress
|
||||||
|
std::string suppress_tokens_re;
|
||||||
};
|
};
|
||||||
|
|
||||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||||
|
@ -85,6 +88,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||||
|
else if ( arg == "--suppress-tokens-re") { params.suppress_tokens_re = argv[++i]; }
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
whisper_print_usage(argc, argv, params);
|
whisper_print_usage(argc, argv, params);
|
||||||
|
@ -122,6 +126,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||||
|
fprintf(stderr, " --suppress-tokens-re REGEX [%-7s] regular expression matching tokens to supporess\n", params.suppress_tokens_re.c_str());
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,6 +172,8 @@ std::string transcribe(
|
||||||
|
|
||||||
wparams.initial_prompt = params.context.data();
|
wparams.initial_prompt = params.context.data();
|
||||||
|
|
||||||
|
wparams.suppress_tokens_re = params.suppress_tokens_re.c_str();
|
||||||
|
|
||||||
const auto & grammar_parsed = params.grammar_parsed;
|
const auto & grammar_parsed = params.grammar_parsed;
|
||||||
auto grammar_rules = grammar_parsed.c_rules();
|
auto grammar_rules = grammar_parsed.c_rules();
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -74,6 +75,9 @@ struct whisper_params {
|
||||||
// [TDRZ] speaker turn string
|
// [TDRZ] speaker turn string
|
||||||
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
||||||
|
|
||||||
|
// A regular expression that matches tokens to suppress
|
||||||
|
std::string suppress_tokens_re;
|
||||||
|
|
||||||
std::string openvino_encode_device = "CPU";
|
std::string openvino_encode_device = "CPU";
|
||||||
|
|
||||||
std::string dtw = "";
|
std::string dtw = "";
|
||||||
|
@ -154,6 +158,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||||
|
else if ( arg == "--suppress-tokens-re") { params.suppress_tokens_re = argv[++i]; }
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
whisper_print_usage(argc, argv, params);
|
whisper_print_usage(argc, argv, params);
|
||||||
|
@ -214,6 +219,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||||
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
||||||
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
||||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||||
|
fprintf(stderr, " --suppress-tokens-re REGEX [%-7s] regular expression matching tokens to supporess\n", params.suppress_tokens_re.c_str());
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -997,6 +1003,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||||
|
|
||||||
|
wparams.suppress_tokens_re = params.suppress_tokens_re.c_str();
|
||||||
|
|
||||||
wparams.initial_prompt = params.prompt.c_str();
|
wparams.initial_prompt = params.prompt.c_str();
|
||||||
|
|
||||||
wparams.greedy.best_of = params.best_of;
|
wparams.greedy.best_of = params.best_of;
|
||||||
|
|
12
whisper.cpp
12
whisper.cpp
|
@ -4553,6 +4553,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||||
|
|
||||||
/*.tdrz_enable =*/ false,
|
/*.tdrz_enable =*/ false,
|
||||||
|
|
||||||
|
/* suppress_tokens_re */ nullptr,
|
||||||
|
|
||||||
/*.initial_prompt =*/ nullptr,
|
/*.initial_prompt =*/ nullptr,
|
||||||
/*.prompt_tokens =*/ nullptr,
|
/*.prompt_tokens =*/ nullptr,
|
||||||
/*.prompt_n_tokens =*/ 0,
|
/*.prompt_n_tokens =*/ 0,
|
||||||
|
@ -4796,6 +4798,16 @@ static void whisper_process_logits(
|
||||||
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// suppress any tokens matching a regular expression
|
||||||
|
// ref: https://github.com/openai/whisper/discussions/1041
|
||||||
|
if (params.suppress_tokens_re != nullptr)
|
||||||
|
{
|
||||||
|
std::regex re(params.suppress_tokens_re);
|
||||||
|
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id)
|
||||||
|
if (std::regex_match(token_id.first, re))
|
||||||
|
logits[token_id.second] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
// suppress non-speech tokens
|
// suppress non-speech tokens
|
||||||
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||||
if (params.suppress_non_speech_tokens) {
|
if (params.suppress_non_speech_tokens) {
|
||||||
|
|
|
@ -505,6 +505,9 @@ extern "C" {
|
||||||
// [EXPERIMENTAL] [TDRZ] tinydiarize
|
// [EXPERIMENTAL] [TDRZ] tinydiarize
|
||||||
bool tdrz_enable; // enable tinydiarize speaker turn detection
|
bool tdrz_enable; // enable tinydiarize speaker turn detection
|
||||||
|
|
||||||
|
// A regular expression that matches tokens to suppress
|
||||||
|
const char * suppress_tokens_re;
|
||||||
|
|
||||||
// tokens to provide to the whisper decoder as initial prompt
|
// tokens to provide to the whisper decoder as initial prompt
|
||||||
// these are prepended to any existing text context from a previous call
|
// these are prepended to any existing text context from a previous call
|
||||||
// use whisper_tokenize() to convert text to tokens
|
// use whisper_tokenize() to convert text to tokens
|
||||||
|
|
Loading…
Reference in a new issue