#!/usr/bin/env python3 import sys, struct, math, argparse from pathlib import Path import numpy as np import gguf # Note: Does not support GGML_QKK_64 QK_K = 256 # Items here are (block size, type size) GGML_QUANT_SIZES = { gguf.GGMLQuantizationType.F32 : (1, 4), gguf.GGMLQuantizationType.F16 : (1, 2), gguf.GGMLQuantizationType.Q4_0 : (32, 2 + 16), gguf.GGMLQuantizationType.Q4_1 : (32, 2 + 2 + 16), gguf.GGMLQuantizationType.Q5_0 : (32, 2 + 4 + 16), gguf.GGMLQuantizationType.Q5_1 : (32, 2 + 2 + 4 + 16), gguf.GGMLQuantizationType.Q8_0 : (32, 2 + 32), gguf.GGMLQuantizationType.Q8_1 : (32, 4 + 4 + 32), gguf.GGMLQuantizationType.Q2_K : (256, 2 + 2 + QK_K // 16 + QK_K // 4), gguf.GGMLQuantizationType.Q3_K : (256, 2 + QK_K // 4 + QK_K // 8 + 12), gguf.GGMLQuantizationType.Q4_K : (256, 2 + 2 + QK_K // 2 + 12), gguf.GGMLQuantizationType.Q5_K : (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), gguf.GGMLQuantizationType.Q6_K : (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8), } class Hyperparameters: def __init__(self): self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.ftype = 0 self.n_ff = 0 def set_n_ff(self, model): ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') assert ff_tensor_idx is not None, 'Missing layer 0 FF tensor' ff_tensor = model.tensors[ff_tensor_idx] self.n_ff = ff_tensor.dims[1] def load(self, data, offset): ( self.n_vocab, self.n_embd, self.n_mult, self.n_head, self.n_layer, self.n_rot, self.ftype, ) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) return 4 * 7 def __str__(self): return f'' class Vocab: def __init__(self): self.items = [] def load(self, data, offset, n_vocab): orig_offset = offset for _ in range(n_vocab): itemlen = struct.unpack('= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' assert name_len < 4096, 'Absurd tensor name length' quant = GGML_QUANT_SIZES.get(dtype) assert quant is not None, 'Unknown tensor type' (blksize, tysize) = quant offset += 12 self.dtype= dtype self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) offset += 4 * n_dims self.name = bytes(data[offset:offset + name_len]) offset += name_len pad = ((offset + 31) & ~31) - offset offset += pad n_elems = np.prod(self.dims) n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize) self.start_offset = offset self.len_bytes = n_bytes offset += n_bytes # print(n_dims, name_len, dtype, self.dims, self.name, pad) return offset - orig_offset class GGMLV3Model: def __init__(self): self.hyperparameters = None self.vocab = None self.tensor_map = {} self.tensors = [] def validate_header(self, data, offset): if bytes(data[offset:offset + 4]) != b'tjgg' or struct.unpack(' 0: gguf_writer.add_token_types(toktypes) return print(f'* Adding {hp.n_vocab} vocab item(s)') assert len(self.model.vocab.items) >= 3, 'Cannot handle unexpectedly short model vocab' for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items): tt = 1 # Normal # Special handling for UNK, BOS, EOS tokens. if tokid <= 2: if tokid == 0: vbytes = b'' tt = 2 elif tokid == 1: vbytes = b'' tt = 3 else: vbytes = b'' tt = 3 elif len(vbytes) == 0: tt = 3 # Control elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1: vbytes = bytes(f'<0x{vbytes[0]:02X}>', encoding = 'UTF-8') tt = 6 # Byte else: vbytes = vbytes.replace(b' ', b'\xe2\x96\x81') toktypes.append(tt) tokens.append(vbytes) scores.append(vscore) gguf_writer.add_token_list(tokens) gguf_writer.add_token_scores(scores) gguf_writer.add_token_types(toktypes) gguf_writer.add_unk_token_id(0) gguf_writer.add_bos_token_id(1) gguf_writer.add_eos_token_id(2) def add_tensors(self, gguf_writer): tensor_map = self.name_map data = self.data print(f'* Adding {len(self.model.tensors)} tensor(s)') for tensor in self.model.tensors: name = str(tensor.name, 'UTF-8') mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) assert mapped_name is not None, f'Bad name {name}' tempdims = list(tensor.dims[:]) if len(tempdims) > 1: temp = tempdims[1] tempdims[1] = tempdims[0] tempdims[0] = temp # print(f'+ {tensor.name} | {mapped_name} {tensor.dims} :: {tempdims}') gguf_writer.add_tensor(mapped_name, data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], raw_shape = tempdims, raw_dtype = tensor.dtype) def handle_metadata(cfg, hp): import convert assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory' hf_config_path = cfg.model_metadata_dir / "config.json" orig_config_path = cfg.model_metadata_dir / "params.json" # We pass a fake model here. "original" mode will check the shapes of some # tensors if information is missing in the .json file: other than that, the # model data isn't used so this should be safe (at least for now). fakemodel = { 'tok_embeddings.weight': convert.LazyTensor.__new__(convert.LazyTensor), 'layers.0.feed_forward.w1.weight': convert.LazyTensor.__new__(convert.LazyTensor), } fakemodel['tok_embeddings.weight'].shape = [hp.n_vocab] fakemodel['layers.0.feed_forward.w1.weight'].shape = [hp.n_ff] if hf_config_path.exists(): params = convert.Params.loadHFTransformerJson(fakemodel, hf_config_path) elif orig_config_path.exists(): params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) else: raise ValueError('Unable to load metadata') vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype) # FIXME: Respect cfg.vocab_dir? svocab = gguf.SpecialVocab(cfg.model_metadata_dir) convert.check_vocab_size(params, vocab) return (params, vocab, svocab) def handle_args(): parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF') parser.add_argument('--input', '-i', type = Path, help = 'Input GGMLv3 filename') parser.add_argument('--output', '-o', type = Path, help ='Output GGUF filename') parser.add_argument('--name', help = 'Set model name') parser.add_argument('--desc', help = 'Set model description') parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') parser.add_argument('--eps', default = '5.0e-06', help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') parser.add_argument('--context-length', '-c', type=int, default = 2048, help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') parser.add_argument('--model-metadata-dir', '-m', type = Path, help ='Load HuggingFace/.pth vocab and metadata from the specified directory') parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)", default="spm") return parser.parse_args() def main(): cfg = handle_args() print(f'* Using config: {cfg}') print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n') data = np.memmap(cfg.input, mode = 'r') model = GGMLV3Model() print('* Scanning GGML input file') offset = model.load(data, 0) print(f'* GGML model hyperparameters: {model.hyperparameters}') vocab_override = None params_override = None special_vocab = None if cfg.model_metadata_dir is not None: (params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters) print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.') print(f'* Overriding params: {params_override}') print(f'* Overriding vocab: {vocab_override}') print(f'* Special vocab: {special_vocab}') else: print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab) converter.save() print(f'* Successful completion. Output saved to: {cfg.output}') if __name__ == '__main__': main()