Implemented command-style grammar in the main example.
Mostly just copied the relevant parts from the command example.pull/1998/head
parent
1558ec5a16
commit
54d3707b17
|
@ -1,6 +1,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
@ -41,6 +42,7 @@ struct whisper_params {
|
||||||
float word_thold = 0.01f;
|
float word_thold = 0.01f;
|
||||||
float entropy_thold = 2.40f;
|
float entropy_thold = 2.40f;
|
||||||
float logprob_thold = -1.00f;
|
float logprob_thold = -1.00f;
|
||||||
|
float grammar_penalty = 100.0f;
|
||||||
|
|
||||||
bool speed_up = false;
|
bool speed_up = false;
|
||||||
bool debug_mode = false;
|
bool debug_mode = false;
|
||||||
|
@ -70,6 +72,8 @@ struct whisper_params {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||||
std::string model = "models/ggml-base.en.bin";
|
std::string model = "models/ggml-base.en.bin";
|
||||||
|
std::string grammar;
|
||||||
|
std::string grammar_rule;
|
||||||
|
|
||||||
// [TDRZ] speaker turn string
|
// [TDRZ] speaker turn string
|
||||||
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
||||||
|
@ -80,6 +84,8 @@ struct whisper_params {
|
||||||
|
|
||||||
std::vector<std::string> fname_inp = {};
|
std::vector<std::string> fname_inp = {};
|
||||||
std::vector<std::string> fname_out = {};
|
std::vector<std::string> fname_out = {};
|
||||||
|
|
||||||
|
grammar_parser::parse_state grammar_parsed;
|
||||||
};
|
};
|
||||||
|
|
||||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||||
|
@ -154,6 +160,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||||
|
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||||
|
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
||||||
|
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
whisper_print_usage(argc, argv, params);
|
whisper_print_usage(argc, argv, params);
|
||||||
|
@ -214,6 +223,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||||
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
||||||
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
||||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||||
|
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||||
|
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||||
|
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -926,6 +938,31 @@ int main(int argc, char ** argv) {
|
||||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||||
|
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
auto& grammar = params.grammar_parsed;
|
||||||
|
if (is_file_exist(params.grammar.c_str())) {
|
||||||
|
// read grammar from file
|
||||||
|
std::ifstream ifs(params.grammar.c_str());
|
||||||
|
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||||
|
grammar = grammar_parser::parse(txt.c_str());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// read grammar from string
|
||||||
|
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (grammar.rules.empty()) {
|
||||||
|
fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str());
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||||
|
grammar_parser::print_grammar(stderr, grammar);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto fname_inp = params.fname_inp[f];
|
||||||
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||||
|
@ -972,7 +1009,8 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
bool bUseGrammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty());
|
||||||
|
wparams.strategy = (params.beam_size > 1 || bUseGrammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||||
|
|
||||||
wparams.print_realtime = false;
|
wparams.print_realtime = false;
|
||||||
wparams.print_progress = params.print_progress;
|
wparams.print_progress = params.print_progress;
|
||||||
|
@ -1010,6 +1048,21 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||||
|
|
||||||
|
const auto& grammar_parsed = params.grammar_parsed;
|
||||||
|
auto grammar_rules = grammar_parsed.c_rules();
|
||||||
|
|
||||||
|
if (bUseGrammar) {
|
||||||
|
if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) {
|
||||||
|
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
wparams.grammar_rules = grammar_rules.data();
|
||||||
|
wparams.n_grammar_rules = grammar_rules.size();
|
||||||
|
wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule);
|
||||||
|
wparams.grammar_penalty = params.grammar_penalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// this callback is called on each new segment
|
// this callback is called on each new segment
|
||||||
if (!wparams.print_realtime) {
|
if (!wparams.print_realtime) {
|
||||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||||
|
|
Loading…
Reference in New Issue