whisper : add new-segment callback
Can be used to process new segments as they are being generated. Sample usage in main, for printing the resulting segments during the inference.pull/78/head
parent
8f95c25aed
commit
7affd309d3
95
main.cpp
95
main.cpp
|
@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
|
||||
const whisper_params & params = *(whisper_params *) user_data;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
// print the last segment
|
||||
const int i = n_segments - 1;
|
||||
if (i == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
if (params.no_timestamps) {
|
||||
if (params.print_colors) {
|
||||
// TODO
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
fflush(stdout);
|
||||
}
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special_tokens == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
|
@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_realtime = !params.print_colors;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special_tokens = params.print_special_tokens;
|
||||
|
@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
|
|||
wparams.n_threads = params.n_threads;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (!wparams.print_realtime) {
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = ¶ms;
|
||||
}
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||
return 7;
|
||||
}
|
||||
|
||||
// print result
|
||||
if (!wparams.print_realtime) {
|
||||
printf("\n");
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
if (params.no_timestamps) {
|
||||
if (params.print_colors) {
|
||||
// TODO
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
fflush(stdout);
|
||||
}
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
// output to text file
|
||||
|
|
16
whisper.cpp
16
whisper.cpp
|
@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||
/*.beam_width =*/ -1,
|
||||
/*.n_best =*/ -1,
|
||||
},
|
||||
|
||||
/*.new_segment_callback =*/ nullptr,
|
||||
/*.new_segment_callback_user_data =*/ nullptr,
|
||||
};
|
||||
} break;
|
||||
case WHISPER_SAMPLING_BEAM_SEARCH:
|
||||
|
@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||
/*.beam_width =*/ 10,
|
||||
/*.n_best =*/ 5,
|
||||
},
|
||||
|
||||
/*.new_segment_callback =*/ nullptr,
|
||||
/*.new_segment_callback_user_data =*/ nullptr,
|
||||
};
|
||||
} break;
|
||||
}
|
||||
|
@ -2549,6 +2555,9 @@ int whisper_full(
|
|||
for (int j = i0; j <= i; j++) {
|
||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||
}
|
||||
if (params.new_segment_callback) {
|
||||
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
||||
}
|
||||
}
|
||||
text = "";
|
||||
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
||||
|
@ -2576,6 +2585,9 @@ int whisper_full(
|
|||
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||
}
|
||||
if (params.new_segment_callback) {
|
||||
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
|
|||
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
||||
}
|
||||
|
||||
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
||||
return ctx->result_all[i_segment].tokens[i_token].id;
|
||||
}
|
||||
|
||||
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
||||
return ctx->result_all[i_segment].tokens[i_token].p;
|
||||
}
|
||||
|
|
|
@ -160,6 +160,11 @@ extern "C" {
|
|||
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
||||
};
|
||||
|
||||
// Text segment callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
|
||||
|
||||
struct whisper_full_params {
|
||||
enum whisper_sampling_strategy strategy;
|
||||
|
||||
|
@ -184,6 +189,9 @@ extern "C" {
|
|||
int beam_width;
|
||||
int n_best;
|
||||
} beam_search;
|
||||
|
||||
whisper_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
};
|
||||
|
||||
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
||||
|
@ -212,6 +220,7 @@ extern "C" {
|
|||
|
||||
// Get the token text of the specified token in the specified segment.
|
||||
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
||||
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
||||
|
||||
// Get the probability of the specified token in the specified segment.
|
||||
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
||||
|
|
Loading…
Reference in New Issue