diff --git a/whisper.cpp b/whisper.cpp index da35456..1f64b35 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #define USE_FLASH_ATTN //#define USE_FLASH_FF @@ -2161,6 +2162,71 @@ static bool log_mel_spectrogram( return true; } +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.size() == 0) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + while (j > i) { + auto it = vocab.token_to_id.find(word.substr(i, j-i)); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + break; + } + --j; + } + if (i == n) { + break; + } + if (j == i) { + auto sub = word.substr(i, 1); + if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) { + tokens.push_back(vocab.token_to_id.at(sub)); + } else { + fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data()); + } + ++i; + } + } + } + + return tokens; +} + // // interface implementation // @@ -2291,6 +2357,21 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, return res; } +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (res.size() > n_max_tokens) { + fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -1; + } + + for (int i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + int whisper_lang_id(const char * lang) { if (!g_lang.count(lang)) { fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang); diff --git a/whisper.h b/whisper.h index def77d4..a28a3b7 100644 --- a/whisper.h +++ b/whisper.h @@ -139,6 +139,17 @@ extern "C" { WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns -1 on failure + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + // Return the id of the specified language, returns -1 if not found WHISPER_API int whisper_lang_id(const char * lang);