diff --git a/convert-refact-hf-to-gguf.py b/convert-refact-hf-to-gguf.py new file mode 100755 index 000000000..e0cd417db --- /dev/null +++ b/convert-refact-hf-to-gguf.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +# HF refact--> gguf conversion + +from __future__ import annotations + +import argparse +import json +import os +import sys +from pathlib import Path + +import numpy as np +import torch +from transformers import AutoTokenizer # type: ignore[import] + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf")) +import gguf + + +def bytes_to_unicode(): + # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + return dict(zip(bs, (chr(n) for n in cs))) + + +def count_model_parts(dir_model: Path) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a Refact model to a GGML compatible file" + ) + parser.add_argument( + "--vocab-only", + action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--outfile", + type=Path, + help="path to write to; default: based on input", + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file, or model file itself (*.bin)", + ) + parser.add_argument( + "ftype", + type=int, + choices=[0, 1], + default=1, + nargs="?", + help="output format - use 0 for float32, 1 for float16", + ) + return parser.parse_args() + + +args = parse_args() + +dir_model = args.model +ftype = args.ftype +if not dir_model.is_dir(): + print(f"Error: {args.model} is not a directory", file=sys.stderr) + sys.exit(1) + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +if args.outfile is not None: + fname_out = args.outfile +else: + # output in the same directory as the model by default + fname_out = dir_model / f"ggml-model-{ftype_str[ftype]}.gguf" + +print("gguf: loading model " + dir_model.name) + +with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "GPTRefactForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + + sys.exit(1) + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH = gguf.MODEL_ARCH.REFACT +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +# Get refact feed forward dimension +hidden_dim = hparams["n_embd"] +inner_dim = 4 * hidden_dim +hidden_dim = int(2 * inner_dim / 3) +multiple_of = 256 +ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + +block_count = hparams["n_layer"] + +gguf_writer.add_name("Refact") +# refact uses Alibi. So this is from config.json which might be used by training. +gguf_writer.add_context_length(hparams["n_positions"]) +gguf_writer.add_embedding_length(hparams["n_embd"]) + +gguf_writer.add_feed_forward_length(ff_dim) +gguf_writer.add_block_count(block_count) +gguf_writer.add_head_count(hparams["n_head"]) +gguf_writer.add_head_count_kv(1) +gguf_writer.add_layer_norm_rms_eps(hparams["layer_norm_epsilon"]) +gguf_writer.add_file_type(ftype) + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: list[bytearray] = [] +scores: list[float] = [] +toktypes: list[int] = [] + +tokenizer_json_file = dir_model / "tokenizer.json" +if not tokenizer_json_file.is_file(): + print(f"Error: Missing {tokenizer_json_file}", file=sys.stderr) + sys.exit(1) + +# gpt2 tokenizer +gguf_writer.add_tokenizer_model("gpt2") + +with open(tokenizer_json_file, "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + +print("gguf: get gpt2 tokenizer vocab") + +# The number of tokens in tokenizer.json can differ from the expected vocab size. +# This causes downstream issues with mismatched tensor sizes when running the inference +vocab_size = ( + hparams["vocab_size"] + if "vocab_size" in hparams + else len(tokenizer_json["model"]["vocab"]) +) + +tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) + +reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} +byte_encoder = bytes_to_unicode() +byte_decoder = {v: k for k, v in byte_encoder.items()} + +for i in range(vocab_size): + if i in reverse_vocab: + text = reverse_vocab[i] + try: + text = bytearray([byte_decoder[c] for c in reverse_vocab[i]]) + except KeyError: + text = bytearray() + for c in reverse_vocab[i]: + if ord(c) < 256: # single byte character + text.append(byte_decoder[ord(c)]) + else: # multibyte special token character + text.extend(c.encode("utf-8")) + else: + print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.") + pad_token = f"[PAD{i}]".encode("utf8") + text = bytearray(pad_token) + + tokens.append(text) + scores.append(0.0) # dymmy + toktypes.append(gguf.TokenType.NORMAL) # dummy + +gguf_writer.add_token_list(tokens) +gguf_writer.add_token_scores(scores) +gguf_writer.add_token_types(toktypes) + +special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) +special_vocab.add_to_gguf(gguf_writer) + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH, block_count) + +# params for qkv transform +n_head = hparams["n_head"] +n_head_kv = 1 + +head_dim = hparams["n_embd"] // n_head + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = iter(("pytorch_model.bin",)) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) +for part_name in part_names: + if args.vocab_only: + break + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(dir_model / part_name, map_location="cpu") + + for i in range(block_count): + if f"transformer.h.{i}.attn.kv.weight" in model_part: + data = model_part[f"transformer.h.{i}.attn.kv.weight"] + model_part[f"model.layers.{i}.self_attn.k_proj.weight"] = data[ + : n_head_kv * head_dim + ] + model_part[f"model.layers.{i}.self_attn.v_proj.weight"] = data[ + n_head_kv * head_dim : + ] + del model_part[f"transformer.h.{i}.attn.kv.weight"] + if f"transformer.h.{i}.attn.q.weight" in model_part: + model_part[f"model.layers.{i}.self_attn.q_proj.weight"] = model_part[ + f"transformer.h.{i}.attn.q.weight" + ] + del model_part[f"transformer.h.{i}.attn.q.weight"] + if f"transformer.h.{i}.mlp.gate_up_proj.weight" in model_part: + data = model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"] + model_part[f"model.layers.{i}.mlp.gate_proj.weight"] = data[:ff_dim] + model_part[f"model.layers.{i}.mlp.up_proj.weight"] = data[ff_dim:] + del model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"] + + for name in model_part.keys(): + data = model_part[name] + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight",)) + if new_name is None: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ( + ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): + data = data.astype(np.float16) + + print( + new_name + + ", n_dims = " + + str(n_dims) + + ", " + + str(old_dtype) + + " --> " + + str(data.dtype) + ) + + gguf_writer.add_tensor(new_name, data) + + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +if not args.vocab_only: + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + +gguf_writer.close() + +print(f"gguf: model successfully exported to '{fname_out}'") +print("") diff --git a/ggml.c b/ggml.c index 4a94b0f33..f56d6ac72 100644 --- a/ggml.c +++ b/ggml.c @@ -13082,7 +13082,6 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); @@ -13103,7 +13102,6 @@ static void ggml_compute_forward_alibi_f32( //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(ne1 + n_past == ne0); GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index c975da0cb..a2c570d7e 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -85,6 +85,7 @@ class MODEL_ARCH(IntEnum): GPTNEOX : int = auto() MPT : int = auto() STARCODER : int = auto() + REFACT : int = auto() BERT : int = auto() @@ -118,6 +119,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GPTNEOX: "gptneox", MODEL_ARCH.MPT: "mpt", MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", } @@ -247,6 +249,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.REFACT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.GPT2: [ # TODO ], @@ -271,7 +287,7 @@ class TensorNameMap: # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt + "transformer.wte", # gpt2 gpt-j mpt refact "transformer.word_embeddings", # falcon "model.embed_tokens", # llama-hf "tok_embeddings", # llama-pth @@ -304,6 +320,7 @@ class TensorNameMap: "norm", # llama-pth "embeddings.LayerNorm", # bert "transformer.norm_f", # mpt + "ln_f", # refact ), # Rope frequencies @@ -316,7 +333,7 @@ class TensorNameMap: # Attention norm MODEL_TENSOR.ATTN_NORM: ( "gpt_neox.layers.{bid}.input_layernorm", # gptneox - "transformer.h.{bid}.ln_1", # gpt2 gpt-j + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact "transformer.blocks.{bid}.norm_1", # mpt "transformer.h.{bid}.input_layernorm", # falcon7b "transformer.h.{bid}.ln_mlp", # falcon40b @@ -365,7 +382,7 @@ class TensorNameMap: # Attention output MODEL_TENSOR.ATTN_OUT: ( "gpt_neox.layers.{bid}.attention.dense", # gptneox - "transformer.h.{bid}.attn.c_proj", # gpt2 + "transformer.h.{bid}.attn.c_proj", # gpt2 refact "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "model.layers.{bid}.self_attn.o_proj", # llama-hf @@ -383,7 +400,7 @@ class TensorNameMap: # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox - "transformer.h.{bid}.ln_2", # gpt2 + "transformer.h.{bid}.ln_2", # gpt2 refact "transformer.blocks.{bid}.norm_2", # mpt "model.layers.{bid}.post_attention_layernorm", # llama-hf "layers.{bid}.ffn_norm", # llama-pth @@ -396,7 +413,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc", # gpt2 "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon - "model.layers.{bid}.mlp.up_proj", # llama-hf + "model.layers.{bid}.mlp.up_proj", # llama-hf refact "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j @@ -404,14 +421,14 @@ class TensorNameMap: # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( - "model.layers.{bid}.mlp.gate_proj", # llama-hf + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact "layers.{bid}.feed_forward.w1", # llama-pth ), # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox - "transformer.h.{bid}.mlp.c_proj", # gpt2 + "transformer.h.{bid}.mlp.c_proj", # gpt2 refact "transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "model.layers.{bid}.mlp.down_proj", # llama-hf diff --git a/llama.cpp b/llama.cpp index a40da6839..08d6c162a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -165,6 +165,7 @@ enum llm_arch { LLM_ARCH_GPTNEOX, LLM_ARCH_MPT, LLM_ARCH_STARCODER, + LLM_ARCH_REFACT, LLM_ARCH_UNKNOWN, }; @@ -177,6 +178,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, }; enum llm_kv { @@ -397,6 +399,23 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_REFACT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1927,6 +1946,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_REFACT: + { + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -2164,6 +2191,7 @@ static void llm_load_tensors( const auto tn = LLM_TN(model.arch); switch (model.arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); @@ -3357,6 +3385,353 @@ static struct ggml_cgraph * llm_build_baichaun( return gf; } +static struct ggml_cgraph * llm_build_refact( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + // printf("n_kv = %d\n", n_kv); + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + struct ggml_tensor * inpSA = inpL; + + // norm + { + cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_0"); + + // cur = cur*attn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + offload_func(cur); + ggml_set_name(cur, "attention_norm_0"); + } + + // self-attention + { + // compute Q and K + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + offload_func_kq(tmpk); + ggml_set_name(tmpk, "tmpk"); + + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + offload_func_kq(tmpq); + ggml_set_name(tmpq, "tmpq"); + + struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens); + offload_func_kq(Kcur); + ggml_set_name(Kcur, "Kcur"); + + struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens); + offload_func_kq(Qcur); + ggml_set_name(Qcur, "Qcur"); + + // store key and value to memory + { + // compute the transposed [n_tokens, n_embd] V matrix + + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + offload_func_v(tmpv); + ggml_set_name(tmpv, "tmpv"); + + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); + ggml_set_name(v, "v"); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + +#if 1 + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); +#else + // make V contiguous in memory to speed up the matmul, however we waste time on the copy + // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation + // is there a better way? + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); +#endif + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + offload_func(inpFF); + ggml_set_name(inpFF, "inpFF"); + + // feed-forward network + { + // norm + { + cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_1"); + + // cur = cur*ffn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); + ggml_set_name(cur, "ffn_norm"); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); + offload_func(tmp); + ggml_set_name(tmp, "result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w1, + cur); + offload_func(cur); + ggml_set_name(cur, "result_w1"); + + // SILU activation + cur = ggml_silu(ctx0, cur); + offload_func(cur); + ggml_set_name(cur, "silu"); + + cur = ggml_mul(ctx0, cur, tmp); + offload_func(cur); + ggml_set_name(cur, "silu_x_result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w2, + cur); + offload_func(cur); + ggml_set_name(cur, "result_w2"); + } + + cur = ggml_add(ctx0, cur, inpFF); + offload_func(cur); + ggml_set_name(cur, "inpFF_+_result_w2"); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); + + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.output_norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); + } + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + static struct ggml_cgraph * llm_build_falcon( llama_context & lctx, const llama_batch & batch) { @@ -3997,6 +4372,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_starcoder(lctx, batch); } break; + case LLM_ARCH_REFACT: + { + result = llm_build_refact(lctx, batch); + } break; default: GGML_ASSERT(false); } @@ -4130,7 +4509,8 @@ static int llama_decode_internal( // If all tensors can be run on the GPU then using more than 1 thread is detrimental. const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_BAICHUAN || - model.arch == LLM_ARCH_FALCON; + model.arch == LLM_ARCH_FALCON || + model.arch == LLM_ARCH_REFACT; const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { n_threads = 1;