llama : allow overriding GGUF metadata when loading model (#4092)

* feat: Allow overriding GGUF metadata when loading model

* Fix the one time GCC is stricter than clang about something

* Step1

* Refactor... basically everything!

* Nuke obsolete GetArrayLen struct

* simplify std::string specialization

* Various cleanups

Add informational output when overrides are applied

Warn user when an override with the wrong type is specified

* Fix broken logic for parsing bool KV overrides
Fix issue where overrides didn't apply when key missing in GGUF metadata
Resolve merge changes

* llama : rearrange model params

* Update new GET_KEY call

Add note that metadata KV overrides aren't reflected in initial metadata KV info dump

---------

Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Kerfuffle 2023-12-05 10:19:18 -07:00 committed by GitHub
parent 52c8bc3cf3
commit 5aa365d88f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 361 additions and 86 deletions

View file

@ -690,6 +690,47 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::istreambuf_iterator<char>(),
std::back_inserter(sparams.grammar)
);
} else if (arg == "--override-kv") {
if (++i >= argc) {
invalid_param = true;
break;
}
char * sep = strchr(argv[i], '=');
if (sep == nullptr || sep - argv[i] >= 128) {
fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
invalid_param = true;
break;
}
struct llama_model_kv_override kvo;
std::strncpy(kvo.key, argv[i], sep - argv[i]);
kvo.key[sep - argv[i]] = 0;
sep++;
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_INT;
kvo.int_value = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
kvo.float_value = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
} else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true;
break;
}
} else {
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
invalid_param = true;
break;
}
params.kv_overrides.push_back(kvo);
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
} else if ( log_param_single_parse( argv[i] ) ) {
@ -733,6 +774,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
}
}
if (!params.kv_overrides.empty()) {
params.kv_overrides.emplace_back(llama_model_kv_override());
params.kv_overrides.back().key[0] = 0;
}
return true;
}
@ -864,6 +910,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf("\n");
#ifndef LOG_DISABLE_LOGS
log_print_usage();
@ -956,6 +1005,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
mparams.kv_overrides = params.kv_overrides.data();
}
return mparams;
}

View file

@ -86,6 +86,8 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
std::vector<llama_model_kv_override> kv_overrides;
// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter

370
llama.cpp
View file

@ -74,6 +74,7 @@
#include <set>
#include <sstream>
#include <thread>
#include <type_traits>
#include <unordered_map>
#if defined(_MSC_VER)
@ -590,21 +591,6 @@ struct LLM_TN {
// gguf helpers
//
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
do { \
const std::string skey(key); \
const int kid = gguf_find_key(ctx, skey.c_str()); \
if (kid >= 0) { \
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
if (ktype != (type)) { \
throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \
} \
(dst) = func(ctx, kid); \
} else if (req) { \
throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \
} \
} while (0)
static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
@ -638,7 +624,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
}
}
static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) {
static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
switch (type) {
@ -1797,6 +1783,169 @@ static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
return buf;
}
namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
struct GKV_Base_Type {
static constexpr gguf_type gt = gt_;
static T getter(const gguf_context * ctx, const int kid) {
return gfun(ctx, kid);
}
};
template<typename T> struct GKV_Base;
template<> struct GKV_Base<bool >: GKV_Base_Type<bool, GGUF_TYPE_BOOL, gguf_get_val_bool> {};
template<> struct GKV_Base<uint8_t >: GKV_Base_Type<uint8_t, GGUF_TYPE_UINT8, gguf_get_val_u8 > {};
template<> struct GKV_Base<uint16_t >: GKV_Base_Type<uint16_t, GGUF_TYPE_UINT16, gguf_get_val_u16 > {};
template<> struct GKV_Base<uint32_t >: GKV_Base_Type<uint32_t, GGUF_TYPE_UINT32, gguf_get_val_u32 > {};
template<> struct GKV_Base<uint64_t >: GKV_Base_Type<uint64_t, GGUF_TYPE_UINT64, gguf_get_val_u64 > {};
template<> struct GKV_Base<int8_t >: GKV_Base_Type<int8_t, GGUF_TYPE_INT8, gguf_get_val_i8 > {};
template<> struct GKV_Base<int16_t >: GKV_Base_Type<int16_t, GGUF_TYPE_INT16, gguf_get_val_i16 > {};
template<> struct GKV_Base<int32_t >: GKV_Base_Type<int32_t, GGUF_TYPE_INT32, gguf_get_val_i32 > {};
template<> struct GKV_Base<int64_t >: GKV_Base_Type<int64_t, GGUF_TYPE_INT64, gguf_get_val_i64 > {};
template<> struct GKV_Base<float >: GKV_Base_Type<float, GGUF_TYPE_FLOAT32, gguf_get_val_f32 > {};
template<> struct GKV_Base<double >: GKV_Base_Type<double, GGUF_TYPE_FLOAT64, gguf_get_val_f64 > {};
template<> struct GKV_Base<const char *>: GKV_Base_Type<const char *, GGUF_TYPE_STRING, gguf_get_val_str > {};
template<> struct GKV_Base<std::string> {
static constexpr gguf_type gt = GGUF_TYPE_STRING;
static std::string getter(const gguf_context * ctx, const int kid) {
return gguf_get_val_str(ctx, kid);
}
};
struct ArrayInfo{
const gguf_type gt;
const size_t length;
const void * data;
};
template<> struct GKV_Base<ArrayInfo> {
public:
static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
static ArrayInfo getter(const gguf_context *ctx, const int k) {
return ArrayInfo {
gguf_get_arr_type(ctx, k),
size_t(gguf_get_arr_n(ctx, k)),
gguf_get_arr_data(ctx, k),
};
}
};
template<typename T>
class GKV: public GKV_Base<T> {
GKV() = delete;
public:
static T get_kv(const gguf_context * ctx, const int k) {
const enum gguf_type kt = gguf_get_kv_type(ctx, k);
if (kt != GKV::gt) {
throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
}
return GKV::getter(ctx, k);
}
static const char * override_type_to_str(const llama_model_kv_override_type ty) {
switch (ty) {
case LLAMA_KV_OVERRIDE_BOOL: return "bool";
case LLAMA_KV_OVERRIDE_INT: return "int";
case LLAMA_KV_OVERRIDE_FLOAT: return "float";
}
return "unknown";
}
static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) {
if (!override) { return false; }
if (override->tag == expected_type) {
LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
__func__, override_type_to_str(override->tag), override->key);
switch (override->tag) {
case LLAMA_KV_OVERRIDE_BOOL: {
printf("%s\n", override->bool_value ? "true" : "false");
} break;
case LLAMA_KV_OVERRIDE_INT: {
printf("%" PRId64 "\n", override->int_value);
} break;
case LLAMA_KV_OVERRIDE_FLOAT: {
printf("%.6f\n", override->float_value);
} break;
default:
// Shouldn't be possible to end up here, but just in case...
throw std::runtime_error(
format("Unsupported attempt to override %s type for metadata key %s\n",
override_type_to_str(override->tag), override->key));
}
return true;
}
LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
__func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag));
return false;
}
template<typename OT>
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) {
target = override->bool_value;
return true;
}
return true;
}
template<typename OT>
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) {
target = override->int_value;
return true;
}
return false;
}
template<typename OT>
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) {
target = override->float_value;
return true;
}
return false;
}
template<typename OT>
static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
(void)target;
(void)override;
if (!override) { return false; }
// Currently, we should never end up here so it would be a bug if we do.
throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n",
override ? override->key : "NULL"));
}
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
if (try_override<T>(target, override)) {
return true;
}
if (k < 0) { return false; }
target = get_kv(ctx, k);
return true;
}
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
return set(ctx, gguf_find_key(ctx, key), target, override);
}
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
return set(ctx, key.c_str(), target, override);
}
};
}
struct llama_model_loader {
int n_kv = 0;
int n_tensors = 0;
@ -1812,21 +1961,34 @@ struct llama_model_loader {
llama_fver fver;
std::unique_ptr<llama_mmap> mapping;
std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
struct gguf_context * ctx_gguf = NULL;
struct ggml_context * ctx_meta = NULL;
llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") {
std::string arch_name;
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
llama_model_loader(const std::string & fname, bool use_mmap, const struct llama_model_kv_override * param_overrides_p) : file(fname.c_str(), "rb") {
struct gguf_init_params params = {
/*.no_alloc = */ true,
/*.ctx = */ &ctx_meta,
};
if (param_overrides_p != nullptr) {
for (const struct llama_model_kv_override *p = param_overrides_p; p->key[0] != 0; p++) {
kv_overrides.insert({std::string(p->key), *p});
}
}
ctx_gguf = gguf_init_from_file(fname.c_str(), params);
if (!ctx_gguf) {
throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
}
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
n_kv = gguf_get_n_kv(ctx_gguf);
n_tensors = gguf_get_n_tensors(ctx_gguf);
@ -1894,6 +2056,7 @@ struct llama_model_loader {
}
}
LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
for (int i = 0; i < n_kv; i++) {
const char * name = gguf_get_key(ctx_gguf, i);
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
@ -1939,19 +2102,59 @@ struct llama_model_loader {
}
}
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const std::string & key, T & result, const bool required = true) {
const int kid = gguf_find_key(ctx_gguf, key.c_str());
if (kid < 0) {
if (required) {
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
}
return false;
}
struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx_gguf, kid);
result = arr_info.length;
return true;
}
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
return get_arr_n(llm_kv(kid), result, required);
}
template<typename T>
bool get_key(const std::string & key, T & result, const bool required = true) {
auto it = kv_overrides.find(key);
const struct llama_model_kv_override * override =
it != kv_overrides.end() ? &it->second : nullptr;
const bool found = GGUFMeta::GKV<T>::set(ctx_gguf, key, result, override);
if (required && !found) {
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
}
return found;
}
template<typename T>
bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
return get_key(llm_kv(kid), result, required);
}
std::string get_arch_name() const {
const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
std::string arch_name;
GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE));
return arch_name;
}
enum llm_arch get_arch() const {
const std::string arch_name = get_arch_name();
return llm_arch_from_string(arch_name);
return llm_kv.arch;
}
const char * get_tensor_name(int i) const {
@ -2201,11 +2404,8 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
static void llm_load_hparams(
llama_model_loader & ml,
llama_model & model) {
struct gguf_context * ctx = ml.ctx_gguf;
const auto kv = LLM_KV(model.arch);
auto & hparams = model.hparams;
const gguf_context * ctx = ml.ctx_gguf;
// get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
@ -2219,42 +2419,41 @@ static void llm_load_hparams(
}
// get general kv
GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME));
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
// get hparams kv
GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST));
GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
// n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
hparams.rope_finetuned = false;
GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
kv(LLM_KV_ROPE_SCALING_FINETUNED));
bool rope_finetuned = false;
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
hparams.rope_finetuned = rope_finetuned;
hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false);
// rope_freq_base (optional)
hparams.rope_freq_base_train = 10000.0f;
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false);
std::string rope_scaling("linear");
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
// rope_freq_scale (inverse of the kv) is optional
float ropescale = 0.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
if (ropescale == 0.0f) { // try the old key name
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
// try the old key name
ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false);
}
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
@ -2262,7 +2461,7 @@ static void llm_load_hparams(
{
hparams.n_rot = hparams.n_embd / hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
@ -2277,7 +2476,7 @@ static void llm_load_hparams(
switch (model.arch) {
case LLM_ARCH_LLAMA:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break;
@ -2291,7 +2490,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_FALCON:
{
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break;
@ -2301,7 +2500,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_BAICHUAN:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break;
@ -2310,7 +2509,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_STARCODER:
{
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break;
case 36: model.type = e_model::MODEL_3B; break;
@ -2321,7 +2520,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_PERSIMMON:
{
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 36: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN;
@ -2329,7 +2528,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_REFACT:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_1B; break;
default: model.type = e_model::MODEL_UNKNOWN;
@ -2337,7 +2536,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_BLOOM:
{
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break;
@ -2352,9 +2551,9 @@ static void llm_load_hparams(
{
hparams.f_clamp_kqv = 0.0f;
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false);
ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break;
@ -2364,7 +2563,7 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_STABLELM:
{
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_3B; break;
@ -2373,7 +2572,8 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_QWEN:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break;
@ -2421,7 +2621,7 @@ static void llm_load_vocab(
{
std::string tokenizer_name;
GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL));
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM;
@ -2511,34 +2711,31 @@ static void llm_load_vocab(
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));
int32_t & id = std::get<1>(it), old_id = id;
int32_t & id = std::get<1>(it);
GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, key);
// Must be >= -1 and < vocab size. Since the key is unsigned, -1
// can only come from the default value, so there's no point in
// validating that.
if (size_t(id + 1) > vocab.id_to_token.size()) {
LLAMA_LOG_WARN("%s: bad special token: '%s' = %d, using default id %d\n",
__func__, key.c_str(), id, old_id);
id = old_id;
uint32_t new_id;
if (!ml.get_key(std::get<0>(it), new_id, false)) {
continue;
}
if (new_id >= vocab.id_to_token.size()) {
LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
__func__, key.c_str(), new_id, id);
} else {
id = new_id;
}
}
// Handle add_bos_token and add_eos_token
std::string key = kv(LLM_KV_TOKENIZER_ADD_BOS);
int kid = gguf_find_key(ctx, key.c_str());
enum gguf_type ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_bos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
}
key = kv(LLM_KV_TOKENIZER_ADD_EOS);
kid = gguf_find_key(ctx, key.c_str());
ktype = kid < 0 ? GGUF_TYPE_COUNT : gguf_get_kv_type(ctx, kid);
vocab.special_add_eos = ktype == GGUF_TYPE_BOOL ? gguf_get_val_bool(ctx, kid) : -1;
if (ktype != GGUF_TYPE_BOOL && ktype != GGUF_TYPE_COUNT) {
LLAMA_LOG_WARN("%s: bad field type %d for '%s' - ignoring\n", __func__, ktype, key.c_str());
{
bool temp = true;
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
vocab.special_add_bos = int(temp);
}
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.special_add_eos = int(temp);
}
}
}
@ -3487,7 +3684,7 @@ static void llm_load_tensors(
static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
try {
llama_model_loader ml(fname, params.use_mmap);
llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);
model.hparams.vocab_only = params.vocab_only;
@ -8078,7 +8275,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
constexpr bool use_mmap = false;
#endif
llama_model_loader ml(fname_inp, use_mmap);
llama_model_loader ml(fname_inp, use_mmap, NULL);
if (ml.use_mmap) {
ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa()));
}
@ -8374,7 +8571,7 @@ static int llama_apply_lora_from_file_internal(
std::vector<uint8_t> base_buf;
if (path_base_model) {
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true));
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL));
size_t ctx_size;
size_t mmapped_size;
@ -8602,6 +8799,7 @@ struct llama_model_params llama_model_default_params() {
/*.tensor_split =*/ nullptr,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,

20
llama.h
View file

@ -158,6 +158,22 @@ extern "C" {
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_INT,
LLAMA_KV_OVERRIDE_FLOAT,
LLAMA_KV_OVERRIDE_BOOL,
};
struct llama_model_kv_override {
char key[128];
enum llama_model_kv_override_type tag;
union {
int64_t int_value;
double float_value;
bool bool_value;
};
};
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
@ -165,9 +181,13 @@ extern "C" {
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible