From 873637afc7924f435ac44c067630a28e82eefa7b Mon Sep 17 00:00:00 2001 From: wonjun Jang Date: Thu, 14 Dec 2023 17:09:34 +0900 Subject: [PATCH] convert : support loading vocab from fast tokenizer config (#3633) * Add HFVocab into convert.py * Update convert.py * Update convert.py * add bytes_to_unicode function * change add_meta_vocab fucntion * remove debug code * remove byte_encoder * Add newline between classes * Check tokenizer.json when tokenizer.model is not exist. * Move transformers dependency to local code * Add error context with 'raise from' * Add fast tokenizer option to BpeVocab * Update convert.py * Add VocabLoader and remove *Vocab class * Add transformers dependency * remove added tokens and check newline token to decide spm or bpe * Update convert.py * Add special token type * Update convert.py * Update convert.py * Update convert.py * Fix typo in convert.py * Fix when params.n_vocab < tokenizer vocab size * update vocab class * change funtion name * Remove unused variable/functions, add types to class variable and methods, delete blank liens * fix flake8 warnings * code style cleanup * make mypy happy * change exception --------- Co-authored-by: Jared Van Bortel --- convert.py | 323 ++++++++++++++++++++++++----------------------- requirements.txt | 1 + 2 files changed, 168 insertions(+), 156 deletions(-) diff --git a/convert.py b/convert.py index e4b69d172..7a3cd615e 100755 --- a/convert.py +++ b/convert.py @@ -10,6 +10,7 @@ import itertools import json import math import mmap +import os import pickle import re import signal @@ -18,15 +19,15 @@ import sys import time import zipfile from abc import ABCMeta, abstractmethod +from collections import OrderedDict from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar +from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast import numpy as np from sentencepiece import SentencePieceProcessor -import os if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf @@ -327,127 +328,138 @@ class Params: return params -# -# vocab -# +class VocabLoader: + def __init__(self, params: Params, fname_tokenizer: Path) -> None: + try: + from transformers import AutoTokenizer + except ImportError as e: + raise ImportError( + "To use VocabLoader, please install the `transformers` package. " + "You can install it with `pip install transformers`." + ) from e -class BpeVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: - self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) - added_tokens: dict[str, int] - if fname_added_tokens is not None: - # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. - added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + try: + self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True) + except ValueError: + self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True) + + self.added_tokens_dict: OrderedDict[str, int] = OrderedDict() + + for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]): + if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size: + continue + + self.added_tokens_dict[tok] = tokidx + + self.unk_token_id: int = self.tokenizer.unk_token_id + self.specials: dict[str, int] = { + tok: self.tokenizer.get_vocab()[tok] + for tok in self.tokenizer.all_special_tokens + } + self.special_ids: set[int] = set(self.tokenizer.all_special_ids) + self.vocab_size_base: int = self.tokenizer.vocab_size + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict) + self.fname_tokenizer: Path = fname_tokenizer + + vocab_file = "tokenizer.model" + path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) + if path_candidate is not None: + self.spm = SentencePieceProcessor(str(path_candidate)) + print(self.spm.vocab_size(), self.vocab_size_base) else: - # Fall back to trying to find the added tokens in tokenizer.json - tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' - if not tokenizer_json_file.is_file(): - added_tokens = {} - else: - tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) - added_tokens = dict( - (item['content'], item['id']) - for item in tokenizer_json.get('added_tokens', []) - # Added tokens here can be duplicates of the main vocabulary. - if item['content'] not in self.bpe_tokenizer) + self.spm = None - vocab_size: int = len(self.bpe_tokenizer) - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) - if expected_ids != actual_ids: - expected_end_id = vocab_size + len(actual_ids) - 1 - raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") + def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.tokenizer + reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()} + added_tokens_ids = set(self.added_tokens_dict.values()) - items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_list = [text for (text, idx) in items] - self.vocab_size_base: int = vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens + for i in range(self.vocab_size_base): + if i in added_tokens_ids: + continue - def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - tokenizer = self.bpe_tokenizer - reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()} + text = reverse_vocab[i].encode("utf-8") + yield text, self.get_token_score(i), self.get_token_type(i) - for i, _ in enumerate(tokenizer): - yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL + def get_token_type(self, token_id: int) -> gguf.TokenType: + toktype = gguf.TokenType.NORMAL - def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - for text in self.added_tokens_list: - score = -1000.0 - yield text.encode("utf-8"), score, gguf.TokenType.CONTROL - - def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - yield from self.bpe_tokens() - yield from self.added_tokens() - - def __repr__(self) -> str: - return f"" - - -class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: - self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) - added_tokens: dict[str, int] - if fname_added_tokens is not None: - added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) - else: - added_tokens = {} - - vocab_size: int = self.sentencepiece_tokenizer.vocab_size() - - new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} - expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) - actual_new_ids = sorted(new_tokens.keys()) - - if expected_new_ids != actual_new_ids: - raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") - - # Token pieces that were added to the base vocabulary. - self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens - - def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - tokenizer = self.sentencepiece_tokenizer - for i in range(tokenizer.vocab_size()): - piece = tokenizer.id_to_piece(i) - text: bytes = piece.encode("utf-8") - score: float = tokenizer.get_score(i) - - toktype = gguf.TokenType.NORMAL - if tokenizer.is_unknown(i): + if self.spm is not None and token_id < self.spm.vocab_size(): + if self.spm.is_unknown(token_id): toktype = gguf.TokenType.UNKNOWN - if tokenizer.is_control(i): + if self.spm.is_control(token_id): + toktype = gguf.TokenType.CONTROL + if self.spm.is_unused(token_id): + toktype = gguf.TokenType.UNUSED + if self.spm.is_byte(token_id): + toktype = gguf.TokenType.BYTE + else: + if token_id == self.unk_token_id: + toktype = gguf.TokenType.UNKNOWN + if token_id in self.special_ids: toktype = gguf.TokenType.CONTROL - # NOTE: I think added_tokens are user defined. - # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto - # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED + return toktype - if tokenizer.is_unused(i): - toktype = gguf.TokenType.UNUSED - if tokenizer.is_byte(i): - toktype = gguf.TokenType.BYTE - - yield text, score, toktype + def get_token_score(self, token_id: int) -> float: + if self.spm is not None and token_id < self.spm.vocab_size(): + return cast(float, self.spm.get_score(token_id)) + return 0.0 def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - for text in self.added_tokens_list: - score = -1000.0 - yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + for text in self.added_tokens_dict: + if text in self.specials: + + toktype = self.get_token_type(self.specials[text]) + score = self.get_token_score(self.specials[text]) + + else: + toktype = gguf.TokenType.USER_DEFINED + score = -1000.0 + + yield text.encode("utf-8"), score, toktype + + def has_newline_token(self) -> bool: + return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: - yield from self.sentencepiece_tokens() + yield from self.hf_tokens() yield from self.added_tokens() + def get_vocab_type(self) -> str: + path_candidates = [] + vocab_file = "tokenizer.model" + path_candidates.append(vocab_file) + path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) + if path_candidate is not None: + return "llama" + + vocab_file = "vocab.json" + path_candidates.append(vocab_file) + path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) + if path_candidate is not None: + return "gpt2" + + vocab_file = "tokenizer.json" + path_candidates.append(vocab_file) + path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file) + if path_candidate: + if not self.has_newline_token(): + return "gpt2" + return "llama" + + raise FileNotFoundError( + f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; " + "if it's in another directory, pass the directory as --vocab-dir" + ) + def __repr__(self) -> str: - return f"" + return f"" -Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab' +Vocab: TypeAlias = 'VocabLoader' + # # data loading @@ -824,20 +836,27 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield result -def check_vocab_size(params: Params, vocab: Vocab) -> None: +def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: if params.n_vocab != vocab.vocab_size: - assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) - if params.n_vocab == vocab.vocab_size_base: + if params.n_vocab == vocab.vocab_size: print("Ignoring added_tokens.json since model matches vocab size without it.") - vocab.added_tokens_list = [] - vocab.vocab_size = vocab.vocab_size_base + vocab.added_tokens_dict = OrderedDict() + vocab.vocab_size = vocab.vocab_size + return + + if pad_vocab and params.n_vocab > vocab.vocab_size: + pad_count = params.n_vocab - vocab.vocab_size + print(f'Padding vocab with {pad_count} token(s) - through ') + for i in range(1, (params.n_vocab - vocab.vocab_size) + 1): + vocab.added_tokens_dict[f''] = -1 + vocab.vocab_size = params.n_vocab return msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}" - if vocab.fname_added_tokens is not None: - msg += f" combined with {vocab.fname_added_tokens}" msg += f" has {vocab.vocab_size})." - if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None: + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + if vocab.vocab_size < params.n_vocab: + msg += " Possibly try using the --padvocab option." raise Exception(msg) @@ -901,12 +920,8 @@ class OutputFile: scores.append(score) toktypes.append(toktype) - if isinstance(vocab, SentencePieceVocab): - self.gguf.add_tokenizer_model("llama") - elif isinstance(vocab, BpeVocab): - self.gguf.add_tokenizer_model("gpt2") - else: - raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab') + vocab_type = vocab.get_vocab_type() + self.gguf.add_tokenizer_model(vocab_type) self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) self.gguf.add_token_types(toktypes) @@ -932,8 +947,12 @@ class OutputFile: self.gguf.close() @staticmethod - def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: - check_vocab_size(params, vocab) + def write_vocab_only( + fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + ) -> None: + check_vocab_size(params, vocab, pad_vocab = pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -960,8 +979,13 @@ class OutputFile: return dt.quantize(arr) @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: - check_vocab_size(params, vocab) + def write_all( + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + ) -> None: + check_vocab_size(params, vocab, pad_vocab = pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1119,35 +1143,17 @@ def load_some_model(path: Path) -> ModelPlus: return model_plus -def load_vocab(path: Path, vocabtype: str | None) -> Vocab: - # Be extra-friendly and accept either a file or a directory. Also, if it's - # a directory, it might be the model directory, and tokenizer.model might - # be in the parent of that. - if path.is_dir(): - vocab_file = "tokenizer.model" - if vocabtype == 'bpe': - vocab_file = "vocab.json" - path2 = path / vocab_file - # Use `.parent` instead of /.. to handle the symlink case better. - path3 = path.parent / vocab_file - if path2.exists(): - path = path2 - elif path3.exists(): - path = path3 - else: - raise FileNotFoundError( - f"Could not find {vocab_file} in {path} or its parent; " - "if it's in another directory, pass the directory as --vocab-dir") +def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]: + path2 = path / vocab_file + # Use `.parent` instead of /.. to handle the symlink case better. + path3 = path.parent / vocab_file - print(f"Loading vocab file '{path}', type '{vocabtype}'") + if path2.exists(): + return path2 + if path3.exists(): + return path3 - added_tokens_path = path.parent / "added_tokens.json" - if vocabtype == "bpe": - return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) - elif vocabtype == "spm": - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) - else: - raise ValueError(f"Unsupported vocabulary type {vocabtype}") + return None def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: @@ -1185,11 +1191,11 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)") - parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") args = parser.parse_args(args_in) if args.dump_single: @@ -1232,12 +1238,13 @@ def main(args_in: list[str] | None = None) -> None: if not args.outfile: raise ValueError("need --outfile if using --vocab-only") # FIXME: Try to respect vocab_dir somehow? - vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) + vocab = VocabLoader(params, args.vocab_dir or args.model) special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = args.vocabtype == 'bpe', + load_merges = True, n_vocab = vocab.vocab_size) outfile = args.outfile - OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, + endianess = endianess, pad_vocab = args.padvocab) print(f"Wrote {outfile}") return @@ -1245,12 +1252,15 @@ def main(args_in: list[str] | None = None) -> None: vocab = model_plus.vocab else: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent - vocab = load_vocab(vocab_dir, args.vocabtype) + vocab = VocabLoader(params, vocab_dir) + # FIXME: Try to respect vocab_dir somehow? + print(f"Vocab info: {vocab}") special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = args.vocabtype == 'bpe', + load_merges = True, n_vocab = vocab.vocab_size) + print(f"Special vocab info: {special_vocab}") model = model_plus.model model = convert_model_names(model, params) ftype = pick_output_type(model, args.outtype) @@ -1260,7 +1270,8 @@ def main(args_in: list[str] | None = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess) + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab) print(f"Wrote {outfile}") diff --git a/requirements.txt b/requirements.txt index 81c909d0b..badfec3be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy==1.24.4 sentencepiece==0.1.98 +transformers>=4.34.0 gguf>=0.1.0