whisper.cpp/models/convert-whisper-to-openvino.py
st-gr eb23f4ef16
openvino : fix convert-whisper-to-openvino.py (#1890)
Fix issue: Conversion from Whisper to OpenVino failed #1870

convert-whisper-to-openvino.py stopped working with OpenVINO version 2023.0.0-10926-b4452d56304-releases/2023/0 .

Error was: TypeError: load(): incompatible function arguments. The following argument types are supported:
    1. (self: openvino._pyopenvino.FrontEnd, path: object) -> ov::frontend::InputModel

Tested successfully with a large-v3 conversion.

Co-authored-by: Stefan Grundmann <grundmanns@sandiego.gov>
2024-02-22 15:11:35 +02:00

61 lines
1.9 KiB
Python

import argparse
import torch
from whisper import load_model
import os
from openvino.tools import mo
from openvino.frontend import FrontEndManager
from openvino.runtime import serialize
import shutil
def convert_encoder(hparams, encoder, mname):
encoder.eval()
mel = torch.zeros((1, hparams.n_mels, 3000))
onnx_folder = os.path.join(os.path.dirname(__file__), "onnx_encoder")
#create a directory to store the onnx model, and other collateral that is saved during onnx export procedure
if not os.path.isdir(onnx_folder):
os.makedirs(onnx_folder)
onnx_path = os.path.join(onnx_folder, "whisper_encoder.onnx")
# Export the PyTorch model to ONNX
torch.onnx.export(
encoder,
mel,
onnx_path,
input_names=["mel"],
output_names=["output_features"]
)
# Convert ONNX to OpenVINO IR format using the frontend
fem = FrontEndManager()
onnx_fe = fem.load_by_framework("onnx")
onnx_model = onnx_fe.load(onnx_path)
ov_model = onnx_fe.convert(onnx_model)
# Serialize the OpenVINO model to XML and BIN files
serialize(ov_model, xml_path=os.path.join(os.path.dirname(__file__), "ggml-" + mname + "-encoder-openvino.xml"))
# Cleanup
if os.path.isdir(onnx_folder):
shutil.rmtree(onnx_folder)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()
hparams = whisper.dims
encoder = whisper.encoder
# Convert encoder to onnx
convert_encoder(hparams, encoder, args.model)