convert : various script cleanups/fixes + merges and special token handling (#2842)

* convert: Fix permute calls and method/func definitions

* Cleanups for gguf-py

* Minor types cleanups.

* Initial implementation of handling merges and special tokens

* convert: Handle special tokens and merges in vocab only mode

convert: Vocab only mode no longer requires loading model tensors

* gguf: Refactor tensor name mapping

* convert: Fix type hint for special_token_types in SpecialVocab

* Use common special vocab handling in various conversion scripts

* First pass at implementing suggested changes

* Second pass

* gguf: SpecialVocab: Fix issue with special token content not in a dict

gguf: SpecialVocab: Allow skipping handling of merges

* convert-falcon-hf-to-gguf: Support --vocab-only option, bail out if no tokenizer.json

* convert-gptneox-hf-to-gguf and convert: Only handle merges for BPE tokenizer

* gguf: SpecialVocab: Actually set load_merges in object

* Uniform args parsing and vocab only mode for convert examples

* convert.py: Set gpt2 as tokenizer model when using BPE

* Squish last type warning in gguf.py - yay!
This commit is contained in:
Kerfuffle 2023-08-30 02:25:50 -06:00 committed by GitHub
parent ad9ddcff6e
commit dc07dc492e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 728 additions and 748 deletions

View file

@ -8,6 +8,7 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List from typing import Any, List
from pathlib import Path from pathlib import Path
@ -32,11 +33,10 @@ def bytes_to_unicode():
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8+n)
n += 1 n += 1
cs = [chr(n) for n in cs] return dict(zip(bs, (chr(n) for n in cs)))
return dict(zip(bs, cs))
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"): if filename.startswith("pytorch_model-"):
@ -47,17 +47,22 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a Falcon model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
@ -65,25 +70,21 @@ last_dir = os.path.basename(os.path.normpath(dir_model))
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "RWForCausalLM": if hparams["architectures"][0] != "RWForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0]) print("Model architecture not supported: " + hparams["architectures"][0])
sys.exit() sys.exit(1)
# get number of model parts # get number of model parts
num_parts = count_model_parts(dir_model) num_parts = count_model_parts(dir_model)
@ -113,77 +114,58 @@ gguf_writer.add_file_type(ftype)
print("gguf: get tokenizer metadata") print("gguf: get tokenizer metadata")
tokens: List[str] = [] tokens: List[bytearray] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = [] toktypes: List[int] = []
merges: List[str] = []
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)
if Path(dir_model + "/tokenizer.json").is_file(): # gpt2 tokenizer
# gpt2 tokenizer gguf_writer.add_tokenizer_model("gpt2")
gguf_writer.add_tokenizer_model("gpt2")
print("gguf: get gpt2 tokenizer merges") with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: print("gguf: get gpt2 tokenizer vocab")
tokenizer_json = json.load(f)
merges = tokenizer_json["model"]["merges"]
gguf_writer.add_token_merges(merges) vocab_size = len(tokenizer_json["model"]["vocab"])
print("gguf: get gpt2 tokenizer vocab") # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = len(tokenizer_json["model"]["vocab"]) reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py for i in range(vocab_size):
tokenizer = AutoTokenizer.from_pretrained(dir_model) if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} tokens.append(text)
byte_encoder = bytes_to_unicode() scores.append(0.0) # dymmy
byte_decoder = {v: k for k, v in byte_encoder.items()} toktypes.append(gguf.TokenType.NORMAL) # dummy
for i in range(vocab_size): gguf_writer.add_token_list(tokens)
if i in reverse_vocab: gguf_writer.add_token_scores(scores)
try: gguf_writer.add_token_types(toktypes)
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
print("gguf: get special token ids")
# Look for special tokens in config.json
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS
@ -199,15 +181,17 @@ head_dim = hparams["hidden_size"] // n_head
print("gguf: get tensor metadata") print("gguf: get tensor metadata")
if num_parts == 0: if num_parts == 0:
part_names = ("pytorch_model.bin",) part_names = iter(("pytorch_model.bin",))
else: else:
part_names = ( part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
) )
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(dir_model / part_name, map_location="cpu")
for name in model_part.keys(): for name in model_part.keys():
data = model_part[name] data = model_part[name]
@ -238,11 +222,8 @@ for part_name in part_names:
data = data.squeeze().numpy() data = data.squeeze().numpy()
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
name = tensor_map[name[:-7]] + ".weight" if new_name is None:
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
print("Can not map tensor '" + name + "'") print("Can not map tensor '" + name + "'")
sys.exit() sys.exit()
@ -261,19 +242,20 @@ for part_name in part_names:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16) data = data.astype(np.float16)
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(name, data) gguf_writer.add_tensor(new_name, data)
print("gguf: write header") print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print("gguf: model successfully exported to '" + fname_out + "'") print(f"gguf: model successfully exported to '{fname_out}'")
print("") print("")

View file

@ -8,6 +8,7 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List from typing import Any, List
from pathlib import Path from pathlib import Path
@ -34,11 +35,10 @@ def bytes_to_unicode():
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8+n)
n += 1 n += 1
cs = [chr(n) for n in cs] return dict(zip(bs, (chr(n) for n in cs)))
return dict(zip(bs, cs))
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"): if filename.startswith("pytorch_model-"):
@ -49,17 +49,22 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a GPT-NeoX model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
@ -67,19 +72,15 @@ last_dir = os.path.basename(os.path.normpath(dir_model))
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "GPTNeoXForCausalLM": if hparams["architectures"][0] != "GPTNeoXForCausalLM":
@ -97,7 +98,7 @@ print("gguf: get model metadata")
block_count = hparams["num_hidden_layers"] block_count = hparams["num_hidden_layers"]
gguf_writer.add_name(last_dir) gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_position_embeddings"]) gguf_writer.add_context_length(hparams["max_position_embeddings"])
gguf_writer.add_embedding_length(hparams["hidden_size"]) gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
@ -111,86 +112,52 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_eps"])
print("gguf: get tokenizer metadata") print("gguf: get tokenizer metadata")
tokens: List[str] = [] tokens: List[bytearray] = []
merges: List[str] = []
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)
if Path(dir_model + "/tokenizer.json").is_file(): # gpt2 tokenizer
# gpt2 tokenizer gguf_writer.add_tokenizer_model("gpt2")
gguf_writer.add_tokenizer_model("gpt2")
print("gguf: get gpt2 tokenizer merges") with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: print("gguf: get gpt2 tokenizer vocab")
tokenizer_json = json.load(f)
merges = tokenizer_json["model"]["merges"]
gguf_writer.add_token_merges(merges) vocab_size = len(tokenizer_json["model"]["vocab"])
print("gguf: get gpt2 tokenizer vocab") # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = len(tokenizer_json["model"]["vocab"]) reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py for i in range(vocab_size):
tokenizer = AutoTokenizer.from_pretrained(dir_model) if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} tokens.append(text)
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i in range(vocab_size): gguf_writer.add_token_list(tokens)
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
tokens.append(text)
gguf_writer.add_token_list(tokens)
if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
print("gguf: get special token ids")
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
# find special token ids
if "bos_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]:
gguf_writer.add_bos_token_id(key["id"])
if "eos_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]:
gguf_writer.add_eos_token_id(key["id"])
if "unk_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]:
gguf_writer.add_unk_token_id(key["id"])
if "sep_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]:
gguf_writer.add_sep_token_id(key["id"])
if "pad_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]:
gguf_writer.add_pad_token_id(key["id"])
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS
@ -200,13 +167,15 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
print("gguf: get tensor metadata") print("gguf: get tensor metadata")
if num_parts == 0: if num_parts == 0:
part_names = ("pytorch_model.bin",) part_names = iter(("pytorch_model.bin",))
else: else:
part_names = ( part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
) )
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
@ -226,11 +195,8 @@ for part_name in part_names:
data = data.squeeze().numpy() data = data.squeeze().numpy()
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
name = tensor_map[name[:-7]] + ".weight" if new_name is None:
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
print("Can not map tensor '" + name + "'") print("Can not map tensor '" + name + "'")
sys.exit() sys.exit()
@ -249,19 +215,20 @@ for part_name in part_names:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16) data = data.astype(np.float16)
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(name, data) gguf_writer.add_tensor(new_name, data)
print("gguf: write header") print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print("gguf: model successfully exported to '" + fname_out + "'") print(f"gguf: model successfully exported to '{fname_out}'")
print("") print("")

View file

@ -10,8 +10,9 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List from typing import Any, List, TypeAlias
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -20,7 +21,7 @@ from sentencepiece import SentencePieceProcessor
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
def count_model_parts(dir_model: str) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("consolidated."): if filename.startswith("consolidated."):
@ -31,19 +32,22 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a PyTorch 7B LLaMA model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
@ -51,19 +55,15 @@ last_dir = os.path.basename(os.path.normpath(dir_model))
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "LlamaForCausalLM": if hparams["architectures"][0] != "LlamaForCausalLM":
@ -107,7 +107,7 @@ else:
sys.exit() sys.exit()
gguf_writer.add_name(last_dir) gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo) gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth") gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length) gguf_writer.add_context_length(ctx_length)
@ -133,109 +133,60 @@ tokens: List[bytes] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = [] toktypes: List[int] = []
if Path(dir_model + "/tokenizer.model").is_file(): tokenizer_model_file = dir_model / 'tokenizer.model'
# vocab type sentencepiece if not tokenizer_model_file.is_file():
print("gguf: get sentencepiece tokenizer vocab and scores") print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)
tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") # vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab and scores")
for i in range(tokenizer.vocab_size()): tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
text: bytes
score: float
piece = tokenizer.id_to_piece(i) for i in range(tokenizer.vocab_size()):
text = piece.encode("utf-8") text: bytes
score = tokenizer.get_score(i) score: float
toktype = 1 # defualt to normal token type piece = tokenizer.id_to_piece(i)
if tokenizer.is_unknown(i): text = piece.encode("utf-8")
toktype = 2 score = tokenizer.get_score(i)
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
if tokenizer.is_unused(i): # toktype = 4 is user-defined = tokens from added_tokens.json
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text) if tokenizer.is_unused(i):
scores.append(score) toktype = 5
toktypes.append(toktype) if tokenizer.is_byte(i):
toktype = 6
if Path(dir_model + "/added_tokens.json").is_file(): tokens.append(text)
with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: scores.append(score)
addtokens_json = json.load(f) toktypes.append(toktype)
print("gguf: get added tokens") added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)
for key in addtokens_json: print("gguf: get added tokens")
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type
gguf_writer.add_tokenizer_model("llama") for key in addtokens_json:
gguf_writer.add_token_list(tokens) tokens.append( key.encode("utf-8") )
gguf_writer.add_token_scores(scores) scores.append(-1000.0)
gguf_writer.add_token_types(toktypes) toktypes.append(4) # user-defined token type
gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
print("gguf: get special token ids") special_vocab = gguf.SpecialVocab(dir_model)
special_vocab.add_to_gguf(gguf_writer)
if Path(dir_model + "/tokenizer.json").is_file():
# Look for special tokens in tokenizer.json if it exists
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer = json.load(f)
if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file():
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]["content"]:
gguf_writer.add_bos_token_id(key["id"])
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]["content"]:
gguf_writer.add_eos_token_id(key["id"])
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]["content"]:
gguf_writer.add_unk_token_id(key["id"])
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]["content"]:
gguf_writer.add_sep_token_id(key["id"])
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]["content"]:
gguf_writer.add_pad_token_id(key["id"])
else:
# If no tokenizer.json: Look for special tokens in config.json
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
# TENSORS # TENSORS
@ -247,6 +198,8 @@ print("gguf: get tensor metadata")
part_names = (f"consolidated.{n:02}.pth" for n in range(0, num_parts)) part_names = (f"consolidated.{n:02}.pth" for n in range(0, num_parts))
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
@ -266,11 +219,8 @@ for part_name in part_names:
data = data.squeeze().numpy() data = data.squeeze().numpy()
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
name = tensor_map[name[:-7]] + ".weight" if new_name is None:
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
print("Can not map tensor '" + name + "'") print("Can not map tensor '" + name + "'")
sys.exit() sys.exit()
@ -289,20 +239,20 @@ for part_name in part_names:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16) data = data.astype(np.float16)
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(name, data) gguf_writer.add_tensor(new_name, data)
print("gguf: write header") print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("gguf: model successfully exported to '" + fname_out + "'")
print("") print("")

View file

@ -75,7 +75,7 @@ class Tensor:
self.dims = () self.dims = ()
self.dtype = None self.dtype = None
self.start_offset = 0 self.start_offset = 0
self.len_bytes = 0 self.len_bytes = np.int64(0)
def load(self, data, offset): def load(self, data, offset):
orig_offset = offset orig_offset = offset
@ -134,13 +134,14 @@ class GGMLV3Model:
return offset return offset
class GGMLToGGUF: class GGMLToGGUF:
def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override = None): def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override = None, special_vocab = None):
hp = ggml_model.hyperparameters hp = ggml_model.hyperparameters
self.model = ggml_model self.model = ggml_model
self.data = data self.data = data
self.cfg = cfg self.cfg = cfg
self.params_override = params_override self.params_override = params_override
self.vocab_override = vocab_override self.vocab_override = vocab_override
self.special_vocab = special_vocab
if params_override is not None: if params_override is not None:
n_kv_head = params_override.n_head_kv n_kv_head = params_override.n_head_kv
else: else:
@ -162,6 +163,8 @@ class GGMLToGGUF:
gguf_writer = gguf.GGUFWriter(self.cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], use_temp_file = False) gguf_writer = gguf.GGUFWriter(self.cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], use_temp_file = False)
self.add_params(gguf_writer) self.add_params(gguf_writer)
self.add_vocab(gguf_writer) self.add_vocab(gguf_writer)
if self.special_vocab is not None:
self.special_vocab.add_to_gguf(gguf_writer)
self.add_tensors(gguf_writer) self.add_tensors(gguf_writer)
print(" gguf: write header") print(" gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
@ -259,20 +262,13 @@ class GGMLToGGUF:
gguf_writer.add_eos_token_id(2) gguf_writer.add_eos_token_id(2)
def add_tensors(self, gguf_writer): def add_tensors(self, gguf_writer):
nm = self.name_map tensor_map = self.name_map
data = self.data data = self.data
print(f'* Adding {len(self.model.tensors)} tensor(s)') print(f'* Adding {len(self.model.tensors)} tensor(s)')
for tensor in self.model.tensors: for tensor in self.model.tensors:
name = str(tensor.name, 'UTF-8') name = str(tensor.name, 'UTF-8')
if name.endswith('.weight'): mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
name = name[:-7]
suffix = '.weight'
elif name.endswith('.bias'):
name = name[:-5]
suffix = '.bias'
mapped_name = nm.get(name)
assert mapped_name is not None, f'Bad name {name}' assert mapped_name is not None, f'Bad name {name}'
mapped_name += suffix
tempdims = list(tensor.dims[:]) tempdims = list(tensor.dims[:])
if len(tempdims) > 1: if len(tempdims) > 1:
temp = tempdims[1] temp = tempdims[1]
@ -302,8 +298,10 @@ def handle_metadata(cfg, hp):
else: else:
raise ValueError('Unable to load metadata') raise ValueError('Unable to load metadata')
vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype) vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype)
# FIXME: Respect cfg.vocab_dir?
svocab = gguf.SpecialVocab(cfg.model_metadata_dir)
convert.check_vocab_size(params, vocab) convert.check_vocab_size(params, vocab)
return (params, vocab) return (params, vocab, svocab)
def handle_args(): def handle_args():
parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF') parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF')
@ -330,14 +328,16 @@ def main():
print(f'* GGML model hyperparameters: {model.hyperparameters}') print(f'* GGML model hyperparameters: {model.hyperparameters}')
vocab_override = None vocab_override = None
params_override = None params_override = None
special_vocab = None
if cfg.model_metadata_dir is not None: if cfg.model_metadata_dir is not None:
(params_override, vocab_override) = handle_metadata(cfg, model.hyperparameters) (params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters)
print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.') print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.')
print(f'* Overriding params: {params_override}') print(f'* Overriding params: {params_override}')
print(f'* Overriding vocab: {vocab_override}') print(f'* Overriding vocab: {vocab_override}')
print(f'* Special vocab: {special_vocab}')
else: else:
print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override) converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab)
converter.save() converter.save()
print(f'* Successful completion. Output saved to: {cfg.output}') print(f'* Successful completion. Output saved to: {cfg.output}')

View file

@ -8,8 +8,9 @@ import struct
import json import json
import numpy as np import numpy as np
import torch import torch
import argparse
from typing import Any, List, Optional from typing import Any, List, Optional, TypeAlias
from pathlib import Path from pathlib import Path
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -43,40 +44,38 @@ def count_model_parts(dir_model: str) -> int:
return num_parts return num_parts
if len(sys.argv) < 3: def parse_args() -> argparse.Namespace:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n") parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
print(" ftype == 0 -> float32") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
print(" ftype == 1 -> float16") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1) sys.exit(1)
# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
# possible tensor data types # possible tensor data types
# ftype == 0 -> float32 # ftype == 0 -> float32
# ftype == 1 -> float16 # ftype == 1 -> float16
# map from ftype to string # map from ftype to string
ftype_str = ["f32", "f16"] ftype_str = ["f32", "f16"]
ftype = 1 if args.outfile is not None:
if len(sys.argv) > 2: fname_out = args.outfile
ftype = int(sys.argv[2]) else:
if ftype < 0 or ftype > 1: # output in the same directory as the model by default
print("Invalid ftype: " + str(ftype)) fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
sys.exit(1) print("gguf: loading model "+dir_model.name)
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" with open(dir_model / "config.json", "r", encoding="utf-8") as f:
print("gguf: loading model "+last_dir)
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "LlamaForCausalLM": if hparams["architectures"][0] != "LlamaForCausalLM":
@ -115,7 +114,7 @@ else:
sys.exit() sys.exit()
gguf_writer.add_name(last_dir) gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo) gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth") gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length) gguf_writer.add_context_length(ctx_length)
@ -141,110 +140,61 @@ tokens: List[bytes] = []
scores: List[float] = [] scores: List[float] = []
toktypes: List[int] = [] toktypes: List[int] = []
if Path(dir_model + "/tokenizer.model").is_file(): tokenizer_model_file = dir_model / 'tokenizer.model'
# vocab type sentencepiece if not tokenizer_model_file.is_file():
print("gguf: get sentencepiece tokenizer vocab, scores and token types") print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)
tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") # vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
for i in range(tokenizer.vocab_size()): tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
text: bytes
score: float
piece = tokenizer.id_to_piece(i) for i in range(tokenizer.vocab_size()):
text = piece.encode("utf-8") text: bytes
score = tokenizer.get_score(i) score: float
toktype = 1 # defualt to normal token type piece = tokenizer.id_to_piece(i)
if tokenizer.is_unknown(i): text = piece.encode("utf-8")
toktype = 2 score = tokenizer.get_score(i)
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
if tokenizer.is_unused(i): # toktype = 4 is user-defined = tokens from added_tokens.json
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text) if tokenizer.is_unused(i):
scores.append(score) toktype = 5
toktypes.append(toktype) if tokenizer.is_byte(i):
toktype = 6
if Path(dir_model + "/added_tokens.json").is_file(): tokens.append(text)
with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: scores.append(score)
addtokens_json = json.load(f) toktypes.append(toktype)
print("gguf: get added tokens") added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)
for key in addtokens_json: print("gguf: get added tokens")
tokens.append( key.encode("utf-8") )
scores.append(-1000.0) for key in addtokens_json:
toktypes.append(4) # user-defined token type tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type
gguf_writer.add_tokenizer_model("llama") gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
print("gguf: get special token ids")
if Path(dir_model + "/tokenizer.json").is_file():
# Look for special tokens in tokenizer.json if it exists
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer = json.load(f)
if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file():
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]["content"]:
gguf_writer.add_bos_token_id(key["id"])
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]["content"]:
gguf_writer.add_eos_token_id(key["id"])
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]["content"]:
gguf_writer.add_unk_token_id(key["id"])
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]["content"]:
gguf_writer.add_sep_token_id(key["id"])
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]["content"]:
gguf_writer.add_pad_token_id(key["id"])
else:
# If no tokenizer.json: Look for special tokens in config.json
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
special_vocab = gguf.SpecialVocab(dir_model)
special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS
@ -254,13 +204,15 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
print("gguf: get tensor metadata") print("gguf: get tensor metadata")
if num_parts == 0: if num_parts == 0:
part_names = ("pytorch_model.bin",) part_names = iter(("pytorch_model.bin",))
else: else:
part_names = ( part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
) )
for part_name in part_names: for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
@ -286,11 +238,8 @@ for part_name in part_names:
data = reverse_hf_permute(data, head_count, head_count_kv) data = reverse_hf_permute(data, head_count, head_count_kv)
# map tensor names # map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map: new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
name = tensor_map[name[:-7]] + ".weight" if new_name is None:
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
print("Can not map tensor '" + name + "'") print("Can not map tensor '" + name + "'")
sys.exit() sys.exit()
@ -309,20 +258,20 @@ for part_name in part_names:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16) data = data.astype(np.float16)
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(name, data) gguf_writer.add_tensor(new_name, data)
print("gguf: write header") print("gguf: write header")
gguf_writer.write_header_to_file() gguf_writer.write_header_to_file()
print("gguf: write metadata") print("gguf: write metadata")
gguf_writer.write_kv_data_to_file() gguf_writer.write_kv_data_to_file()
print("gguf: write tensors") if not args.vocab_only:
gguf_writer.write_tensors_to_file() print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close() gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("gguf: model successfully exported to '" + fname_out + "'")
print("") print("")

View file

@ -4,7 +4,7 @@ import os
import re import re
import struct import struct
import sys import sys
from typing import Any, Dict, Sequence, TextIO from typing import Any, Dict, Sequence, BinaryIO
import numpy as np import numpy as np
import torch import torch
@ -46,7 +46,7 @@ def translate_tensor_name(t: str) -> str:
sys.exit(1) sys.exit(1)
def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None: def write_file_header(fout: BinaryIO, params: Dict[str, Any]) -> None:
fout.write(b"ggla"[::-1]) # magic (ggml lora) fout.write(b"ggla"[::-1]) # magic (ggml lora)
fout.write(struct.pack("i", 1)) # file version fout.write(struct.pack("i", 1)) # file version
fout.write(struct.pack("i", params["r"])) fout.write(struct.pack("i", params["r"]))
@ -60,7 +60,7 @@ def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
def write_tensor_header( def write_tensor_header(
self, name: str, shape: Sequence[int], data_type: np.dtype self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
) -> None: ) -> None:
sname = name.encode("utf-8") sname = name.encode("utf-8")
fout.write( fout.write(

View file

@ -25,7 +25,7 @@ import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union) from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore from sentencepiece import SentencePieceProcessor # type: ignore
if TYPE_CHECKING: if TYPE_CHECKING:
@ -299,8 +299,10 @@ class Params:
params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
elif orig_config_path.exists(): elif orig_config_path.exists():
params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
else: elif model_plus.format != 'none':
params = Params.guessed(model_plus.model) params = Params.guessed(model_plus.model)
else:
raise ValueError('Cannot guess params when model format is none')
params.path_model = model_plus.paths[0].parent params.path_model = model_plus.paths[0].parent
@ -353,7 +355,7 @@ class BpeVocab:
yield from self.added_tokens() yield from self.added_tokens()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class SentencePieceVocab: class SentencePieceVocab:
@ -416,7 +418,6 @@ class SentencePieceVocab:
Vocab = Union[BpeVocab, SentencePieceVocab] Vocab = Union[BpeVocab, SentencePieceVocab]
# #
# data loading # data loading
# TODO: reuse (probably move to gguf.py?) # TODO: reuse (probably move to gguf.py?)
@ -439,14 +440,14 @@ class Tensor(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ... def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
@abstractmethod @abstractmethod
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ...
@abstractmethod @abstractmethod
def part(self, n_part: int) -> 'UnquantizedTensor': ... def part(self, n_part: int) -> 'UnquantizedTensor': ...
@abstractmethod @abstractmethod
def to_ggml(self) -> 'GGMLCompatibleTensor': ... def to_ggml(self) -> 'GGMLCompatibleTensor': ...
def bf16_to_fp32(bf16_arr: np.ndarray) -> np.ndarray: def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
fp32_arr = bf16_arr.astype(np.uint32) << 16 fp32_arr = bf16_arr.astype(np.uint32) << 16
return fp32_arr.view(np.float32) return fp32_arr.view(np.float32)
@ -467,9 +468,9 @@ class UnquantizedTensor(Tensor):
def to_ggml(self) -> 'UnquantizedTensor': def to_ggml(self) -> 'UnquantizedTensor':
return self return self
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
r = self.ndarray.shape[0] // 3 r = self.ndarray.shape[0] // 3
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head)) return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
def part(self, n_part: int) -> 'UnquantizedTensor': def part(self, n_part: int) -> 'UnquantizedTensor':
r = self.ndarray.shape[0] // 3 r = self.ndarray.shape[0] // 3
@ -531,7 +532,7 @@ LazyModel = Dict[str, LazyTensor]
class ModelPlus: class ModelPlus:
model: LazyModel model: LazyModel
paths: List[Path] # Where this was read from. paths: List[Path] # Where this was read from.
format: Literal['ggml', 'torch', 'safetensors'] format: Literal['ggml', 'torch', 'safetensors', 'none']
vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab. vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab.
@ -597,12 +598,12 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTe
return lazy_tensor.load().permute(n_head, n_head_kv) return lazy_tensor.load().permute(n_head, n_head_kv)
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor:
def load() -> Tensor: def load() -> Tensor:
return lazy_tensor.load().permute_part(n_part, n_head) return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv)
s = lazy_tensor.shape.copy() s = lazy_tensor.shape.copy()
s[0] = s[0] // 3 s[0] = s[0] // 3
return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description) return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
def load() -> Tensor: def load() -> Tensor:
@ -657,7 +658,7 @@ class LazyUnpickler(pickle.Unpickler):
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
return LazyStorage(load=load, kind=pid[1], description=description) return LazyStorage(load=load, kind=pid[1], description=description)
# @staticmethod @staticmethod
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
# pyright: ignore[reportSelfClsParameterName] # pyright: ignore[reportSelfClsParameterName]
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
@ -669,13 +670,15 @@ class LazyUnpickler(pickle.Unpickler):
description = f'pickled storage_offset={storage_offset} in {storage.description}' description = f'pickled storage_offset={storage_offset} in {storage.description}'
return LazyTensor(load, list(size), storage.kind.data_type, description) return LazyTensor(load, list(size), storage.kind.data_type, description)
# @staticmethod @staticmethod
def rebuild_from_type_v2(func, new_type, args, state): def rebuild_from_type_v2(func, new_type, args, state):
return func(*args) return func(*args)
CLASSES: Dict[Any, Any] = { CLASSES: Dict[Tuple[str, str], Any] = {
('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2, # getattr used here as a workaround for mypy not being smart enough to detrmine
('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2, # the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
('torch', 'HalfStorage'): LazyStorageKind(DT_F16), ('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
('torch', 'FloatStorage'): LazyStorageKind(DT_F32), ('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
@ -751,7 +754,7 @@ def lazy_load_file(path: Path) -> ModelPlus:
In = TypeVar('In') In = TypeVar('In')
Out = TypeVar('Out') Out = TypeVar('Out')
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]: def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next` '''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one letting results pile up in memory. Specifically, there is a max of one
@ -760,7 +763,12 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
yield from map(func, iterable) yield from map(func, iterable)
# Not reached. # Not reached.
iterable = iter(iterable) iterable = iter(iterable)
with factory(max_workers = max_workers) as executor: executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]]
if use_processpool_executor:
executor_class = ProcessPoolExecutor
else:
executor_class = ThreadPoolExecutor
with executor_class(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = [] futures: List[concurrent.futures.Future[Out]] = []
done = False done = False
for _ in range(concurrency): for _ in range(concurrency):
@ -838,11 +846,19 @@ class OutputFile:
scores.append(score) scores.append(score)
toktypes.append(toktype) toktypes.append(toktype)
self.gguf.add_tokenizer_model("llama") if isinstance(vocab, SentencePieceVocab):
self.gguf.add_tokenizer_model("llama")
elif isinstance(vocab, BpeVocab):
self.gguf.add_tokenizer_model("gpt2")
else:
raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab')
self.gguf.add_token_list(tokens) self.gguf.add_token_list(tokens)
self.gguf.add_token_scores(scores) self.gguf.add_token_scores(scores)
self.gguf.add_token_types(toktypes) self.gguf.add_token_types(toktypes)
def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None:
svocab.add_to_gguf(self.gguf)
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
n_elements = int(np.prod(tensor.shape)) n_elements = int(np.prod(tensor.shape))
raw_dtype = getattr(tensor.data_type, 'ggml_type', None) raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
@ -861,7 +877,7 @@ class OutputFile:
self.gguf.close() self.gguf.close()
@staticmethod @staticmethod
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None: def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
@ -869,6 +885,8 @@ class OutputFile:
# meta data # meta data
of.add_meta_arch(params) of.add_meta_arch(params)
of.add_meta_vocab(vocab) of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
of.write_meta() of.write_meta()
of.close() of.close()
@ -887,7 +905,7 @@ class OutputFile:
return dt.quantize(arr) return dt.quantize(arr)
@staticmethod @staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None: def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
@ -895,6 +913,7 @@ class OutputFile:
# meta data # meta data
of.add_meta_arch(params) of.add_meta_arch(params)
of.add_meta_vocab(vocab) of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
# tensor info # tensor info
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
@ -906,7 +925,7 @@ class OutputFile:
# tensor data # tensor data
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0: if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor) ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True)
else: else:
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
@ -939,7 +958,8 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
for (name, tensor) in model.items()} for (name, tensor) in model.items()}
def convert_model_names(model: LazyModel, params: Params) -> LazyModel: def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
tmap = gguf.get_tensor_name_map(ARCH, params.n_layer) tmap = gguf.TensorNameMap(ARCH, params.n_layer)
should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
tmp = model tmp = model
@ -952,8 +972,8 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
#tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] #tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
elif f"model.layers.{i}.self_attn.W_pack.weight" in model: elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
print(f"Unpacking and permuting layer {i}") print(f"Unpacking and permuting layer {i}")
tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head) tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head)
tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head) tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv)
tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
else: else:
@ -961,23 +981,16 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
out: LazyModel = {} out: LazyModel = {}
for name, lazy_tensor in model.items(): for name, lazy_tensor in model.items():
name_new = name tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None)
if name_new is None:
if name in tmap:
name_new = tmap[name]
elif name.endswith(".weight") and name[:-7] in tmap:
name_new = tmap[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tmap:
name_new = tmap[name[:-5]] + ".bias"
else:
raise Exception(f"Unexpected tensor name: {name}") raise Exception(f"Unexpected tensor name: {name}")
if gguf.should_skip_tensor_TMP(ARCH, params.n_layer, name_new): if tensor_type in should_skip:
print(f"skipping tensor {name_new}") print(f"skipping tensor {name_new}")
continue continue
else:
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
out[name_new] = lazy_tensor out[name_new] = lazy_tensor
return out return out
@ -1117,8 +1130,16 @@ def main(args_in: Optional[List[str]] = None) -> None:
if args.dump_single: if args.dump_single:
model_plus = lazy_load_file(args.model) model_plus = lazy_load_file(args.model)
do_dump_model(model_plus) do_dump_model(model_plus)
return
model_plus = load_some_model(args.model) if not args.vocab_only:
model_plus = load_some_model(args.model)
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
if args.dump:
do_dump_model(model_plus)
return
params = Params.load(model_plus) params = Params.load(model_plus)
if params.n_ctx == -1: if params.n_ctx == -1:
@ -1140,33 +1161,34 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab: Vocab vocab: Vocab
if args.vocab_only: if args.vocab_only:
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
assert args.outfile, "need --outfile if using --vocab-only" assert args.outfile, "need --outfile if using --vocab-only"
# FIXME: Try to respect vocab_dir somehow?
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
outfile = args.outfile outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab) OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
return
if model_plus.vocab is not None and args.vocab_dir is None:
vocab = model_plus.vocab
else: else:
if args.dump: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
do_dump_model(model_plus) vocab = load_vocab(vocab_dir, args.vocabtype)
return # FIXME: Try to respect vocab_dir somehow?
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe')
if model_plus.vocab is not None and args.vocab_dir is None: model = model_plus.model
vocab = model_plus.vocab model = convert_model_names(model, params)
else: ftype = pick_output_type(model, args.outtype)
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent model = convert_to_output_type(model, ftype)
vocab = load_vocab(vocab_dir, args.vocabtype) outfile = args.outfile or default_outfile(model_plus.paths, ftype)
model = model_plus.model params.ftype = ftype
model = convert_model_names(model, params) print(f"Writing {outfile}, format {ftype}")
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
params.ftype = ftype OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency)
print(f"Writing {outfile}, format {ftype}") print(f"Wrote {outfile}")
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
print(f"Wrote {outfile}")
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -4,9 +4,13 @@ import sys
import struct import struct
import tempfile import tempfile
import numpy as np import numpy as np
import json
import os
from pathlib import Path
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any, IO, List, Optional from io import BufferedWriter
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
# #
# constants # constants
@ -71,35 +75,35 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
class MODEL_ARCH(IntEnum): class MODEL_ARCH(IntEnum):
LLAMA = auto() LLAMA : int = auto()
FALCON = auto() FALCON : int = auto()
GPT2 = auto() GPT2 : int = auto()
GPTJ = auto() GPTJ : int = auto()
GPTNEOX = auto() GPTNEOX: int = auto()
MPT = auto() MPT : int = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto() TOKEN_EMBD : int = auto()
POS_EMBD = auto() POS_EMBD : int = auto()
OUTPUT = auto() OUTPUT : int = auto()
OUTPUT_NORM = auto() OUTPUT_NORM : int = auto()
ROPE_FREQS = auto() ROPE_FREQS : int = auto()
ATTN_Q = auto() ATTN_Q : int = auto()
ATTN_K = auto() ATTN_K : int = auto()
ATTN_V = auto() ATTN_V : int = auto()
ATTN_QKV = auto() ATTN_QKV : int = auto()
ATTN_OUT = auto() ATTN_OUT : int = auto()
ATTN_NORM = auto() ATTN_NORM : int = auto()
ATTN_NORM_2 = auto() ATTN_NORM_2 : int = auto()
ATTN_ROT_EMBD = auto() ATTN_ROT_EMBD: int = auto()
FFN_GATE = auto() FFN_GATE : int = auto()
FFN_DOWN = auto() FFN_DOWN : int = auto()
FFN_UP = auto() FFN_UP : int = auto()
FFN_NORM = auto() FFN_NORM : int = auto()
MODEL_ARCH_NAMES = { MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPT2: "gpt2",
@ -108,7 +112,7 @@ MODEL_ARCH_NAMES = {
MODEL_ARCH.MPT: "mpt", MODEL_ARCH.MPT: "mpt",
} }
MODEL_TENSOR_NAMES = { MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
MODEL_ARCH.LLAMA: { MODEL_ARCH.LLAMA: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT_NORM: "output_norm",
@ -154,7 +158,7 @@ MODEL_TENSOR_NAMES = {
} }
# tensors that will not be serialized # tensors that will not be serialized
MODEL_TENSOR_SKIP = { MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [ MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
@ -162,167 +166,198 @@ MODEL_TENSOR_SKIP = {
} }
# TODO: the following helper functions should be removed class TensorNameMap:
# instead, get_tensor_name_map should return tuples of (name, MODEL_TENSOR) mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
# however, my Python is very bad, and I couldn't figure out how to do this, hence these functions # Token embeddings
# REMOVE MODEL_TENSOR.TOKEN_EMBD: (
def should_skip_tensor_TMP(arch: MODEL_ARCH, n_blocks: int, name: str) -> bool: "gpt_neox.embed_in", # gptneox
for skip in MODEL_TENSOR_SKIP.get(arch, []): "transformer.wte", # gpt2 mpt
for i in range(n_blocks): "transformer.word_embeddings", # falcon
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i): "model.embed_tokens", # llama-hf
return True "tok_embeddings", # llama-pth
),
return False # Position embeddings
MODEL_TENSOR.POS_EMBD: (
"transformer.wpe", # gpt2
),
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf
"output", # llama-pth
),
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict: # Output norm
tensor_map = {} MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 falcon
"model.norm", # llama-hf
"norm", # llama-pth
),
# Token embeddings # Rope frequencies
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None) MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs", # llama-pth
),
}
tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
tensor_map["transformer.wte"] = mapped_to # gpt2 mpt
tensor_map["transformer.word_embeddings"] = mapped_to # falcon
tensor_map["model.embed_tokens"] = mapped_to # llama-hf
tensor_map["tok_embeddings"] = mapped_to # llama-pth
# Position embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None)
tensor_map["transformer.wpe"] = mapped_to # gpt2
# Output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None)
tensor_map["embed_out"] = mapped_to # gptneox
tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf
tensor_map["output"] = mapped_to # llama-pth
# Output norm
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None)
tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox
tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon
tensor_map["transformer.norm_f"] = mapped_to # mpt
tensor_map["model.norm"] = mapped_to # llama-hf
tensor_map["norm"] = mapped_to # llama-pth
# Rope frequencies
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None)
tensor_map["rope.freqs"] = mapped_to # llama-pth
# Attention and feed-forward blocks
for i in range(0, n_blocks):
# Attention norm # Attention norm
# TODO: is there are simpler way to write these 2 lines in Python? MODEL_TENSOR.ATTN_NORM: (
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None) "gpt_neox.layers.{bid}.input_layernorm", # gptneox
mapped_to = mapped_to.format(bid=i) if mapped_to else None "transformer.h.{bid}.ln_1", # gpt2
"transformer.blocks.{bid}.norm_1", # mpt
tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox "transformer.h.{bid}.input_layernorm", # falcon7b
tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 "transformer.h.{bid}.ln_mlp", # falcon40b
tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt "model.layers.{bid}.input_layernorm", # llama-hf
tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b "layers.{bid}.attention_norm", # llama-pth
tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b ),
tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth
# Attention norm 2 # Attention norm 2
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None) MODEL_TENSOR.ATTN_NORM_2: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "transformer.h.{bid}.ln_attn", # falcon40b
),
tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b
# Attention query-key-value # Attention query-key-value
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None) MODEL_TENSOR.ATTN_QKV: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2
tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox "transformer.blocks.{bid}.attn.Wqkv", # mpt
tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 "transformer.h.{bid}.self_attention.query_key_value", # falcon
tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt ),
tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon
# Attention query # Attention query
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None) MODEL_TENSOR.ATTN_Q: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth
tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth
# Attention key # Attention key
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None) MODEL_TENSOR.ATTN_K: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth
tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth
# Attention value # Attention value
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None) MODEL_TENSOR.ATTN_V: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth
tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth
# Attention output # Attention output
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None) MODEL_TENSOR.ATTN_OUT: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2
tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox "transformer.blocks.{bid}.attn.out_proj", # mpt
tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 "transformer.h.{bid}.self_attention.dense", # falcon
tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt "model.layers.{bid}.self_attn.o_proj", # llama-hf
tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon "layers.{bid}.attention.wo", # llama-pth
tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth
# Rotary embeddings # Rotary embeddings
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None) MODEL_TENSOR.ATTN_ROT_EMBD: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth
# Feed-forward norm # Feed-forward norm
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None) MODEL_TENSOR.FFN_NORM: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2
tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox "transformer.blocks.{bid}.norm_2", # mpt
tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 "model.layers.{bid}.post_attention_layernorm", # llama-hf
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt "layers.{bid}.ffn_norm", # llama-pth
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
# Feed-forward up # Feed-forward up
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None) MODEL_TENSOR.FFN_UP: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
"transformer.h.{bid}.mlp.c_fc", # gpt2
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox "transformer.blocks.{bid}.ffn.up_proj", # mpt
tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt "model.layers.{bid}.mlp.up_proj", # llama-hf
tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon "layers.{bid}.feed_forward.w3", # llama-pth
tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth
# Feed-forward gate # Feed-forward gate
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None) MODEL_TENSOR.FFN_GATE: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "model.layers.{bid}.mlp.gate_proj", # llama-hf
"layers.{bid}.feed_forward.w1", # llama-pth
tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf ),
tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth
# Feed-forward down # Feed-forward down
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None) MODEL_TENSOR.FFN_DOWN: (
mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"model.layers.{bid}.mlp.down_proj", # llama-hf
"layers.{bid}.feed_forward.w2", # llama-pth
),
}
tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox mapping: Dict[str, Tuple[MODEL_TENSOR, str]]
tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2
tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt
tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon
tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth
return tensor_map tensor_names: Dict[MODEL_TENSOR, str]
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
mapping = self.mapping = {}
tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch]
for tensor, keys in self.mappings_cfg.items():
tensor_name = tensor_names.get(tensor)
if tensor_name is None:
continue
for key in keys:
mapping[key] = (tensor, tensor_name)
for bid in range(n_blocks):
for tensor, keys in self.block_mappings_cfg.items():
tensor_name = tensor_names.get(tensor)
if tensor_name is None:
continue
tensor_name = tensor_name.format(bid = bid)
for key in keys:
key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
result = self.mapping.get(key)
if result is not None:
return result
for suffix in try_suffixes:
if key.endswith(suffix):
result = self.mapping.get(key[:-len(suffix)])
if result is not None:
return (result[0], result[1] + suffix)
return None
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[1]
def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[0]
def __getitem__(self, key: str) -> str:
try:
return self.mapping[key][1]
except KeyError:
raise KeyError(key)
def __contains__(self, key: str) -> bool:
return key in self.mapping
def __repr__(self) -> str:
return repr(self.mapping)
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
return TensorNameMap(arch, n_blocks)
class TokenType(IntEnum): class TokenType(IntEnum):
NORMAL = 1 NORMAL = 1
@ -388,15 +423,21 @@ class GGUFValueType(IntEnum):
class GGUFWriter: class GGUFWriter:
def __init__(self, path: str, arch: str, use_temp_file = True): fout: BufferedWriter
arch: str
offset_tensor = 0
data_alignment = GGUF_DEFAULT_ALIGNMENT
kv_data = b""
kv_data_count = 0
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray[Any, Any], int]]
def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True):
self.fout = open(path, "wb") self.fout = open(path, "wb")
self.arch = arch self.arch = arch
self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b""
self.kv_data_count = 0
self.ti_data = b""
self.ti_data_count = 0
self.add_architecture() self.add_architecture()
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.tensors = [] self.tensors = []
@ -470,14 +511,27 @@ class GGUFWriter:
self.add_key(key) self.add_key(key)
self.add_val(val, GGUFValueType.STRING) self.add_val(val, GGUFValueType.STRING)
def add_array(self, key: str, val: list): def add_array(self, key: str, val: Sequence[Any]):
if not isinstance(val, list): if not isinstance(val, Sequence):
raise ValueError("Value must be a list for array type") raise ValueError("Value must be a sequence for array type")
self.add_key(key) self.add_key(key)
self.add_val(val, GGUFValueType.ARRAY) self.add_val(val, GGUFValueType.ARRAY)
def add_val(self: str, val: Any, vtype: GGUFValueType = None, add_vtype: bool = True): _simple_value_packing = {
GGUFValueType.UINT8: "<B",
GGUFValueType.INT8: "<b",
GGUFValueType.UINT16: "<H",
GGUFValueType.INT16: "<h",
GGUFValueType.UINT32: "<I",
GGUFValueType.INT32: "<i",
GGUFValueType.FLOAT32: "<f",
GGUFValueType.UINT64: "<Q",
GGUFValueType.INT64: "<q",
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?" ,
}
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
if vtype is None: if vtype is None:
vtype = GGUFValueType.get_type(val) vtype = GGUFValueType.get_type(val)
@ -485,47 +539,29 @@ class GGUFWriter:
self.kv_data += struct.pack("<I", vtype) self.kv_data += struct.pack("<I", vtype)
self.kv_data_count += 1 self.kv_data_count += 1
if vtype == GGUFValueType.UINT8: pack_fmt = self._simple_value_packing.get(vtype)
self.kv_data += struct.pack("<B", val) if pack_fmt is not None:
elif vtype == GGUFValueType.INT8: self.kv_data += struct.pack(pack_fmt, val)
self.kv_data += struct.pack("<b", val)
elif vtype == GGUFValueType.UINT16:
self.kv_data += struct.pack("<H", val)
elif vtype == GGUFValueType.INT16:
self.kv_data += struct.pack("<h", val)
elif vtype == GGUFValueType.UINT32:
self.kv_data += struct.pack("<I", val)
elif vtype == GGUFValueType.INT32:
self.kv_data += struct.pack("<i", val)
elif vtype == GGUFValueType.FLOAT32:
self.kv_data += struct.pack("<f", val)
elif vtype == GGUFValueType.UINT64:
self.kv_data += struct.pack("<Q", val)
elif vtype == GGUFValueType.INT64:
self.kv_data += struct.pack("<q", val)
elif vtype == GGUFValueType.FLOAT64:
self.kv_data += struct.pack("<d", val)
elif vtype == GGUFValueType.BOOL:
self.kv_data += struct.pack("?", val)
elif vtype == GGUFValueType.STRING: elif vtype == GGUFValueType.STRING:
encoded_val = val.encode("utf8") if isinstance(val, str) else val encoded_val = val.encode("utf8") if isinstance(val, str) else val
self.kv_data += struct.pack("<Q", len(encoded_val)) self.kv_data += struct.pack("<Q", len(encoded_val))
self.kv_data += encoded_val self.kv_data += encoded_val
elif vtype == GGUFValueType.ARRAY: elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
ltype = set([GGUFValueType.get_type(item) for item in val]) ltype = GGUFValueType.get_type(val[0])
assert len(ltype) == 1, "All items in a GGUF array should be of the same type" if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
self.kv_data += struct.pack("<I", list(ltype)[0]) raise ValueError("All items in a GGUF array should be of the same type")
self.kv_data += struct.pack("<I", ltype)
self.kv_data += struct.pack("<Q", len(val)) self.kv_data += struct.pack("<Q", len(val))
for item in val: for item in val:
self.add_val(item, add_vtype=False) self.add_val(item, add_vtype=False)
else: else:
raise ValueError("Invalid GGUF metadata value type") raise ValueError("Invalid GGUF metadata value type or value")
@staticmethod @staticmethod
def ggml_pad(x: int, n: int) -> int: def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None): def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8") encoded_name = name.encode("utf8")
@ -544,16 +580,18 @@ class GGUFWriter:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment) self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1 self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None): def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
if self.use_temp_file and not hasattr(self, "temp_file"): if self.use_temp_file and self.temp_file is None:
self.temp_file = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
self.temp_file.seek(0) fp.seek(0)
self.temp_file = fp
self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if not self.use_temp_file: if self.temp_file is None:
self.tensors.append((tensor, pad)) self.tensors.append((tensor, pad))
return return
@ -562,25 +600,22 @@ class GGUFWriter:
if pad != 0: if pad != 0:
self.temp_file.write(bytes([0] * pad)) self.temp_file.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray): def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0: if pad != 0:
self.fout.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
self.write_padding(self.fout, self.fout.tell())
tensor.tofile(self.fout) tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if pad != 0:
self.fout.write(bytes([0] * pad))
def write_tensors_to_file(self): def write_tensors_to_file(self):
self.write_ti_data_to_file() self.write_ti_data_to_file()
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() self.write_padding(self.fout, self.fout.tell())
if pad != 0:
self.fout.write(bytes([0] * pad))
if not self.use_temp_file: if self.temp_file is None:
for (currtensor, currpad) in self.tensors: for (currtensor, currpad) in self.tensors:
currtensor.tofile(self.fout) currtensor.tofile(self.fout)
if currpad != 0: if currpad != 0:
@ -654,10 +689,6 @@ class GGUFWriter:
self.add_bool( self.add_bool(
KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
def add_tensor_data_layout(self, layout: str):
self.add_string(
KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_head_count(self, count: int): def add_head_count(self, count: int):
self.add_uint32( self.add_uint32(
KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count) KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
@ -695,16 +726,16 @@ class GGUFWriter:
def add_tokenizer_model(self, model: str): def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model) self.add_string(KEY_TOKENIZER_MODEL, model)
def add_token_list(self, tokens: List): def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
self.add_array(KEY_TOKENIZER_LIST, tokens) self.add_array(KEY_TOKENIZER_LIST, tokens)
def add_token_merges(self, merges: List): def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
self.add_array(KEY_TOKENIZER_MERGES, merges) self.add_array(KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: List[int]): def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]):
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types) self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: List[float]): def add_token_scores(self, scores: Sequence[float]):
self.add_array(KEY_TOKENIZER_SCORES, scores) self.add_array(KEY_TOKENIZER_SCORES, scores)
def add_bos_token_id(self, id: int): def add_bos_token_id(self, id: int):
@ -723,6 +754,84 @@ class GGUFWriter:
self.add_uint32(KEY_TOKENIZER_PAD_ID, id) self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
class SpecialVocab:
load_merges: bool = False
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
self.special_token_ids = {}
self.load_merges = load_merges
if special_token_types is not None:
self.special_token_types = special_token_types
self.load(path)
def load(self, path: Path):
if not self.try_load_from_tokenizer_json(path):
self.try_load_from_config_json(path)
def try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file():
return False
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
tokenizer = json.load(f)
if self.load_merges:
merges = tokenizer.get('model', {}).get('merges')
if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
self.merges = merges
tokenizer_config_file = path / 'tokenizer_config.json'
added_tokens = tokenizer.get('added_tokens')
if added_tokens is None or not tokenizer_config_file.is_file():
return True
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
for typ in self.special_token_types:
entry = tokenizer_config.get(f'{typ}_token')
if isinstance(entry, str):
tc_content = entry
elif isinstance(entry, dict):
entry_content = entry.get('content')
if not isinstance(entry_content, str):
continue
tc_content = entry_content
else:
continue
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content):
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
break
return True
def try_load_from_config_json(self, path: Path) -> bool:
config_file = path / 'config.json'
if not config_file.is_file():
return False
with open(config_file, 'r', encoding = 'utf-8') as f:
config = json.load(f)
for typ in self.special_token_types:
maybe_token_id = config.get(f'{typ}_token_id')
if isinstance(maybe_token_id, int):
self.special_token_ids[typ] = maybe_token_id
return True
def add_to_gguf(self, gw: GGUFWriter):
if len(self.merges) > 0:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
continue
print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid)
def __repr__(self):
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>'
# Example usage: # Example usage:
if __name__ == "__main__": if __name__ == "__main__":
# Example usage with a file # Example usage with a file

0
gguf-py/gguf/py.typed Normal file
View file

View file

@ -5,6 +5,7 @@ description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [
{include = "gguf"}, {include = "gguf"},
{include = "gguf/py.typed"},
] ]
readme = "README.md" readme = "README.md"
homepage = "https://ggml.ai" homepage = "https://ggml.ai"