diff --git a/whisper.cpp b/whisper.cpp index 67451dc..2e8ee87 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -429,6 +429,12 @@ struct whisper_context { int32_t exp_n_audio_ctx; // 0 - use default }; +template +static void read_safe(std::ifstream& fin, T& dest) +{ + fin.read((char*)& dest, sizeof(T)); +} + // load the model from a ggml file // // file format: @@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // verify magic { uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); + read_safe(fin, magic); if (magic != 0x67676d6c) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; @@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { auto & hparams = model.hparams; - fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); - fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); - fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); - fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); - fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); - fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx)); - fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state)); - fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head)); - fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer)); - fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); - fin.read((char *) &hparams.f16, sizeof(hparams.f16)); + read_safe(fin, hparams.n_vocab); + read_safe(fin, hparams.n_audio_ctx); + read_safe(fin, hparams.n_audio_state); + read_safe(fin, hparams.n_audio_head); + read_safe(fin, hparams.n_audio_layer); + read_safe(fin, hparams.n_text_ctx); + read_safe(fin, hparams.n_text_state); + read_safe(fin, hparams.n_text_head); + read_safe(fin, hparams.n_text_layer); + read_safe(fin, hparams.n_mels); + read_safe(fin, hparams.f16); assert(hparams.n_text_state == hparams.n_audio_state); @@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx { auto & filters = wctx.model.filters; - fin.read((char *) &filters.n_mel, sizeof(filters.n_mel)); - fin.read((char *) &filters.n_fft, sizeof(filters.n_fft)); + read_safe(fin, filters.n_mel); + read_safe(fin, filters.n_fft); filters.data.resize(filters.n_mel * filters.n_fft); fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float)); @@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // load vocab { int32_t n_vocab = 0; - fin.read((char *) &n_vocab, sizeof(n_vocab)); + read_safe(fin, n_vocab); //if (n_vocab != model.hparams.n_vocab) { // fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", @@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx std::string word; for (int i = 0; i < n_vocab; i++) { uint32_t len; - fin.read((char *) &len, sizeof(len)); + read_safe(fin, len); - word.resize(len); - fin.read((char *) word.data(), len); + std::vector tmp(len); // create a buffer + fin.read( &tmp[0], tmp.size() ); // read to buffer + word.assign(&tmp[0], tmp.size()); vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; @@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx int32_t length; int32_t ftype; - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + read_safe(fin, n_dims); + read_safe(fin, length); + read_safe(fin, ftype); if (fin.eof()) { break; @@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx int32_t nelements = 1; int32_t ne[3] = { 1, 1, 1 }; for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + read_safe(fin, ne[i]); nelements *= ne[i]; } - std::string name(length, 0); - fin.read(&name[0], length); + std::string name; + std::vector tmp(length); // create a buffer + fin.read( &tmp[0], tmp.size() ); // read to buffer + name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name.data()) == model.tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());