llm : add MPT support (#3417)

* CUDA: added support for ggml_clamp (see also: https://github.com/ggerganov/ggml/issues/545)

* mpt : added an implementation based (mostly) on falcon integration, modified with deltas from ggml/examples/mpt

* mpt : protect against "clip_qkv": null in mpt-7b

* mpt : quick fix to avoid "Strange model" warning when quantizing MPT models

* mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping" (more compliant with the current GGUF spec?)

* mpt : standardized all tensor names to follow GGUF spec

* mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GET_KEY macro instead of duplicate code

* mpt : fixed comment s/gptneox/mpt/

* mpt : remove tabs, trailing whitespace

* mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt

* mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to convert-gptneox-hf-to-gguf.py in pr:3252

* comment out n_past instead of marking it unused

* mpt : removed hardcoded +178 from convert script in favor of utilizing hparams["vocab_size"]

* mpt : remove unused tokenizer_json in convert script

* ggml : remove obsolete n_past assert in ggml_alibi

* llama : print clam_kqv and max_alibi_bias hparams

---------

Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Jan Ploski 2023-10-10 09:50:23 +02:00 committed by GitHub
parent 11ea5c7d96
commit f5f9121de1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 685 additions and 9 deletions

216
convert-mpt-hf-to-gguf.py Executable file
View file

@ -0,0 +1,216 @@
#!/usr/bin/env python3
# HF mpt--> gguf conversion
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any
import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf
def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1
if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file, or model file itself (*.bin)",
)
parser.add_argument(
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
help="output format - use 0 for float32, 1 for float16",
)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)
# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16
# map from ftype to string
ftype_str = ["f32", "f16"]
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
print("gguf: loading model "+dir_model.name)
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
if hparams["architectures"][0] != "MPTForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])
sys.exit()
# get number of model parts
num_parts = count_model_parts(dir_model)
ARCH=gguf.MODEL_ARCH.MPT
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
print("gguf: get model metadata")
block_count = hparams["n_layers"]
gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_seq_len"])
gguf_writer.add_embedding_length(hparams["d_model"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
gguf_writer.add_head_count(hparams["n_heads"])
gguf_writer.add_layer_norm_eps(1e-05)
if hparams["attn_config"]["clip_qkv"] is not None:
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
# TOKENIZATION
print("gguf: get tokenizer metadata")
tokens: list[bytearray] = []
scores: list[float] = []
toktypes: list[int] = []
# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")
print("gguf: get gpt2 tokenizer vocab")
# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
# accomodate some "reserved" tokens; this is causing problems down the line in
# llama.cpp, so we pad the vocab with dummy tokens:
vocab_size = hparams["vocab_size"]
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
for i in range(vocab_size):
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
scores.append(0.0) # dummy
toktypes.append(gguf.TokenType.NORMAL)
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)
# TENSORS
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
# tensor info
print("gguf: get tensor metadata")
if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)
for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
for name in model_part.keys():
data = model_part[name]
old_dtype = data.dtype
# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)
data = data.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Cannot map tensor '" + name + "'")
continue # for the sake of compatibility with some old published models, don't quit
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
# note: MPT output is tied to (same as) wte in original model;
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
if new_name == "token_embd.weight":
gguf_writer.add_tensor("output.weight", data)
print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("")

View file

@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}
static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}
static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}
template<typename T>
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi(
const int64_t ne02 = src0->ne[2];
const int64_t nrows = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
GGML_ASSERT(ne01 + n_past == ne00);
//GGML_ASSERT(ne01 + n_past == ne00);
GGML_ASSERT(n_head == ne02);
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale(
(void) src1_dd;
}
inline void ggml_cuda_op_clamp(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float min = ((float *) dst->op_params)[0];
const float max = ((float *) dst->op_params)[1];
clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
CUDA_CHECK(cudaGetLastError());
(void) src1;
(void) dst;
(void) src1_dd;
}
static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
const int64_t nrows0 = ggml_nrows(src0);
@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}
static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
}
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_SCALE:
func = ggml_cuda_scale;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;
}
func = ggml_cuda_clamp;
break;
case GGML_OP_CPY:
func = ggml_cuda_cpy;
break;

View file

@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(
const int nth = MIN(1024, ne00);
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

4
ggml.c
View file

@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32(
return;
}
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
assert(n_past >= 0);
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
const int64_t ne2 = src0->ne[2]; // n_head -> this is k

425
llama.cpp
View file

@ -424,6 +424,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
LLM_ARCH_MPT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
@ -1011,6 +1019,9 @@ struct llama_hparams {
float rope_freq_base_train;
float rope_freq_scale_train;
float f_clamp_kqv;
float f_max_alibi_bias;
bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
if (this->n_vocab != other.n_vocab) return true;
@ -2060,6 +2071,20 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_MPT:
{
hparams.f_clamp_kqv = 0.0f;
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break;
case 48: model.type = e_model::MODEL_30B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}
@ -2204,6 +2229,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
@ -2649,6 +2676,73 @@ static void llm_load_tensors(
layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend);
}
} break;
case LLM_ARCH_MPT:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
ggml_backend_type backend_norm;
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = LLAMA_BACKEND_OFFLOAD;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
} else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
}
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.output_norm);
}
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
vram_weights += ggml_nbytes(model.output);
}
}
const uint32_t n_ff = hparams.n_ff;
const int i_gpu_start = n_layer - n_gpu_layers;
model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
ggml_nbytes(layer.attn_norm) +
ggml_nbytes(layer.wqkv) +
ggml_nbytes(layer.wo) +
ggml_nbytes(layer.ffn_norm) +
ggml_nbytes(layer.w2) +
ggml_nbytes(layer.w3);
}
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -4505,7 +4599,6 @@ static struct ggml_cgraph * llm_build_starcoder(
return gf;
}
static struct ggml_cgraph * llm_build_persimmon(
llama_context & lctx,
const llama_batch & batch) {
@ -4903,6 +4996,326 @@ static struct ggml_cgraph * llm_build_persimmon(
return gf;
}
static struct ggml_cgraph * llm_build_mpt(
llama_context & lctx,
const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
const float norm_eps = hparams.f_norm_eps;
const float clamp_kqv = hparams.f_clamp_kqv;
const float max_alibi_bias = hparams.f_max_alibi_bias;
const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
//printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
// kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
/*.no_alloc =*/ false,
};
params.no_alloc = true;
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
//int warmup = 0;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
//warmup = ((uint32_t*) inp_tokens->data)[0] == 0;
}
ggml_set_name(inp_tokens, "inp_tokens");
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_allocr_alloc(lctx.alloc, inpL);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
}
}
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
}
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm;
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
// self-attention
// TODO: refactor into common function (shared with LLaMA)
{
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
offload_func(attn_norm);
attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
offload_func(attn_norm);
if (1) {
cur = attn_norm;
}
// compute QKV
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
offload_func_kq(cur);
if (clamp_kqv > 0.0f) {
cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv);
offload_func_kq(cur);
}
const size_t wsize = ggml_type_size(cur->type);
struct ggml_tensor * Qcur = ggml_view_3d(
ctx0, cur, n_embd_head, n_head, n_tokens,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
0);
offload_func_kq(Qcur);
struct ggml_tensor * Kcur = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
wsize * n_embd_head * n_head);
offload_func_kq(Kcur);
struct ggml_tensor * tmpv = ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, n_tokens,
wsize * n_embd_head,
wsize * n_embd_head * (n_head + 2 * n_head_kv),
wsize * n_embd_head * (n_head + n_head_kv));
offload_func_kq(Kcur);
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");
{
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
offload_func_v(Vcur);
offload_func_v(Vcur->src[0]->src[0]);
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
offload_func_kq(k);
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
offload_func_kq(K);
ggml_set_name(K, "K");
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");
// TODO: replace with ggml_add()
struct ggml_tensor * KQ_scaled_alibi =
ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias);
offload_func_kq(KQ_scaled_alibi);
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV);
ggml_set_name(KQV, "KQV");
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
offload_func(cur);
ggml_set_name(cur, "result_wo");
}
// Add the input
cur = ggml_add(ctx0, cur, inpL);
offload_func(cur);
struct ggml_tensor * attn_out = cur;
// feed forward
{
// Norm
{
cur = ggml_norm(ctx0, attn_out, norm_eps);
offload_func(cur);
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
offload_func(cur);
}
cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
offload_func(cur);
cur = ggml_gelu(ctx0, cur);
offload_func(cur);
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
offload_func(cur);
}
cur = ggml_add(ctx0, cur, attn_out);
offload_func(cur);
// input for next layer
inpL = cur;
}
cur = inpL;
// norm
{
cur = ggml_norm(ctx0, cur, norm_eps);
offload_func_nr(cur);
cur = ggml_mul(ctx0, cur, model.output_norm);
ggml_set_name(cur, "result_norm");
}
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
return gf;
}
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch) {
@ -4935,6 +5348,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm_build_refact(lctx, batch);
} break;
case LLM_ARCH_MPT:
{
result = llm_build_mpt(lctx, batch);
} break;
default:
GGML_ASSERT(false);
}
@ -5065,7 +5482,8 @@ static int llama_decode_internal(
const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA ||
model.arch == LLM_ARCH_BAICHUAN ||
model.arch == LLM_ARCH_FALCON ||
model.arch == LLM_ARCH_REFACT;
model.arch == LLM_ARCH_REFACT ||
model.arch == LLM_ARCH_MPT;
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
n_threads = 1;
@ -7161,7 +7579,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const std::string name = ggml_get_name(meta);
// TODO: avoid hardcoded tensor names - use the TN_* constants
if (name.find("attn_v.weight") != std::string::npos) {
if (name.find("attn_v.weight") != std::string::npos ||
name.find("attn_qkv.weight") != std::string::npos) {
++n_attention_wv;
}
else if (name.find("ffn_down.weight") != std::string::npos) {