convert : support loading vocab from fast tokenizer config (#3633)

* Add HFVocab into convert.py

* Update convert.py

* Update convert.py

* add bytes_to_unicode function

* change add_meta_vocab fucntion

* remove debug code

* remove byte_encoder

* Add newline between classes

* Check tokenizer.json when tokenizer.model is not exist.

* Move transformers dependency to local code

* Add error context with 'raise from'

* Add fast tokenizer option to BpeVocab

* Update convert.py

* Add VocabLoader and remove *Vocab class

* Add transformers dependency

* remove added tokens and check newline token to decide spm or bpe

* Update convert.py

* Add special token type

* Update convert.py

* Update convert.py

* Update convert.py

* Fix typo in convert.py

* Fix when params.n_vocab < tokenizer vocab size

* update vocab class

* change funtion name

* Remove unused variable/functions, add types to class variable and methods, delete blank liens

* fix flake8 warnings

* code style cleanup

* make mypy happy

* change exception

---------

Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
wonjun Jang 2023-12-14 17:09:34 +09:00 committed by GitHub
parent 0353a18401
commit 873637afc7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 168 additions and 156 deletions

View file

@ -10,6 +10,7 @@ import itertools
import json import json
import math import math
import mmap import mmap
import os
import pickle import pickle
import re import re
import signal import signal
@ -18,15 +19,15 @@ import sys
import time import time
import zipfile import zipfile
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast
import numpy as np import numpy as np
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
import os
if 'NO_LOCAL_GGUF' not in os.environ: if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf import gguf
@ -327,127 +328,138 @@ class Params:
return params return params
# class VocabLoader:
# vocab def __init__(self, params: Params, fname_tokenizer: Path) -> None:
# try:
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"To use VocabLoader, please install the `transformers` package. "
"You can install it with `pip install transformers`."
) from e
class BpeVocab: try:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) except ValueError:
added_tokens: dict[str, int] self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True)
if fname_added_tokens is not None:
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. self.added_tokens_dict: OrderedDict[str, int] = OrderedDict()
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]):
if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size:
continue
self.added_tokens_dict[tok] = tokidx
self.unk_token_id: int = self.tokenizer.unk_token_id
self.specials: dict[str, int] = {
tok: self.tokenizer.get_vocab()[tok]
for tok in self.tokenizer.all_special_tokens
}
self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
self.vocab_size_base: int = self.tokenizer.vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
self.fname_tokenizer: Path = fname_tokenizer
vocab_file = "tokenizer.model"
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
if path_candidate is not None:
self.spm = SentencePieceProcessor(str(path_candidate))
print(self.spm.vocab_size(), self.vocab_size_base)
else: else:
# Fall back to trying to find the added tokens in tokenizer.json self.spm = None
tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
if not tokenizer_json_file.is_file():
added_tokens = {}
else:
tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
added_tokens = dict(
(item['content'], item['id'])
for item in tokenizer_json.get('added_tokens', [])
# Added tokens here can be duplicates of the main vocabulary.
if item['content'] not in self.bpe_tokenizer)
vocab_size: int = len(self.bpe_tokenizer) def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) tokenizer = self.tokenizer
actual_ids = sorted(added_tokens.values()) reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()}
if expected_ids != actual_ids: added_tokens_ids = set(self.added_tokens_dict.values())
expected_end_id = vocab_size + len(actual_ids) - 1
raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) for i in range(self.vocab_size_base):
self.added_tokens_list = [text for (text, idx) in items] if i in added_tokens_ids:
self.vocab_size_base: int = vocab_size continue
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: text = reverse_vocab[i].encode("utf-8")
tokenizer = self.bpe_tokenizer yield text, self.get_token_score(i), self.get_token_type(i)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()}
for i, _ in enumerate(tokenizer): def get_token_type(self, token_id: int) -> gguf.TokenType:
yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL toktype = gguf.TokenType.NORMAL
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: if self.spm is not None and token_id < self.spm.vocab_size():
for text in self.added_tokens_list: if self.spm.is_unknown(token_id):
score = -1000.0
yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.bpe_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
actual_new_ids = sorted(new_tokens.keys())
if expected_new_ids != actual_new_ids:
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
# Token pieces that were added to the base vocabulary.
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
self.vocab_size_base = vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
text: bytes = piece.encode("utf-8")
score: float = tokenizer.get_score(i)
toktype = gguf.TokenType.NORMAL
if tokenizer.is_unknown(i):
toktype = gguf.TokenType.UNKNOWN toktype = gguf.TokenType.UNKNOWN
if tokenizer.is_control(i): if self.spm.is_control(token_id):
toktype = gguf.TokenType.CONTROL
if self.spm.is_unused(token_id):
toktype = gguf.TokenType.UNUSED
if self.spm.is_byte(token_id):
toktype = gguf.TokenType.BYTE
else:
if token_id == self.unk_token_id:
toktype = gguf.TokenType.UNKNOWN
if token_id in self.special_ids:
toktype = gguf.TokenType.CONTROL toktype = gguf.TokenType.CONTROL
# NOTE: I think added_tokens are user defined. return toktype
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
if tokenizer.is_unused(i): def get_token_score(self, token_id: int) -> float:
toktype = gguf.TokenType.UNUSED if self.spm is not None and token_id < self.spm.vocab_size():
if tokenizer.is_byte(i): return cast(float, self.spm.get_score(token_id))
toktype = gguf.TokenType.BYTE return 0.0
yield text, score, toktype
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list:
score = -1000.0 for text in self.added_tokens_dict:
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED if text in self.specials:
toktype = self.get_token_type(self.specials[text])
score = self.get_token_score(self.specials[text])
else:
toktype = gguf.TokenType.USER_DEFINED
score = -1000.0
yield text.encode("utf-8"), score, toktype
def has_newline_token(self) -> bool:
return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.sentencepiece_tokens() yield from self.hf_tokens()
yield from self.added_tokens() yield from self.added_tokens()
def get_vocab_type(self) -> str:
path_candidates = []
vocab_file = "tokenizer.model"
path_candidates.append(vocab_file)
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
if path_candidate is not None:
return "llama"
vocab_file = "vocab.json"
path_candidates.append(vocab_file)
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
if path_candidate is not None:
return "gpt2"
vocab_file = "tokenizer.json"
path_candidates.append(vocab_file)
path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
if path_candidate:
if not self.has_newline_token():
return "gpt2"
return "llama"
raise FileNotFoundError(
f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir"
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>"
Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab' Vocab: TypeAlias = 'VocabLoader'
# #
# data loading # data loading
@ -824,20 +836,27 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
yield result yield result
def check_vocab_size(params: Params, vocab: Vocab) -> None: def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
if params.n_vocab != vocab.vocab_size: if params.n_vocab != vocab.vocab_size:
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) if params.n_vocab == vocab.vocab_size:
if params.n_vocab == vocab.vocab_size_base:
print("Ignoring added_tokens.json since model matches vocab size without it.") print("Ignoring added_tokens.json since model matches vocab size without it.")
vocab.added_tokens_list = [] vocab.added_tokens_dict = OrderedDict()
vocab.vocab_size = vocab.vocab_size_base vocab.vocab_size = vocab.vocab_size
return
if pad_vocab and params.n_vocab > vocab.vocab_size:
pad_count = params.n_vocab - vocab.vocab_size
print(f'Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>')
for i in range(1, (params.n_vocab - vocab.vocab_size) + 1):
vocab.added_tokens_dict[f'<dummy{i:05}>'] = -1
vocab.vocab_size = params.n_vocab
return return
msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}" msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
if vocab.fname_added_tokens is not None:
msg += f" combined with {vocab.fname_added_tokens}"
msg += f" has {vocab.vocab_size})." msg += f" has {vocab.vocab_size})."
if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None: if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
if vocab.vocab_size < params.n_vocab:
msg += " Possibly try using the --padvocab option."
raise Exception(msg) raise Exception(msg)
@ -901,12 +920,8 @@ class OutputFile:
scores.append(score) scores.append(score)
toktypes.append(toktype) toktypes.append(toktype)
if isinstance(vocab, SentencePieceVocab): vocab_type = vocab.get_vocab_type()
self.gguf.add_tokenizer_model("llama") self.gguf.add_tokenizer_model(vocab_type)
elif isinstance(vocab, BpeVocab):
self.gguf.add_tokenizer_model("gpt2")
else:
raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab')
self.gguf.add_token_list(tokens) self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores) self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes) self.gguf.add_token_types(toktypes)
@ -932,8 +947,12 @@ class OutputFile:
self.gguf.close() self.gguf.close()
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: def write_vocab_only(
check_vocab_size(params, vocab) fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
) -> None:
check_vocab_size(params, vocab, pad_vocab = pad_vocab)
of = OutputFile(fname_out, endianess=endianess) of = OutputFile(fname_out, endianess=endianess)
@ -960,8 +979,13 @@ class OutputFile:
return dt.quantize(arr) return dt.quantize(arr)
@staticmethod @staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: def write_all(
check_vocab_size(params, vocab) fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY,
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
) -> None:
check_vocab_size(params, vocab, pad_vocab = pad_vocab)
of = OutputFile(fname_out, endianess=endianess) of = OutputFile(fname_out, endianess=endianess)
@ -1119,35 +1143,17 @@ def load_some_model(path: Path) -> ModelPlus:
return model_plus return model_plus
def load_vocab(path: Path, vocabtype: str | None) -> Vocab: def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]:
# Be extra-friendly and accept either a file or a directory. Also, if it's path2 = path / vocab_file
# a directory, it might be the model directory, and tokenizer.model might # Use `.parent` instead of /.. to handle the symlink case better.
# be in the parent of that. path3 = path.parent / vocab_file
if path.is_dir():
vocab_file = "tokenizer.model"
if vocabtype == 'bpe':
vocab_file = "vocab.json"
path2 = path / vocab_file
# Use `.parent` instead of /.. to handle the symlink case better.
path3 = path.parent / vocab_file
if path2.exists():
path = path2
elif path3.exists():
path = path3
else:
raise FileNotFoundError(
f"Could not find {vocab_file} in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir")
print(f"Loading vocab file '{path}', type '{vocabtype}'") if path2.exists():
return path2
if path3.exists():
return path3
added_tokens_path = path.parent / "added_tokens.json" return None
if vocabtype == "bpe":
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
elif vocabtype == "spm":
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
else:
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
@ -1185,11 +1191,11 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.dump_single: if args.dump_single:
@ -1232,12 +1238,13 @@ def main(args_in: list[str] | None = None) -> None:
if not args.outfile: if not args.outfile:
raise ValueError("need --outfile if using --vocab-only") raise ValueError("need --outfile if using --vocab-only")
# FIXME: Try to respect vocab_dir somehow? # FIXME: Try to respect vocab_dir somehow?
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) vocab = VocabLoader(params, args.vocab_dir or args.model)
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = args.vocabtype == 'bpe', load_merges = True,
n_vocab = vocab.vocab_size) n_vocab = vocab.vocab_size)
outfile = args.outfile outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
endianess = endianess, pad_vocab = args.padvocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
return return
@ -1245,12 +1252,15 @@ def main(args_in: list[str] | None = None) -> None:
vocab = model_plus.vocab vocab = model_plus.vocab
else: else:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = VocabLoader(params, vocab_dir)
# FIXME: Try to respect vocab_dir somehow? # FIXME: Try to respect vocab_dir somehow?
print(f"Vocab info: {vocab}")
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = args.vocabtype == 'bpe', load_merges = True,
n_vocab = vocab.vocab_size) n_vocab = vocab.vocab_size)
print(f"Special vocab info: {special_vocab}")
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params)
ftype = pick_output_type(model, args.outtype) ftype = pick_output_type(model, args.outtype)
@ -1260,7 +1270,8 @@ def main(args_in: list[str] | None = None) -> None:
params.ftype = ftype params.ftype = ftype
print(f"Writing {outfile}, format {ftype}") print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess) OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")

View file

@ -1,3 +1,4 @@
numpy==1.24.4 numpy==1.24.4
sentencepiece==0.1.98 sentencepiece==0.1.98
transformers>=4.34.0
gguf>=0.1.0 gguf>=0.1.0