diff --git a/.flake8 b/.flake8 index 113ca5fd3..18fba2c15 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] max-line-length = 125 +ignore = W503 diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0d4ea03b4..cae1551a2 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -209,6 +209,8 @@ class Model: return InternLM2Model if model_architecture == "MiniCPMForCausalLM": return MiniCPMModel + if model_architecture == "BertModel": + return BertModel return Model def _is_model_safetensors(self) -> bool: @@ -264,6 +266,8 @@ class Model: return gguf.MODEL_ARCH.INTERNLM2 if arch == "MiniCPMForCausalLM": return gguf.MODEL_ARCH.MINICPM + if arch == "BertModel": + return gguf.MODEL_ARCH.BERT raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -1629,6 +1633,96 @@ in chat mode so that the conversation can end normally.") self.post_write_tensors(tensor_map, name, data_torch) +class BertModel(Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + + def set_gguf_parameters(self): + # TODO(cebtenzzre): merge with parent class + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) + self.gguf_writer.add_causal_attention(False) + self.gguf_writer.add_file_type(self.ftype) + + def set_vocab(self): + path = self.dir_model + added_tokens_path = self.dir_model if self.dir_model.exists() else None + + # use huggingface vocab to get all tokens + vocab = HfVocab(path, added_tokens_path) + tokens, scores, toktypes = zip(*vocab.all_tokens()) + assert len(tokens) == vocab.vocab_size + + # we need this to validate the size of the token_type embeddings + # though currently we are passing all zeros to the token_type embeddings + n_token_types = len(set(toktypes)) + self.gguf_writer.add_token_type_count(n_token_types) + + # convert to phantom space vocab + def phantom(tok, typ): + if tok.startswith(b"[") and tok.endswith(b"]"): + return tok + if tok.startswith(b"##"): + return tok[2:] + return b"\xe2\x96\x81" + tok + tokens = [phantom(t, y) for t, y in zip(tokens, toktypes)] + + # set up bos and eos tokens (cls and sep) + self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id) + self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id) + + # add vocab to gguf + self.gguf_writer.add_tokenizer_model("bert") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + # handle special tokens + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def write_tensors(self): + tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + tensors = dict(self.get_tensors()) + for name, data_torch in tensors.items(): + # we are only using BERT for embeddings so we don't need the pooling layer + if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + continue # we don't need these + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + data = data_torch.squeeze().numpy() + n_dims = len(data.shape) + new_dtype: type[np.floating[Any]] + + if ( + self.ftype == 1 and name.endswith(".weight") and n_dims == 2 + and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32 + ): + # if f16 desired, convert any float32 2-dim weight tensors to float16 + new_dtype = np.float16 + else: + # if f32 desired, convert any float16 to float32 + new_dtype = np.float32 + + print(f"{new_name}, n_dims = {n_dims}, {data_torch.dtype} --> {new_dtype}") + + if data.dtype != new_dtype: + data = data.astype(new_dtype) + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3295cd240..27376c8f0 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -87,7 +87,17 @@ int main(int argc, char ** argv) { } const int n_embd = llama_n_embd(model); - const auto * embeddings = llama_get_embeddings(ctx); + auto * embeddings = llama_get_embeddings(ctx); + + // l2-normalize embeddings + float norm = 0; + for (int i = 0; i < n_embd; i++) { + norm += embeddings[i] * embeddings[i]; + } + norm = sqrt(norm); + for (int i = 0; i < n_embd; i++) { + embeddings[i] /= norm; + } for (int i = 0; i < n_embd; i++) { printf("%f ", embeddings[i]); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1cfd41c0b..a9c13dd38 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -50,6 +50,7 @@ class Keys: VALUE_LENGTH = "{arch}.attention.value_length" LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + CAUSAL = "{arch}.attention.causal" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -60,22 +61,23 @@ class Keys: SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class Tokenizer: - MODEL = "tokenizer.ggml.model" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - CHAT_TEMPLATE = "tokenizer.chat_template" + MODEL = "tokenizer.ggml.model" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" # @@ -122,6 +124,7 @@ class MODEL_TENSOR(IntEnum): ATTN_OUT = auto() ATTN_NORM = auto() ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() ATTN_ROT_EMBD = auto() FFN_GATE_INP = auto() FFN_NORM = auto() @@ -134,6 +137,7 @@ class MODEL_TENSOR(IntEnum): FFN_UP_EXP = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -178,6 +182,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", @@ -187,6 +192,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -262,17 +268,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { ], MODEL_ARCH.BERT: [ MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, MODEL_TENSOR.TOKEN_TYPES, MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.LAYER_OUT_NORM, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 16808196e..7af58a46c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -357,6 +357,9 @@ class GGUFWriter: def add_layer_norm_rms_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + def add_causal_attention(self, value: bool) -> None: + self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) @@ -387,6 +390,9 @@ class GGUFWriter: def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: self.add_array(Keys.Tokenizer.TOKEN_TYPE, types) + def add_token_type_count(self, value: int) -> None: + self.add_uint32(Keys.Tokenizer.TOKEN_TYPE_COUNT, value) + def add_token_scores(self, scores: Sequence[float]) -> None: self.add_array(Keys.Tokenizer.SCORES, scores) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4f16d8504..c7ba1420e 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -30,6 +30,7 @@ class TensorNameMap: # Normalization of token embeddings MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom + "embeddings.LayerNorm", # bert ), # Position embeddings @@ -54,7 +55,6 @@ class TensorNameMap: "transformer.ln_f", # gpt2 gpt-j falcon "model.norm", # llama-hf baichuan internlm2 "norm", # llama-pth - "embeddings.LayerNorm", # bert "transformer.norm_f", # mpt "ln_f", # refact bloom qwen gpt2 "language_model.encoder.final_layernorm", # persimmon @@ -79,7 +79,6 @@ class TensorNameMap: "transformer.h.{bid}.ln_mlp", # falcon40b "model.layers.{bid}.input_layernorm", # llama-hf "layers.{bid}.attention_norm", # llama-pth - "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi "h.{bid}.ln_1", # gpt2 @@ -155,6 +154,11 @@ class TensorNameMap: "model.layers.{bid}.attention.wo", # internlm2 ), + # Attention output norm + MODEL_TENSOR.ATTN_OUT_NORM: ( + "encoder.layer.{bid}.attention.output.LayerNorm", # bert + ), + # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf @@ -171,7 +175,6 @@ class TensorNameMap: "transformer.blocks.{bid}.norm_2", # mpt "model.layers.{bid}.post_attention_layernorm", # llama-hf "layers.{bid}.ffn_norm", # llama-pth - "encoder.layer.{bid}.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi "h.{bid}.ln_2", # gpt2 @@ -266,6 +269,10 @@ class TensorNameMap: MODEL_TENSOR.ROPE_FREQS: ( "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon ), + + MODEL_TENSOR.LAYER_OUT_NORM: ( + "encoder.layer.{bid}.output.LayerNorm", # bert + ) } mapping: dict[str, tuple[MODEL_TENSOR, str]] diff --git a/llama.cpp b/llama.cpp index 3f39a67fb..d1ee26ce2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -196,6 +196,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, + LLM_ARCH_BERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -220,6 +221,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_PERSIMMON, "persimmon" }, { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -261,6 +263,7 @@ enum llm_kv { LLM_KV_ATTENTION_VALUE_LENGTH, LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_CAUSAL, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -273,6 +276,7 @@ enum llm_kv { LLM_KV_TOKENIZER_MODEL, LLM_KV_TOKENIZER_LIST, LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, LLM_KV_TOKENIZER_SCORES, LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, @@ -316,6 +320,7 @@ static std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -328,6 +333,7 @@ static std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, @@ -355,6 +361,7 @@ struct LLM_KV { enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, @@ -536,6 +543,23 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_BLOOM, { @@ -1440,6 +1464,11 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_17M, + MODEL_22M, + MODEL_33M, + MODEL_109M, + MODEL_335M, MODEL_0_5B, MODEL_1B, MODEL_2B, @@ -1481,6 +1510,7 @@ struct llama_hparams { uint32_t n_ff; uint32_t n_expert = 0; uint32_t n_expert_used = 0; + uint32_t n_vocab_type = 0; // for BERT-style token types float f_norm_eps; float f_norm_rms_eps; @@ -1493,6 +1523,8 @@ struct llama_hparams { float f_clamp_kqv; float f_max_alibi_bias; + bool causal_attn = true; + bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; @@ -1720,6 +1752,7 @@ struct llama_model { llama_vocab vocab; struct ggml_tensor * tok_embd; + struct ggml_tensor * type_embd; struct ggml_tensor * pos_embd; struct ggml_tensor * tok_norm; struct ggml_tensor * tok_norm_b; @@ -1850,6 +1883,7 @@ struct llama_context { struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_K_shift; // I32 [n_ctx] + struct ggml_tensor * inp_sum; // F32 [1, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -2829,6 +2863,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ switch (type) { case LLAMA_VOCAB_TYPE_SPM: return "SPM"; case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; default: return "unknown"; } } @@ -3000,6 +3035,26 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + + switch (hparams.n_layer) { + case 3: + model.type = e_model::MODEL_17M; break; // bge-micro + case 6: + model.type = e_model::MODEL_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small + case 768: model.type = e_model::MODEL_109M; break; // bge-base + } break; + case 24: + model.type = e_model::MODEL_335M; break; // bge-large + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3204,6 +3259,16 @@ static void llm_load_vocab( vocab.special_unk_id = -1; vocab.special_sep_id = -1; vocab.special_pad_id = -1; + } else if (tokenizer_name == "bert") { + vocab.type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + vocab.special_bos_id = 101; + vocab.special_eos_id = 102; + vocab.special_unk_id = 100; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.add_space_prefix = false; } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); @@ -3232,6 +3297,8 @@ static void llm_load_vocab( // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { + vocab.linefeed_id = vocab.special_pad_id; } else { const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); @@ -3569,6 +3636,7 @@ static bool llm_load_tensors( const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; + const int64_t n_vocab_type = hparams.n_vocab_type; const int64_t n_ff = hparams.n_ff; GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); @@ -3783,11 +3851,50 @@ static bool llm_load_tensors( layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}); } } break; - case LLM_ARCH_BLOOM: + case LLM_ARCH_BERT: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); + model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + } + } break; + case LLM_ARCH_BLOOM: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); // output { @@ -4739,6 +4846,7 @@ struct llm_build_context { const int32_t n_orig_ctx; const bool do_rope_shift; + const bool causal_attn; const llm_build_cb & cb; @@ -4782,6 +4890,7 @@ struct llm_build_context { kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), + causal_attn (hparams.causal_attn), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { // all initializations should be done in init() @@ -5625,6 +5734,100 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bert() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // get input vectors with right size + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); + + // construct input embeddings (token, type, position) + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); + // token types are hardcoded to zero ("Sentence A") + struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); + cb(inpL, "inp_norm", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens] + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * cur = inpL; + + // self-attention + { + struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); + cb(Vcur, "Vcur", il); + + // seems like we just need to do this for Q? + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + // attention layer norm + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, cb, il); + + struct ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // output layer norm + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); + + // input for next layer + inpL = cur; + } + + // final output + cur = inpL; + + // pooling + cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + cb(cur, "result_embed", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_bloom() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -7060,7 +7263,8 @@ static struct ggml_cgraph * llama_build_graph( for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || + (llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) { f = -INFINITY; } else { f = 0; @@ -7081,6 +7285,15 @@ static struct ggml_cgraph * llama_build_graph( data[i] = lctx.kv_self.cells[i].delta; } } + + { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); + float * data = (float *) lctx.inp_sum->data; + + for (int i = 0; i < batch.n_tokens; ++i) { + data[i] = 1.0f/float(batch.n_tokens); + } + } } llm.init(); @@ -7110,6 +7323,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_refact(); } break; + case LLM_ARCH_BERT: + { + result = llm.build_bert(); + } break; case LLM_ARCH_BLOOM: { result = llm.build_bloom(); @@ -7269,13 +7486,18 @@ static int llama_decode_internal( // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - - // the embeddings could be the second to last tensor, or the third to last tensor struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - if (strcmp(embeddings->name, "result_norm") != 0) { - embeddings = gf->nodes[gf->n_nodes - 3]; - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + if (strcmp(res->name, "result_output") == 0) { + // the embeddings could be the second to last tensor, or the third to last tensor + if (strcmp(embeddings->name, "result_norm") != 0) { + embeddings = gf->nodes[gf->n_nodes - 3]; + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + } + } else if (strcmp(res->name, "result_embed") == 0) { + embeddings = res; + res = nullptr; + } else { + GGML_ASSERT(false); } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -7344,7 +7566,7 @@ static int llama_decode_internal( // extract logits // TODO: do not compute and extract logits if only embeddings are needed // need to update the graphs to skip "result_output" - { + if (res) { auto & logits_out = lctx.logits; #ifndef NDEBUG @@ -7388,9 +7610,11 @@ static int llama_decode_internal( if (!lctx.embedding.empty()) { auto & embedding_out = lctx.embedding; + const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0; + embedding_out.resize(n_embd); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), (n_embd*(n_tokens - 1))*sizeof(float), n_embd*sizeof(float)); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float)); ggml_backend_synchronize(embeddings_backend); } @@ -7454,6 +7678,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(false); return unicode_to_bytes_bpe(token_data.text); } + case LLAMA_VOCAB_TYPE_WPM: { + GGML_ASSERT(false); + } default: GGML_ASSERT(false); } @@ -7466,6 +7693,7 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; return vocab.token_to_id.at(buf); } + case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_BPE: { return vocab.token_to_id.at(bytes_to_unicode_bpe(ch)); } @@ -7936,12 +8164,212 @@ private: llm_bigram_bpe::queue work_queue; }; -typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ +struct llm_tokenizer_wpm { + llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + auto * token_map = &vocab.token_to_id; + + // normalize and split by whitespace + std::vector words = preprocess(text); + + // bos token prepended already + + // find the longest tokens that form the words + for (const std::string &word : words) { + // skip empty words + if (word.size() == 0) { + continue; + } + + // prepend phantom space + std::string word1 = "\xe2\x96\x81" + word; + int n = word1.size(); + + // we're at the start of a new word + int i = 0; + bool match_any = false; + + // move through character position in word + while (i < n) { + // loop through possible match length + bool match = false; + for (int j = n; j > i; j--) { + auto it = token_map->find(word1.substr(i, j - i)); + if (it != token_map->end()) { + output.push_back(it->second); + match = true; + match_any = true; + i = j; + break; + } + } + + // must be an unknown character + if (!match) { + i++; + } + } + + // we didn't find any matches for this word + if (!match_any) { + output.push_back(vocab.special_unk_id); + } + } + + // append eos token + output.push_back(vocab.special_eos_id); + } + + std::vector preprocess(const std::string & text) { + std::string ori_str = normalize(text); + uint64_t ori_size = ori_str.size(); + + // single punct / single symbol / single digit + // baseline: add whitespace on the left and right of punct and chinese characters + std::vector words; + std::string new_str = ""; + uint64_t i = 0; + while (i < ori_size) { + int utf_char_len = utf8_len(ori_str[i]); + if ((utf_char_len == 1) && ispunct(ori_str[i])) { + new_str += " "; + new_str += ori_str[i]; + new_str += " "; + i += 1; + } + else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) { + new_str += " "; + new_str += ori_str.substr(i, 3); + new_str += " "; + i += 3; + } + else { + new_str += ori_str[i]; + i += 1; + } + } + + // split by whitespace + uint64_t l = 0; + uint64_t r = 0; + while (r < new_str.size()) { + // if is whitespace + if (isspace(new_str[r])) { + if (r > l) words.push_back(new_str.substr(l, (r - l))); + l = r + 1; + r = l; + } + else { + r += 1; + } + } + if (r > l) { + words.push_back(new_str.substr(l, (r - l))); + } + return words; + } + + std::string normalize(const std::string & text) { + // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98 + std::string text2 = strip_accents(text); + for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) { + char c = text2[i]; + if (c >= 'A' && c <= 'Z') { + text2[i] = c - 'A' + 'a'; + } + } + return text2; + } + + bool is_chinese_char(const std::string & str) { + int len = str.length(); + unsigned int codepoint = 0; + int num_bytes = 0; + int i = 0; + unsigned char ch = static_cast(str[i]); + if (ch <= 0x7f) { + codepoint = ch; + num_bytes = 1; + } else if ((ch >> 5) == 0x06) { + codepoint = ch & 0x1f; + num_bytes = 2; + } else if ((ch >> 4) == 0x0e) { + codepoint = ch & 0x0f; + num_bytes = 3; + } else if ((ch >> 3) == 0x1e) { + codepoint = ch & 0x07; + num_bytes = 4; + } + for (int j = 1; j < num_bytes; ++j) { + if (i + j >= len) { + return false; // incomplete UTF-8 character + } + unsigned char next_ch = static_cast(str[i + j]); + if ((next_ch >> 6) != 0x02) { + return false; // invalid trailing byte + } + codepoint = (codepoint << 6) | (next_ch & 0x3f); + } + if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) || + (codepoint >= 0x3400 && codepoint <= 0x4DBF) || + (codepoint >= 0x20000 && codepoint <= 0x2A6DF) || + (codepoint >= 0x2A700 && codepoint <= 0x2B73F) || + (codepoint >= 0x2B740 && codepoint <= 0x2B81F) || + (codepoint >= 0x2B920 && codepoint <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 + (codepoint >= 0xF900 && codepoint <= 0xFAFF) || + (codepoint >= 0x2F800 && codepoint <= 0x2FA1F) || + (codepoint >= 0x3000 && codepoint <= 0x303F) || + (codepoint >= 0xFF00 && codepoint <= 0xFFEF)) { + return true; // NOLINT + } + return false; + } + + std::string strip_accents(const std::string & input_string) { + std::string resultString; + std::map accent_map = { + {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'}, + {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'}, + {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'}, + {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'}, + {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'}, + {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'}, + {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'}, + {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'}, + {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'}, + }; + + for (size_t i = 0; i < input_string.length();) { + int len = utf8_len(input_string[i]); + std::string curChar = input_string.substr(i, len); + auto iter = accent_map.find(curChar); + if (iter != accent_map.end()) { + resultString += iter->second; + } else { + resultString += curChar; + } + i += len; + } + + return resultString; + } + + static size_t utf8_len(char src) { + const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4}; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; + } + + const llama_vocab & vocab; +}; + +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT } FRAGMENT_BUFFER_VARIANT_TYPE; -struct fragment_buffer_variant{ +struct fragment_buffer_variant { fragment_buffer_variant(llama_vocab::id _token) : type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), @@ -7971,8 +8399,7 @@ struct fragment_buffer_variant{ // #define PRETOKENIZERDEBUG -static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) -{ +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token for (const auto & st: vocab.special_tokens_cache) { const auto & special_token = st.first; @@ -8090,10 +8517,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { - for (const auto & fragment: fragment_buffer) - { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) - { + for (const auto & fragment: fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { // without adding this leading whitespace, we do not get the same results as the original tokenizer // TODO: It's likely possible to get rid of this string copy entirely @@ -8113,19 +8538,15 @@ static std::vector llama_tokenize_internal(const llama_vocab & llm_tokenizer_spm tokenizer(vocab); llama_escape_whitespace(raw_text); tokenizer.tokenize(raw_text, output); - } - else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - { + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } } } break; case LLAMA_VOCAB_TYPE_BPE: { - for (const auto & fragment: fragment_buffer) - { - if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) - { + for (const auto & fragment: fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); #ifdef PRETOKENIZERDEBUG @@ -8133,9 +8554,23 @@ static std::vector llama_tokenize_internal(const llama_vocab & #endif llm_tokenizer_bpe tokenizer(vocab); tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); } - else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - { + } + } break; + case LLAMA_VOCAB_TYPE_WPM: + { + for (const auto & fragment: fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_wpm tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); } } @@ -10799,7 +11234,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*5, + /* .mem_size */ ggml_tensor_overhead()*7, /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -10810,12 +11245,14 @@ struct llama_context * llama_new_context_with_model( ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); + ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch); ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_pos, "inp_pos"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); + ggml_set_name(ctx->inp_sum, "inp_sum"); ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); @@ -11746,6 +12183,7 @@ static std::string llama_decode_text(const std::string & text) { int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length) { if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { + case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_SPM: { // NOTE: we accept all unsupported token types, // suppressing them like CONTROL tokens. diff --git a/llama.h b/llama.h index cec4158bc..367e8f1a1 100644 --- a/llama.h +++ b/llama.h @@ -61,6 +61,7 @@ extern "C" { enum llama_vocab_type { LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece }; enum llama_token_type {