convert : fix python 3.8 support, modernize type annotations (#2916)

* convert : fix python 3.8 support

* convert : sort imports

* convert : fix required parameters in convert-llama-ggmlv3-to-gguf

* convert : fix mypy errors in convert-llama-ggmlv3-to-gguf

* convert : use PEP 585 generics and PEP 604 unions

Now that we have `from __future__ import annotations`, we can use this
modern syntax in Python 3.7 instead of restricting support to Python 3.9
or 3.10 respectively.

* gguf.py : a tuple is already a tuple

* add mypy.ini

* convert : add necessary `type: ignore` comments

* gguf-py: bump version
This commit is contained in:
Cebtenzzre 2023-08-31 01:02:23 -04:00 committed by GitHub
parent 8afe228000
commit 92d0b751a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 193 additions and 168 deletions

View file

@ -1,18 +1,21 @@
#!/usr/bin/env python3
# HF falcon--> gguf conversion
import gguf
import os
import sys
import struct
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any
import gguf
import numpy as np
import torch
import argparse
from transformers import AutoTokenizer # type: ignore[import]
from typing import Any, List
from pathlib import Path
from transformers import AutoTokenizer
def bytes_to_unicode():
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
@ -114,9 +117,9 @@ gguf_writer.add_file_type(ftype)
print("gguf: get tokenizer metadata")
tokens: List[bytearray] = []
scores: List[float] = []
toktypes: List[int] = []
tokens: list[bytearray] = []
scores: list[float] = []
toktypes: list[int] = []
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():

View file

@ -1,18 +1,20 @@
#!/usr/bin/env python3
# HF gptneox--> gguf conversion
import gguf
import os
import sys
import struct
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any
import gguf
import numpy as np
import torch
import argparse
from typing import Any, List
from pathlib import Path
from transformers import AutoTokenizer
from transformers import AutoTokenizer # type: ignore[import]
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
@ -112,7 +114,7 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_eps"])
print("gguf: get tokenizer metadata")
tokens: List[bytearray] = []
tokens: list[bytearray] = []
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():

View file

@ -3,22 +3,25 @@
# Only models with a single datafile are supported, like 7B
# HF files required in the model dir: config.json tokenizer_config.json tokenizer.json tokenizer.model
import gguf
import os
import sys
import struct
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
import gguf
import numpy as np
import torch
import argparse
from sentencepiece import SentencePieceProcessor # type: ignore[import]
from typing import Any, List, TypeAlias
from pathlib import Path
from sentencepiece import SentencePieceProcessor
if TYPE_CHECKING:
from typing import TypeAlias
#NDArray = np.ndarray[Any, Any]
# compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
def count_model_parts(dir_model: Path) -> int:
@ -129,9 +132,9 @@ if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in
print("gguf: get tokenizer metadata")
tokens: List[bytes] = []
scores: List[float] = []
toktypes: List[int] = []
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
tokenizer_model_file = dir_model / 'tokenizer.model'
if not tokenizer_model_file.is_file():

View file

@ -1,10 +1,14 @@
#!/usr/bin/env python3
import sys, struct, math, argparse
from __future__ import annotations
import argparse
import math
import struct
import sys
from pathlib import Path
import numpy as np
import gguf
import numpy as np
# Note: Does not support GGML_QKK_64
QK_K = 256
@ -72,7 +76,7 @@ class Vocab:
class Tensor:
def __init__(self):
self.name = None
self.dims = ()
self.dims: tuple[int, ...] = ()
self.dtype = None
self.start_offset = 0
self.len_bytes = np.int64(0)
@ -119,7 +123,7 @@ class GGMLV3Model:
offset += hp.load(data, offset)
vocab = Vocab()
offset += vocab.load(data, offset, hp.n_vocab)
tensors = []
tensors: list[Tensor] = []
tensor_map = {}
while offset < len(data):
tensor = Tensor()
@ -305,8 +309,8 @@ def handle_metadata(cfg, hp):
def handle_args():
parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF')
parser.add_argument('--input', '-i', type = Path, help = 'Input GGMLv3 filename')
parser.add_argument('--output', '-o', type = Path, help ='Output GGUF filename')
parser.add_argument('--input', '-i', type = Path, required = True, help = 'Input GGMLv3 filename')
parser.add_argument('--output', '-o', type = Path, required = True, help ='Output GGUF filename')
parser.add_argument('--name', help = 'Set model name')
parser.add_argument('--desc', help = 'Set model description')
parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)')

View file

@ -1,28 +1,31 @@
#!/usr/bin/env python3
# HF llama --> gguf conversion
import gguf
import os
import sys
import struct
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
import gguf
import numpy as np
import torch
import argparse
from sentencepiece import SentencePieceProcessor # type: ignore[import]
from typing import Any, List, Optional, TypeAlias
from pathlib import Path
from sentencepiece import SentencePieceProcessor
if TYPE_CHECKING:
from typing import TypeAlias
#NDArray = np.ndarray[Any, Any]
# compatible with python < 3.9
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
# reverse HF permute back to original pth layout
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
@ -136,9 +139,9 @@ if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in
print("gguf: get tokenizer metadata")
tokens: List[bytes] = []
scores: List[float] = []
toktypes: List[int] = []
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
tokenizer_model_file = dir_model / 'tokenizer.model'
if not tokenizer_model_file.is_file():

View file

@ -1,15 +1,17 @@
#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import re
import struct
import sys
from typing import Any, Dict, Sequence, BinaryIO
from typing import Any, BinaryIO, Sequence
import numpy as np
import torch
NUMPY_TYPE_TO_FTYPE: Dict[str, int] = {"float32": 0, "float16": 1}
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
HF_SUBLAYER_TO_GGML = {
@ -46,7 +48,7 @@ def translate_tensor_name(t: str) -> str:
sys.exit(1)
def write_file_header(fout: BinaryIO, params: Dict[str, Any]) -> None:
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
fout.write(b"ggla"[::-1]) # magic (ggml lora)
fout.write(struct.pack("i", 1)) # file version
fout.write(struct.pack("i", params["r"]))

View file

@ -1,9 +1,8 @@
#!/usr/bin/env python3
from __future__ import annotations
import gguf
import argparse
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import copy
import enum
import faulthandler
@ -20,21 +19,23 @@ import struct
import sys
import time
import zipfile
import numpy as np
from abc import ABCMeta, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore
from typing import IO, TYPE_CHECKING, Any, Callable, Generator, Iterable, Literal, Sequence, TypeVar
import gguf
import numpy as np
from sentencepiece import SentencePieceProcessor # type: ignore[import]
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing import TypeAlias
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
faulthandler.register(signal.SIGUSR1)
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
@ -47,8 +48,8 @@ DEFAULT_CONCURRENCY = 8
@dataclass(frozen=True)
class DataType:
name: str
dtype: 'np.dtype[Any]'
valid_conversions: List[str]
dtype: np.dtype[Any]
valid_conversions: list[str]
def elements_to_bytes(self, n_elements: int) -> int:
return n_elements * self.dtype.itemsize
@ -65,7 +66,7 @@ DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_convers
@dataclass(frozen=True)
class QuantizedDataType(DataType):
block_size: int
quantized_dtype: 'np.dtype[Any]'
quantized_dtype: np.dtype[Any]
ggml_type: gguf.GGMLQuantizationType
def quantize(self, arr: NDArray) -> NDArray:
@ -84,7 +85,7 @@ class Q8_0QuantizedDataType(QuantizedDataType):
n_blocks = arr.size // self.block_size
blocks = arr.reshape((n_blocks, self.block_size))
# Much faster implementation of block quantization contributed by @Cebtenzzre
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]:
d = abs(blocks).max(axis = 1) / np.float32(127)
with np.errstate(divide = 'ignore'):
qs = (blocks / d[:, None]).round()
@ -98,13 +99,13 @@ DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
# Quantized types skipped here because they may also map to np.float32
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {}
for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
raise ValueError(f'Invalid duplicate data type {dt}')
NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
SAFETENSORS_DATA_TYPES: dict[str, DataType] = {
'BF16': DT_BF16,
'F16': DT_F16,
'F32': DT_F32,
@ -119,14 +120,14 @@ class GGMLFileType(enum.IntEnum):
MostlyF16 = 1 # except 1d tensors
MostlyQ8_0 = 7 # except 1d tensors
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType:
dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
if dt is None:
raise ValueError(self)
# 1D tensors are always F32.
return dt if len(tensor.shape) > 1 else DT_F32
GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
GGMLFileType.AllF32 : DT_F32,
GGMLFileType.MostlyF16 : DT_F16,
GGMLFileType.MostlyQ8_0: DT_Q8_0,
@ -148,13 +149,13 @@ class Params:
n_head_kv: int
f_norm_eps: float
f_rope_freq_base: Optional[float] = None
f_rope_scale: Optional[float] = None
f_rope_freq_base: float | None = None
f_rope_scale: float | None = None
ftype: Optional[GGMLFileType] = None
ftype: GGMLFileType | None = None
# path to the directory containing the model files
path_model: Optional['Path'] = None
path_model: Path | None = None
@staticmethod
def find_n_mult(n_ff: int, n_embd: int) -> int:
@ -166,7 +167,7 @@ class Params:
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
@staticmethod
def guessed(model: 'LazyModel') -> 'Params':
def guessed(model: LazyModel) -> Params:
# try transformer naming first
n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
@ -202,7 +203,7 @@ class Params:
)
@staticmethod
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))
n_vocab = config["vocab_size"]
@ -247,7 +248,7 @@ class Params:
# LLaMA v2 70B params.json
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
@staticmethod
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path))
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
@ -291,7 +292,7 @@ class Params:
)
@staticmethod
def load(model_plus: 'ModelPlus') -> 'Params':
def load(model_plus: ModelPlus) -> Params:
hf_config_path = model_plus.paths[0].parent / "config.json"
orig_config_path = model_plus.paths[0].parent / "params.json"
@ -314,9 +315,9 @@ class Params:
#
class BpeVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
added_tokens: Dict[str, int]
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
@ -335,9 +336,9 @@ class BpeVocab:
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def bpe_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.bpe_tokenizer
from transformers.models.gpt2 import tokenization_gpt2
from transformers.models.gpt2 import tokenization_gpt2 # type: ignore[import]
byte_encoder = tokenization_gpt2.bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i, item in enumerate(tokenizer):
@ -345,12 +346,12 @@ class BpeVocab:
score: float = -i
yield text, score, gguf.TokenType.USER_DEFINED
def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list:
score = -1000.0
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.bpe_tokens()
yield from self.added_tokens()
@ -359,9 +360,9 @@ class BpeVocab:
class SentencePieceVocab:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
added_tokens: dict[str, int]
if fname_added_tokens is not None:
added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
else:
@ -380,7 +381,7 @@ class SentencePieceVocab:
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.sentencepiece_tokenizer
for i in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(i)
@ -404,19 +405,19 @@ class SentencePieceVocab:
yield text, score, toktype
def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list:
score = -1000.0
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
yield from self.sentencepiece_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
Vocab = Union[BpeVocab, SentencePieceVocab]
Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
#
# data loading
@ -436,15 +437,15 @@ class Tensor(metaclass=ABCMeta):
data_type: DataType
@abstractmethod
def astype(self, data_type: DataType) -> 'Tensor': ...
def astype(self, data_type: DataType) -> Tensor: ...
@abstractmethod
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
@abstractmethod
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': ...
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
@abstractmethod
def part(self, n_part: int) -> 'UnquantizedTensor': ...
def part(self, n_part: int) -> UnquantizedTensor: ...
@abstractmethod
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
def to_ggml(self) -> GGMLCompatibleTensor: ...
def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
@ -465,22 +466,22 @@ class UnquantizedTensor(Tensor):
self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype))
def to_ggml(self) -> 'UnquantizedTensor':
def to_ggml(self) -> UnquantizedTensor:
return self
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
r = self.ndarray.shape[0] // 3
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
def part(self, n_part: int) -> 'UnquantizedTensor':
def part(self, n_part: int) -> UnquantizedTensor:
r = self.ndarray.shape[0] // 3
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor:
return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray:
tensor = lazy_tensor.load()
assert isinstance(tensor, UnquantizedTensor)
@ -496,13 +497,13 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
return tensor.ndarray
GGMLCompatibleTensor = Union[UnquantizedTensor]
GGMLCompatibleTensor = UnquantizedTensor
@dataclass
class LazyTensor:
_load: Callable[[], Tensor]
shape: List[int]
shape: list[int]
data_type: DataType
description: str
@ -513,7 +514,7 @@ class LazyTensor:
(self.data_type, ret.data_type, self.description)
return ret
def astype(self, data_type: DataType) -> 'LazyTensor':
def astype(self, data_type: DataType) -> LazyTensor:
self.validate_conversion_to(data_type)
def load() -> Tensor:
@ -525,24 +526,24 @@ class LazyTensor:
raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
LazyModel = Dict[str, LazyTensor]
LazyModel = dict[str, LazyTensor]
@dataclass
class ModelPlus:
model: LazyModel
paths: List[Path] # Where this was read from.
paths: list[Path] # Where this was read from.
format: Literal['ggml', 'torch', 'safetensors', 'none']
vocab: Optional[Vocab] # For GGML models (which have vocab built in), the vocab.
vocab: Vocab | None # For GGML models (which have vocab built in), the vocab.
def merge_sharded(models: List[LazyModel]) -> LazyModel:
def merge_sharded(models: list[LazyModel]) -> LazyModel:
# Original LLaMA models have each file contain one part of each tensor.
# Use a dict instead of a set to preserve order.
names = {name: None for model in models for name in model}
def convert(name: str) -> LazyTensor:
lazy_tensors: List[LazyTensor] = [model[name] for model in models]
lazy_tensors: list[LazyTensor] = [model[name] for model in models]
if len(lazy_tensors) == 1:
# only one file; don't go through this procedure since there might
# be quantized tensors
@ -570,7 +571,7 @@ def merge_sharded(models: List[LazyModel]) -> LazyModel:
return {name: convert(name) for name in names}
def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
formats = set(mp.format for mp in models_plus)
assert len(formats) == 1, "different formats?"
format = formats.pop()
@ -674,7 +675,7 @@ class LazyUnpickler(pickle.Unpickler):
def rebuild_from_type_v2(func, new_type, args, state):
return func(*args)
CLASSES: Dict[Tuple[str, str], Any] = {
CLASSES: dict[tuple[str, str], Any] = {
# getattr used here as a workaround for mypy not being smart enough to detrmine
# the staticmethods have a __func__ attribute.
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
@ -707,15 +708,15 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
header_size, = struct.unpack('<Q', fp.read(8))
header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
header: dict[str, dict[str, Any]] = json.loads(fp.read(header_size))
# Use mmap for the actual data to avoid race conditions with the file offset.
mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
byte_buf = mapped[8 + header_size:]
def convert(info: Dict[str, Any]) -> LazyTensor:
def convert(info: dict[str, Any]) -> LazyTensor:
data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
numpy_dtype = data_type.dtype
shape: List[int] = info['shape']
shape: list[int] = info['shape']
begin, end = info['data_offsets']
assert 0 <= begin <= end <= len(byte_buf)
assert end - begin == math.prod(shape) * numpy_dtype.itemsize
@ -754,7 +755,7 @@ def lazy_load_file(path: Path) -> ModelPlus:
In = TypeVar('In')
Out = TypeVar('Out')
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]:
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: int | None = None, use_processpool_executor: bool = False) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one
@ -763,13 +764,13 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
yield from map(func, iterable)
# Not reached.
iterable = iter(iterable)
executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]]
executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor]
if use_processpool_executor:
executor_class = ProcessPoolExecutor
else:
executor_class = ThreadPoolExecutor
with executor_class(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = []
futures: list[concurrent.futures.Future[Out]] = []
done = False
for _ in range(concurrency):
try:
@ -893,13 +894,13 @@ class OutputFile:
of.close()
@staticmethod
def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]:
name, lazy_tensor = item
tensor = lazy_tensor.load().to_ggml()
return (lazy_tensor.data_type, tensor.ndarray)
@staticmethod
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray:
dt, arr = item
if not isinstance(dt, QuantizedDataType):
return arr
@ -940,7 +941,7 @@ class OutputFile:
of.close()
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
@ -960,7 +961,7 @@ def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyM
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
tmap = gguf.TensorNameMap(ARCH, params.n_layer)
should_skip: Set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, []))
tmp = model
@ -995,12 +996,12 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
return out
def nth_multifile_path(path: Path, n: int) -> Optional[Path]:
def nth_multifile_path(path: Path, n: int) -> Path | None:
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
the nth path in the model.
'''
# Support the following patterns:
patterns: List[Tuple[str, str]] = [
patterns: list[tuple[str, str]] = [
# - x.00.pth, x.01.pth, etc.
(r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'),
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
@ -1016,11 +1017,11 @@ def nth_multifile_path(path: Path, n: int) -> Optional[Path]:
return None
def find_multifile_paths(path: Path) -> List[Path]:
def find_multifile_paths(path: Path) -> list[Path]:
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
the whole list of paths in the model.
'''
ret: List[Path] = []
ret: list[Path] = []
for i in itertools.count():
nth_path = nth_multifile_path(path, i)
if nth_path is None:
@ -1051,7 +1052,7 @@ def load_some_model(path: Path) -> ModelPlus:
path = files[0]
paths = find_multifile_paths(path)
models_plus: List[ModelPlus] = []
models_plus: list[ModelPlus] = []
for path in paths:
print(f"Loading model file {path}")
models_plus.append(lazy_load_file(path))
@ -1060,7 +1061,7 @@ def load_some_model(path: Path) -> ModelPlus:
return model_plus
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
# Be extra-friendly and accept either a file or a directory. Also, if it's
# a directory, it might be the model directory, and tokenizer.model might
# be in the parent of that.
@ -1091,7 +1092,7 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
raise ValueError(f"Unsupported vocabulary type {vocabtype}")
def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
namestr = {
GGMLFileType.AllF32: "f32",
GGMLFileType.MostlyF16: "f16",
@ -1114,7 +1115,7 @@ def do_dump_model(model_plus: ModelPlus) -> None:
print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
def main(args_in: Optional[List[str]] = None) -> None:
def main(args_in: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")

View file

@ -1,16 +1,18 @@
#!/usr/bin/env python3
import shutil
import sys
import struct
import tempfile
import numpy as np
from __future__ import annotations
import json
import os
from pathlib import Path
import shutil
import struct
import sys
import tempfile
from enum import IntEnum, auto
from io import BufferedWriter
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Sequence
import numpy as np
#
# constants
@ -103,7 +105,7 @@ class MODEL_TENSOR(IntEnum):
FFN_NORM : int = auto()
MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.GPT2: "gpt2",
@ -112,7 +114,7 @@ MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH.MPT: "mpt",
}
MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
MODEL_ARCH.LLAMA: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
@ -158,7 +160,7 @@ MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
}
# tensors that will not be serialized
MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
@ -167,7 +169,7 @@ MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
class TensorNameMap:
mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
@ -203,7 +205,7 @@ class TensorNameMap:
),
}
block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
@ -298,9 +300,9 @@ class TensorNameMap:
),
}
mapping: Dict[str, Tuple[MODEL_TENSOR, str]]
mapping: dict[str, tuple[MODEL_TENSOR, str]]
tensor_names: Dict[MODEL_TENSOR, str]
tensor_names: dict[MODEL_TENSOR, str]
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
mapping = self.mapping = {}
@ -321,7 +323,7 @@ class TensorNameMap:
key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
if result is not None:
return result
@ -332,13 +334,13 @@ class TensorNameMap:
return (result[0], result[1] + suffix)
return None
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[1]
def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]:
def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
@ -432,10 +434,10 @@ class GGUFWriter:
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray[Any, Any], int]]
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
tensors: list[tuple[np.ndarray[Any, Any], int]]
def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True):
def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True):
self.fout = open(path, "wb")
self.arch = arch
self.add_architecture()
@ -531,7 +533,7 @@ class GGUFWriter:
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?" ,
}
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
if vtype is None:
vtype = GGUFValueType.get_type(val)
@ -561,7 +563,7 @@ class GGUFWriter:
def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None):
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8")
@ -580,7 +582,7 @@ class GGUFWriter:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
fp.seek(0)
@ -600,7 +602,7 @@ class GGUFWriter:
if pad != 0:
self.temp_file.write(bytes([0] * pad))
def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))
@ -726,13 +728,13 @@ class GGUFWriter:
def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model)
def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
self.add_array(KEY_TOKENIZER_LIST, tokens)
def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
self.add_array(KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]):
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]):
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: Sequence[float]):
@ -756,11 +758,11 @@ class GGUFWriter:
class SpecialVocab:
load_merges: bool = False
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
merges: list[str] = []
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
special_token_ids: dict[str, int] = {}
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
def __init__(self, path: Path, load_merges: bool = False, special_token_types: tuple[str, ...] | None = None):
self.special_token_ids = {}
self.load_merges = load_merges
if special_token_types is not None:
@ -821,7 +823,7 @@ class SpecialVocab:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
continue

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "gguf"
version = "0.2.1"
version = "0.3.1"
description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [

5
mypy.ini Normal file
View file

@ -0,0 +1,5 @@
[mypy]
strict = true
allow_untyped_calls = true
allow_untyped_defs = true
allow_incomplete_defs = true