whisper : add new-segment callback
Can be used to process new segments as they are being generated. Sample usage in main, for printing the resulting segments during the inference.pull/78/head
parent
8f95c25aed
commit
7affd309d3
95
main.cpp
95
main.cpp
|
@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
|
||||||
|
const whisper_params & params = *(whisper_params *) user_data;
|
||||||
|
|
||||||
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
|
|
||||||
|
// print the last segment
|
||||||
|
const int i = n_segments - 1;
|
||||||
|
if (i == 0) {
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.no_timestamps) {
|
||||||
|
if (params.print_colors) {
|
||||||
|
// TODO
|
||||||
|
} else {
|
||||||
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
|
printf("%s", text);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||||
|
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||||
|
|
||||||
|
if (params.print_colors) {
|
||||||
|
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||||
|
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||||
|
if (params.print_special_tokens == false) {
|
||||||
|
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||||
|
if (id >= whisper_token_eot(ctx)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||||
|
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||||
|
|
||||||
|
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||||
|
|
||||||
|
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
} else {
|
||||||
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
|
|
||||||
|
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||||
std::ofstream fout(fname);
|
std::ofstream fout(fname);
|
||||||
if (!fout.is_open()) {
|
if (!fout.is_open()) {
|
||||||
|
@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
wparams.print_realtime = !params.print_colors;
|
wparams.print_realtime = false;
|
||||||
wparams.print_progress = false;
|
wparams.print_progress = false;
|
||||||
wparams.print_timestamps = !params.no_timestamps;
|
wparams.print_timestamps = !params.no_timestamps;
|
||||||
wparams.print_special_tokens = params.print_special_tokens;
|
wparams.print_special_tokens = params.print_special_tokens;
|
||||||
|
@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
|
||||||
wparams.n_threads = params.n_threads;
|
wparams.n_threads = params.n_threads;
|
||||||
wparams.offset_ms = params.offset_t_ms;
|
wparams.offset_ms = params.offset_t_ms;
|
||||||
|
|
||||||
|
// this callback is called on each new segment
|
||||||
|
if (!wparams.print_realtime) {
|
||||||
|
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||||
|
wparams.new_segment_callback_user_data = ¶ms;
|
||||||
|
}
|
||||||
|
|
||||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||||
return 7;
|
return 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
// print result
|
|
||||||
if (!wparams.print_realtime) {
|
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
|
||||||
if (params.no_timestamps) {
|
|
||||||
if (params.print_colors) {
|
|
||||||
// TODO
|
|
||||||
} else {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
printf("%s", text);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
|
||||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
|
||||||
|
|
||||||
if (params.print_colors) {
|
|
||||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
|
||||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
|
||||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
|
||||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
|
||||||
|
|
||||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
|
||||||
|
|
||||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
} else {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
|
|
||||||
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
// output to text file
|
// output to text file
|
||||||
|
|
16
whisper.cpp
16
whisper.cpp
|
@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||||
/*.beam_width =*/ -1,
|
/*.beam_width =*/ -1,
|
||||||
/*.n_best =*/ -1,
|
/*.n_best =*/ -1,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/*.new_segment_callback =*/ nullptr,
|
||||||
|
/*.new_segment_callback_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
case WHISPER_SAMPLING_BEAM_SEARCH:
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
||||||
|
@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||||
/*.beam_width =*/ 10,
|
/*.beam_width =*/ 10,
|
||||||
/*.n_best =*/ 5,
|
/*.n_best =*/ 5,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/*.new_segment_callback =*/ nullptr,
|
||||||
|
/*.new_segment_callback_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
|
@ -2549,6 +2555,9 @@ int whisper_full(
|
||||||
for (int j = i0; j <= i; j++) {
|
for (int j = i0; j <= i; j++) {
|
||||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||||
}
|
}
|
||||||
|
if (params.new_segment_callback) {
|
||||||
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
text = "";
|
text = "";
|
||||||
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
||||||
|
@ -2576,6 +2585,9 @@ int whisper_full(
|
||||||
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
||||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||||
}
|
}
|
||||||
|
if (params.new_segment_callback) {
|
||||||
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
|
||||||
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
||||||
|
return ctx->result_all[i_segment].tokens[i_token].id;
|
||||||
|
}
|
||||||
|
|
||||||
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
||||||
return ctx->result_all[i_segment].tokens[i_token].p;
|
return ctx->result_all[i_segment].tokens[i_token].p;
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,6 +160,11 @@ extern "C" {
|
||||||
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Text segment callback
|
||||||
|
// Called on every newly generated text segment
|
||||||
|
// Use the whisper_full_...() functions to obtain the text segments
|
||||||
|
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
|
||||||
|
|
||||||
struct whisper_full_params {
|
struct whisper_full_params {
|
||||||
enum whisper_sampling_strategy strategy;
|
enum whisper_sampling_strategy strategy;
|
||||||
|
|
||||||
|
@ -184,6 +189,9 @@ extern "C" {
|
||||||
int beam_width;
|
int beam_width;
|
||||||
int n_best;
|
int n_best;
|
||||||
} beam_search;
|
} beam_search;
|
||||||
|
|
||||||
|
whisper_new_segment_callback new_segment_callback;
|
||||||
|
void * new_segment_callback_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
||||||
|
@ -212,6 +220,7 @@ extern "C" {
|
||||||
|
|
||||||
// Get the token text of the specified token in the specified segment.
|
// Get the token text of the specified token in the specified segment.
|
||||||
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
||||||
|
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
||||||
|
|
||||||
// Get the probability of the specified token in the specified segment.
|
// Get the probability of the specified token in the specified segment.
|
||||||
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
||||||
|
|
Loading…
Reference in New Issue