diff --git a/common/common.cpp b/common/common.cpp index 4a0d43c13..90fe2e84e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -938,8 +939,8 @@ std::string get_sortable_timestamp() { const int64_t ns = std::chrono::duration_cast( current_time.time_since_epoch() % 1000000000).count(); - char timestamp_ns[10]; - snprintf(timestamp_ns, 11, "%09ld", ns); + char timestamp_ns[11]; + snprintf(timestamp_ns, 11, "%09" PRId64, ns); return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); } diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 51d90ea6a..e9e070b1f 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -681,7 +681,6 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod // for rms-att-weight int row_length = model->hparams.n_embd; - const auto & hparams = model->hparams; int n_ff = model->hparams.n_ff; for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ diff --git a/examples/gguf/CMakeLists.txt b/examples/gguf/CMakeLists.txt new file mode 100644 index 000000000..7d1806af3 --- /dev/null +++ b/examples/gguf/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET gguf) +add_executable(${TARGET} gguf.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/train-text-from-scratch/README.md b/examples/train-text-from-scratch/README.md index 726ec47c0..f4ffcd987 100644 --- a/examples/train-text-from-scratch/README.md +++ b/examples/train-text-from-scratch/README.md @@ -8,15 +8,15 @@ wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/s # train ./bin/train-text-from-scratch \ - --vocab-model ../models/ggml-vocab.bin \ + --vocab-model ../models/ggml-vocab-llama.gguf \ --ctx 64 --embd 256 --head 8 --layer 16 \ - --checkpoint-in chk-shakespeare-256x16.bin \ - --checkpoint-out chk-shakespeare-256x16.bin \ - --model-out ggml-shakespeare-256x16-f32.bin \ + --checkpoint-in chk-shakespeare-256x16.gguf \ + --checkpoint-out chk-shakespeare-256x16.gguf \ + --model-out ggml-shakespeare-256x16-f32.gguf \ --train-data "shakespeare.txt" \ - -t 6 -b 16 -n 32 --seed 1 --adam-iter 16 \ - --print-details-interval 0 --predict 16 --use-flash + -t 6 -b 16 --seed 1 --adam-iter 256 \ + --no-checkpointing # predict -./bin/main -m ggml-shakespeare-256x16-f32.bin +./bin/main -m ggml-shakespeare-256x16-f32.gguf ``` diff --git a/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py b/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py new file mode 100644 index 000000000..01b3ee92a --- /dev/null +++ b/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py @@ -0,0 +1,492 @@ +#!/usr/bin/env python3 +# train-text-from-scratch checkpoint --> gguf conversion + +import argparse +import gguf +import os +import struct +import sys +import numpy as np +from pathlib import Path + +# gguf constants +LLM_KV_OPTIMIZER_TYPE = "optimizer.type" +LLM_KV_OPTIMIZER_TYPE_ADAM = "adam" +LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs" +LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version" +LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count" +LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count" +LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count" +LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized" +LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss" +LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss" +LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count" +LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count" +LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end" +LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count" + +LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments" +LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments" +LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values" + +LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters" +LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters" +LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients" +LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients" +LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction" +LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y" + +LLM_KV_TRAINING_FILE_VERSION = "training.file_version" +LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count" +LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count" +LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count" + +class Tensor: + def __init__(self, dtype='f', ne=None): + if ne is None: + ne = [] + self.dtype = dtype + self.ne = ne + self.nbytes = 0 + if self.dtype == 'f': + if len(self.ne) == 0: + self.nbytes = 0 + else: + self.nbytes = int(np.product(self.ne)) * 4 + else: + raise ValueError(f"Unhandled data type '{self.dtype}'") + + def load(self, data, offset): + nd = struct.unpack(' 0 else []) + + self.lbfgs_x = Tensor('f', [self.nx]) + self.lbfgs_xp = Tensor('f', [self.nx]) + self.lbfgs_g = Tensor('f', [self.nx]) + self.lbfgs_gp = Tensor('f', [self.nx]) + self.lbfgs_d = Tensor('f', [self.nx]) + self.lbfgs_pf = Tensor('f', [self.past] if self.past > 0 else []) + self.lbfgs_lmal = Tensor('f', [self.lbfgs_m]) + self.lbfgs_lmys = Tensor('f', [self.lbfgs_m]) + self.lbfgs_lms = Tensor('f', [self.nx, self.lbfgs_m]) + self.lbfgs_lmy = Tensor('f', [self.nx, self.lbfgs_m]) + + if self.type == 0: + # these tensors are stored, but we don't need their data + x = Tensor('f', [self.nx]) + g = Tensor('f', [self.nx]) + g2 = Tensor('f', [self.nx]) + mh = Tensor('f', [self.nx]) + vh = Tensor('f', [self.nx]) + + offset = x.load(data, offset) + offset = g.load(data, offset) + offset = g2.load(data, offset) + offset = self.adam_m.load(data, offset) + offset = self.adam_v.load(data, offset) + offset = mh.load(data, offset) + offset = vh.load(data, offset) + offset = self.adam_pf.load(data, offset) + + self.adam_fx_best = struct.unpack(' 0 else []) + + self.lbfgs_x = Tensor('f', [self.nx]) + self.lbfgs_xp = Tensor('f', [self.nx]) + self.lbfgs_g = Tensor('f', [self.nx]) + self.lbfgs_gp = Tensor('f', [self.nx]) + self.lbfgs_d = Tensor('f', [self.nx]) + self.lbfgs_pf = Tensor('f', [self.past] if self.past > 0 else []) + self.lbfgs_lmal = Tensor('f', [self.lbfgs_m]) + self.lbfgs_lmys = Tensor('f', [self.lbfgs_m]) + self.lbfgs_lms = Tensor('f', [self.nx, self.lbfgs_m]) + self.lbfgs_lmy = Tensor('f', [self.nx, self.lbfgs_m]) + + # forgot to save type in version 1: + # guess self.type from number of remaining bytes + size_type_0 = 12 + sum([t.max_storage_size() for t in + [self.adam_m, self.adam_v] + +([self.adam_pf] if (self.past > 0) else [])]) + size_type_1 = 24 + sum([t.max_storage_size() for t in + [self.lbfgs_x, self.lbfgs_xp, self.lbfgs_g, + self.lbfgs_gp, self.lbfgs_d, self.lbfgs_pf, + self.lbfgs_lmal, self.lbfgs_lmys, + self.lbfgs_lms, self.lbfgs_lmy] + +([self.lbfgs_pf] if (self.past > 0) else [])]) + # due to alignment padding the size might not by exact + # but the difference in size for both types is significant, + # so we can just use whichever is closest + remaining = len(data) - offset + if abs(remaining - size_type_0) < abs(remaining - size_type_1): + self.type = 0 + else: + self.type = 1 + + if self.type == 0: + offset = self.adam_m.load(data, offset) + offset = self.adam_v.load(data, offset) + offset = self.adam_pf.load(data,offset) + + self.adam_fx_best = struct.unpack(' 0: + self.adam_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES) + + elif self.type == 1: + gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS) + gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m) + gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best) + gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step) + gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, self.lbfgs_j) + gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, self.lbfgs_k) + gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end) + gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement) + + self.lbfgs_x.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS) + self.lbfgs_xp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS) + self.lbfgs_g.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS) + self.lbfgs_gp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS) + self.lbfgs_d.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION) + if self.past > 0: + self.lbfgs_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES) + self.lbfgs_lmal.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA) + self.lbfgs_lmys.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS) + self.lbfgs_lms.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S) + self.lbfgs_lmy.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y) + else: + raise ValueError('Unknown optimizer type') + +class ModelParams: + def __init__(self): + pass + + def load(self, data, offset): + self.n_vocab = struct.unpack(' @@ -17,8 +18,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const float rms_norm_eps = 1e-5f; - struct random_normal_distribution { std::mt19937 gen; std::normal_distribution rd; @@ -63,17 +62,6 @@ float frand_uniform(struct random_uniform_distribution * rnd) { return rnd->rd(rnd->gen); } -void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); - - if (plan.work_size > 0) { - buf.resize(plan.work_size); - plan.work_data = buf.data(); - } - - ggml_graph_compute(graph, &plan); -} - struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { float scale = 1.0f; // xavier switch (tensor->n_dims) { @@ -167,29 +155,20 @@ struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struc return tensor; } -struct llama_vocab { - using id = int32_t; - using token = std::string; - using ttype = llama_token_type; - - struct token_data { - token text; - float score; - ttype type; - }; - - std::unordered_map token_to_id; - std::vector id_to_token; -}; - struct my_llama_hparams { uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_ctx = 512; uint32_t n_embd = 4096; - uint32_t n_mult = 4; uint32_t n_head = 32; uint32_t n_layer = 32; uint32_t n_rot = 64; + uint32_t n_ff = 11008; + + // float f_norm_eps = 1e-5; // falcon + float f_norm_rms_eps = 1e-5; // llama + + float rope_freq_base = 10000.0f; + float rope_freq_scale = 1.0f; bool operator!=(const my_llama_hparams& other) const { return memcmp(this, &other, sizeof(my_llama_hparams)); @@ -215,17 +194,6 @@ struct my_llama_layer { struct ggml_tensor * w3; }; -struct my_llama_kv_cache { - struct ggml_context * ctx = NULL; - - struct ggml_tensor * k; - struct ggml_tensor * v; - - // llama_ctx_buffer buf; - - int n; // number of tokens currently in the cache -}; - struct my_llama_model { struct ggml_context * ctx = NULL; @@ -243,18 +211,91 @@ struct my_llama_model { uint32_t train_tokens = 0; }; -uint32_t get_n_ff(const struct my_llama_hparams* hparams) { - const uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult; - return n_ff; -} +// gguf constants +const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type"; +const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam"; +const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs"; +const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version"; +const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count"; +const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count"; +const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count"; +const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized"; +const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss"; +const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss"; +const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count"; +const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count"; +const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss"; +const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step"; +const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j"; +const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k"; +const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end"; +const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count"; + +const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments"; +const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments"; +const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values"; + +const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; +const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; + +const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; +const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; +const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; +const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; + +// gguf constants (sync with gguf.py) + +const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture"; +const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type"; + +const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length"; +const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length"; +const char * LLM_KV_BLOCK_COUNT = "%s.block_count"; +const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length"; +const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count"; +const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon"; +const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count"; +const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp +const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear"; + +const char * LLM_KV_TOKENIZER_MODEL = "tokenizer.ggml.model"; +const char * LLM_KV_TOKENIZER_LIST = "tokenizer.ggml.tokens"; +const char * LLM_KV_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type"; +const char * LLM_KV_TOKENIZER_SCORES = "tokenizer.ggml.scores"; +const char * LLM_KV_TOKENIZER_MERGES = "tokenizer.ggml.merges"; +const char * LLM_KV_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id"; +const char * LLM_KV_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id"; +const char * LLM_KV_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id"; +const char * LLM_KV_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id"; +const char * LLM_KV_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id"; + +const char * LLM_TENSOR_TOKEN_EMBD = "token_embd"; +const char * LLM_TENSOR_OUTPUT_NORM = "output_norm"; +const char * LLM_TENSOR_OUTPUT = "output"; +const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm"; +const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q"; +const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k"; +const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v"; +const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output"; +const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm"; +const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate"; +const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down"; +const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up"; void print_params(struct my_llama_hparams * params) { printf("%s: n_vocab: %d\n", __func__, params->n_vocab); printf("%s: n_ctx: %d\n", __func__, params->n_ctx); printf("%s: n_embd: %d\n", __func__, params->n_embd); - printf("%s: n_mult: %d\n", __func__, params->n_mult); printf("%s: n_head: %d\n", __func__, params->n_head); - printf("%s: n_ff: %d\n", __func__, get_n_ff(params)); + printf("%s: n_ff: %d\n", __func__, params->n_ff); printf("%s: n_layer: %d\n", __func__, params->n_layer); printf("%s: n_rot: %d\n", __func__, params->n_rot); } @@ -265,8 +306,7 @@ void init_model(struct my_llama_model * model) { const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; - - const uint32_t n_ff = get_n_ff(&hparams); + const uint32_t n_ff = hparams.n_ff; struct ggml_context * ctx = model->ctx; @@ -274,20 +314,31 @@ void init_model(struct my_llama_model * model) { model->train_samples = 0; model->train_tokens = 0; + std::vector tn_buf; + tn_buf.resize(GGML_MAX_NAME); + auto tn = [&tn_buf](const char * key) -> const char * { + snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key); + return tn_buf.data(); + }; + auto tni = [&tn_buf](const char * key, int bid) -> const char * { + snprintf(tn_buf.data(), tn_buf.size(), key, bid); + std::string s = tn_buf.data(); + snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str()); + return tn_buf.data(); + }; + model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); - ggml_set_name(model->tok_embeddings, "tok_embeddings.weight"); - ggml_set_name(model->norm, "norm.weight"); - ggml_set_name(model->output, "output.weight"); + ggml_set_name(model->tok_embeddings, tn(LLM_TENSOR_TOKEN_EMBD)); + ggml_set_name(model->norm, tn(LLM_TENSOR_OUTPUT_NORM)); + ggml_set_name(model->output, tn(LLM_TENSOR_OUTPUT)); model->layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { auto & layer = model->layers[i]; - std::string layers_i = "layers." + std::to_string(i); - layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); @@ -301,18 +352,18 @@ void init_model(struct my_llama_model * model) { layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); - ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str()); + ggml_set_name(layer.attention_norm, tni(LLM_TENSOR_ATTN_NORM, i)); - ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str()); - ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str()); - ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str()); - ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str()); + ggml_set_name(layer.wq, tni(LLM_TENSOR_ATTN_Q, i)); + ggml_set_name(layer.wk, tni(LLM_TENSOR_ATTN_K, i)); + ggml_set_name(layer.wv, tni(LLM_TENSOR_ATTN_V, i)); + ggml_set_name(layer.wo, tni(LLM_TENSOR_ATTN_OUT, i)); - ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str()); + ggml_set_name(layer.ffn_norm, tni(LLM_TENSOR_FFN_NORM, i)); - ggml_format_name(layer.w1, "%s.feed_forward.w1.weight", layers_i.c_str()); - ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str()); - ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str()); + ggml_set_name(layer.w1, tni(LLM_TENSOR_FFN_GATE, i)); + ggml_set_name(layer.w2, tni(LLM_TENSOR_FFN_DOWN, i)); + ggml_set_name(layer.w3, tni(LLM_TENSOR_FFN_UP, i)); } } @@ -371,267 +422,6 @@ void randomize_model(struct my_llama_model * model, int seed, float mean, float } } -bool init_kv_cache(struct my_llama_kv_cache* cache, struct my_llama_model * model, int n_batch) { - const auto & hparams = model->hparams; - - const uint32_t n_ctx = hparams.n_ctx; - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - - const int64_t n_mem = n_layer*n_ctx*n_batch; - const int64_t n_elements = n_embd*n_mem; - - // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - - // struct ggml_init_params params; - // params.mem_size = cache.buf.size; - // params.mem_buffer = cache.buf.addr; - // params.no_alloc = false; - if (!cache->ctx) { - struct ggml_init_params params; - params.mem_size = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024; - params.mem_buffer = NULL; - params.no_alloc = false; - - cache->ctx = ggml_init(params); - - if (!cache->ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); - return false; - } - } - - cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); - cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); - - return true; -} - -struct ggml_tensor * forward( - struct my_llama_model * model, - struct my_llama_kv_cache * cache, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_past) { - - const int N = n_tokens; - - struct my_llama_kv_cache& kv_self = *cache; - const auto & hparams = model->hparams; - const int n_ctx = hparams.n_ctx; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens)); - - struct ggml_tensor * kc = kv_self.k; - struct ggml_tensor * vc = kv_self.v; - - // inpL shape [n_embd,N,1,1] - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // lctx.use_buf(ctx0, 0); - - // norm - { - // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, 1] - // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); - - // store key and value to memory - { - // compute the transposed [N, n_embd] V matrix - // wv shape [n_embd, n_embd, 1, 1] - // Vcur shape [n_embd, N, 1, 1] - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N))); - - // kv_self.k shape [n_embd * n_ctx * n_layer, 1] - // kv_self.v shape [n_embd * n_ctx * n_layer, 1] - // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0] - // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0] - - /* { - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } //*/ - - kc = ggml_set_1d_inplace(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - vc = ggml_set_2d_inplace(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - } - - // Qcur shape [n_embd/n_head, n_head, N, 1] - // Q shape [n_embd/n_head, N, n_head, 1] - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - - // kv_self.k shape [n_embd * n_ctx * n_layer, 1] - // K shape [n_embd/n_head, n_past + N, n_head, 1] - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd), - n_embd/n_head, n_head, n_past + N), - 0, 2, 1, 3); - - // K * Q - // KQ shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); - - // KQ_masked = mask_past(KQ_scaled) - // KQ_masked shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - // KQ_soft_max shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - - // split cached V into n_head heads - //// V shape [n_past + N, n_embd/n_head, n_head, 1] - // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1] - struct ggml_tensor * V = - ggml_view_3d(ctx0, vc, - n_past + N, n_embd/n_head, n_head, - n_ctx*ggml_element_size(vc), - n_ctx*ggml_element_size(vc)*n_embd/n_head, - il*n_ctx*ggml_element_size(vc)*n_embd); - - // KQV shape [n_embd/n_head, N, n_head, 1] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // KQV_merged shape - - // cur = KQV_merged.contiguous().view(n_embd, N) - // cur shape [n_embd,N,1,1] - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N); - // cur = ggml_cpy(ctx0, - // KQV_merged, - // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - - // projection (no bias) - // cur shape [n_embd,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].wo, - cur); - } - - // lctx.use_buf(ctx0, 1); - - // inpFF shape [n_embd,N,1,1] - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - - // feed-forward network - { - // norm - { - // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - - // cur = ffn_norm*cur - // cur shape [n_embd,N,1,1] - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - } - - // tmp shape [n_ff,N,1,1] - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - - // cur shape [n_ff,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - - // SILU activation - // cur shape [n_ff,N,1,1] - cur = ggml_silu(ctx0, cur); - - // cur shape [n_ff,N,1,1] - cur = ggml_mul(ctx0, cur, tmp); - - // cur shape [n_embd,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - } - - // cur shape [n_embd,N,1,1] - cur = ggml_add(ctx0, cur, inpFF); - - // input for next layer - // inpL shape [n_embd,N,1,1] - inpL = cur; - } - - // norm - { - - // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - - // inpL = norm*inpL - // inpL shape [n_embd,N,1,1] - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - //embeddings = inpL; - } - - // lm_head - // inpL shape [n_vocab,N,1,1] - inpL = ggml_mul_mat(ctx0, model->output, inpL); - - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; -} - void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) { GGML_ASSERT(tensor->n_dims == 1); GGML_ASSERT(tensor->ne[0] == ne0); @@ -658,786 +448,222 @@ void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int6 GGML_ASSERT(tensor->ne[3] == ne3); } -struct ggml_tensor * forward_batch( - struct my_llama_model * model, - struct my_llama_kv_cache * cache, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_past, - const int n_batch) { - - const int N = n_tokens; - - struct my_llama_kv_cache& kv_self = *cache; - const auto & hparams = model->hparams; - const int n_ctx = hparams.n_ctx; - const int n_vocab = hparams.n_vocab; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - const int n_ff = get_n_ff(&hparams); - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); - - struct ggml_tensor * kc = kv_self.k; - struct ggml_tensor * vc = kv_self.v; - - // inpL shape [n_embd,N*n_batch,1] - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - assert_shape_2d(inpL, n_embd, N*n_batch); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // lctx.use_buf(ctx0, 0); - - // norm - { - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); - assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); - - // store key and value to memory - { - // compute the transposed [N, n_embd] V matrix - // wv shape [n_embd, n_embd, 1, 1] - // Vcur shape [N, n_embd, n_batch, 1] - struct ggml_tensor * Vcur = ggml_cont(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_mul_mat(ctx0, - model->layers[il].wv, - cur), - n_embd, N, n_batch), - 1, 0, 2, 3)); - assert_shape_3d(Vcur, N, n_embd, n_batch); - - // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] - // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] - // k shape [n_embd * N, n_batch] == kv_self.k[:,n_past:n_past+N,:,il] - // v shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il] - - /* { - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } //*/ - - kc = ggml_set_2d_inplace(ctx0, kc, - ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch), - ggml_element_size(kc)*n_embd*n_ctx, - (ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past)); - vc = ggml_set_2d_inplace(ctx0, vc, - ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch), - ggml_element_size(vc)*n_ctx*n_embd, - ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx)); - - assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer); - assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer); - } - - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Q shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); - - // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] - // K shape [n_embd/n_head, n_past + N, n_head, n_batch] - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_4d(ctx0, - ggml_view_3d(ctx0, - kc, - n_embd, - (n_past + N), - n_batch, - n_embd*ggml_element_size(kc), - n_ctx*n_embd*ggml_element_size(kc), - il*n_batch*n_ctx*n_embd*ggml_element_size(kc)), - n_embd/n_head, n_head, n_past + N, n_batch), - 0, 2, 1, 3); - assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch); - - // K * Q - // KQ shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - assert_shape_4d(KQ, n_past + N, N, n_head, n_batch); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // KQ_scaled shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_scaled = - ggml_scale_inplace(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); - assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch); - - // KQ_masked = mask_past(KQ_scaled) - // KQ_masked shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch); - - // KQ = soft_max(KQ_masked) - // KQ_soft_max shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch); - - // split cached V into n_head heads - // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] - // V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il] - struct ggml_tensor * V = - ggml_view_4d(ctx0, vc, - n_past + N, n_embd/n_head, n_head, n_batch, - ggml_element_size(vc)*n_ctx, - ggml_element_size(vc)*n_ctx*n_embd/n_head, - ggml_element_size(vc)*n_ctx*n_embd, - il*n_batch*n_ctx*n_embd*ggml_element_size(vc)); - assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch); - - // KQV shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); - // KQV_merged shape - - // cur = KQV_merged.contiguous().view(n_embd, N) - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); - assert_shape_2d(cur, n_embd, N*n_batch); - // cur = ggml_cpy(ctx0, - // KQV_merged, - // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - - // projection (no bias) - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].wo, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // lctx.use_buf(ctx0, 1); - - // inpFF shape [n_embd,N*n_batch,1,1] - struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); - assert_shape_2d(inpFF, n_embd, N*n_batch); - - // feed-forward network - { - // norm - { - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = ffn_norm*cur - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // tmp shape [n_ff,N*n_batch,1,1] - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - assert_shape_2d(tmp, n_ff, N*n_batch); - - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // SILU activation - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_silu(ctx0, cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_mul(ctx0, cur, tmp); - assert_shape_2d(cur, n_ff, N*n_batch); - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_add_inplace(ctx0, cur, inpFF); - assert_shape_2d(cur, n_embd, N*n_batch); - - // input for next layer - // inpL shape [n_embd,N*n_batch,1,1] - inpL = cur; - assert_shape_2d(inpL, n_embd, N*n_batch); - } - - // norm - { - - // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(inpL, n_embd, N*n_batch); - - // inpL = norm*inpL - // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - assert_shape_2d(inpL, n_embd, N*n_batch); - - //embeddings = inpL; - } - - // lm_head - // inpL shape [n_vocab,N*n_batch,1,1] - inpL = ggml_mul_mat(ctx0, model->output, inpL); - assert_shape_2d(inpL, n_vocab, N*n_batch); - - { - // inpL shape [n_vocab,N,n_batch,1] - inpL = ggml_reshape_3d(ctx0, - inpL, - n_vocab, N, n_batch); - assert_shape_3d(inpL, n_vocab, N, n_batch); - } - - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; } -struct ggml_tensor * forward_batch_wo_cache( - struct my_llama_model * model, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_batch) { +static size_t hash_find(void * hash_table[], void * p) { + size_t h = hash(p); - const int n_past = 0; - const int N = n_tokens; - - const auto & hparams = model->hparams; - //const int n_ctx = hparams.n_ctx; - const int n_vocab = hparams.n_vocab; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - const int n_ff = get_n_ff(&hparams); - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); - - // inpL shape [n_embd,N*n_batch,1] - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - assert_shape_2d(inpL, n_embd, N*n_batch); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // lctx.use_buf(ctx0, 0); - - // norm - { - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // visited all hash table entries -> not found + return GGML_GRAPH_HASHTABLE_SIZE; } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); - assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); - - // Vcur shape [N, n_batch, n_embd/n_head, n_head] - struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); - assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); - - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Q shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); - - // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] - // K shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * K = - ggml_permute(ctx0, - Kcur, - 0, 2, 1, 3); - assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); - - // K * Q - // KQ shape [N, N, n_head, n_batch] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - assert_shape_4d(KQ, N, N, n_head, n_batch); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // KQ_scaled shape [N, N, n_head, n_batch] - struct ggml_tensor * KQ_scaled = - ggml_scale_inplace(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); - assert_shape_4d(KQ_scaled, N, N, n_head, n_batch); - - // KQ_masked = mask_past(KQ_scaled) - // KQ_masked shape [N, N, n_head, n_batch] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - assert_shape_4d(KQ_masked, N, N, n_head, n_batch); - - // KQ = soft_max(KQ_masked) - // KQ_soft_max shape [N, N, n_head, n_batch] - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch); - - // Vcur shape [N, n_batch, n_embd/n_head, n_head] - // V shape [N, n_embd/n_head, n_head, n_batch] - struct ggml_tensor * V = - ggml_permute(ctx0, - Vcur, - 0, 3, 1, 2); - assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); - - // KQV shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); - // KQV_merged shape - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); - assert_shape_2d(cur, n_embd, N*n_batch); - - // projection (no bias) - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].wo, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // lctx.use_buf(ctx0, 1); - - // inpFF shape [n_embd,N*n_batch,1,1] - struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); - assert_shape_2d(inpFF, n_embd, N*n_batch); - - // feed-forward network - { - // norm - { - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = ffn_norm*cur - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // tmp shape [n_ff,N*n_batch,1,1] - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - assert_shape_2d(tmp, n_ff, N*n_batch); - - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // SILU activation - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_silu(ctx0, cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_mul(ctx0, cur, tmp); - assert_shape_2d(cur, n_ff, N*n_batch); - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_add_inplace(ctx0, cur, inpFF); - assert_shape_2d(cur, n_embd, N*n_batch); - - // input for next layer - // inpL shape [n_embd,N*n_batch,1,1] - inpL = cur; - assert_shape_2d(inpL, n_embd, N*n_batch); } - - // norm - { - - // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(inpL, n_embd, N*n_batch); - - // inpL = norm*inpL - // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - assert_shape_2d(inpL, n_embd, N*n_batch); - - //embeddings = inpL; - } - - // lm_head - // inpL shape [n_vocab,N*n_batch,1,1] - inpL = ggml_mul_mat(ctx0, model->output, inpL); - assert_shape_2d(inpL, n_vocab, N*n_batch); - - { - // inpL shape [n_vocab,N,n_batch,1] - inpL = ggml_reshape_3d(ctx0, - inpL, - n_vocab, N, n_batch); - assert_shape_3d(inpL, n_vocab, N, n_batch); - } - - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; + return i; } -struct ggml_tensor * forward_batch_wo_cache_flash_attn( - struct my_llama_model * model, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_batch) { +static bool hash_insert(void * hash_table[], void * p) { + //size_t h = hash(p); + size_t i = hash_find(hash_table, p); - const int n_past = 0; - const int N = n_tokens; + GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full - const auto & hparams = model->hparams; - //const int n_ctx = hparams.n_ctx; - const int n_vocab = hparams.n_vocab; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - const int n_ff = get_n_ff(&hparams); - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); - - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - assert_shape_2d(inpL, n_embd, N*n_batch); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // norm - { - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); - assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); - - struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); - assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); - - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); - - struct ggml_tensor * K = - ggml_permute(ctx0, - Kcur, - 0, 2, 1, 3); - assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); - - struct ggml_tensor * V = - ggml_permute(ctx0, - Vcur, - 0, 3, 1, 2); - assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); - - bool masked = true; - struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked); - assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); - assert_shape_2d(cur, n_embd, N*n_batch); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, - model->layers[il].wo, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); - assert_shape_2d(inpFF, n_embd, N*n_batch); - - // feed-forward network - { - // norm - { - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = ffn_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - assert_shape_2d(tmp, n_ff, N*n_batch); - - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // SILU activation - cur = ggml_silu(ctx0, cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - cur = ggml_mul(ctx0, cur, tmp); - assert_shape_2d(cur, n_ff, N*n_batch); - - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - cur = ggml_add_inplace(ctx0, cur, inpFF); - assert_shape_2d(cur, n_embd, N*n_batch); - - // input for next layer - inpL = cur; - assert_shape_2d(inpL, n_embd, N*n_batch); + if (hash_table[i] == p) { + return true; } - // norm - { - - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(inpL, n_embd, N*n_batch); - - // inpL = norm*inpL - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - assert_shape_2d(inpL, n_embd, N*n_batch); - } - - // lm_head - inpL = ggml_mul_mat(ctx0, model->output, inpL); - assert_shape_2d(inpL, n_vocab, N*n_batch); - - { - inpL = ggml_reshape_3d(ctx0, - inpL, - n_vocab, N, n_batch); - assert_shape_3d(inpL, n_vocab, N, n_batch); - } - - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; + // insert + GGML_ASSERT(hash_table[i] == NULL); + hash_table[i] = p; + return false; } -// expand the graph nodes without creating leafs. -struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) { - // check if already visited - for (int i = 0; i < g->n_nodes; i++) { - if (g->nodes[i] == t) { - return t; - } - } - - for (int i = 0; i < g->n_leafs; i++) { - if (g->leafs[i] == t) { - return t; - } - } - - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (t->src[i]) { - expand(g, t->src[i]); - } - } - - GGML_ASSERT(g->n_nodes < GGML_MAX_NODES); - - if (strlen(t->name) == 0) { - snprintf(t->name, sizeof(t->name), "node_%d", g->n_nodes); - } - - g->nodes[g->n_nodes] = t; - g->grads[g->n_nodes] = t->grad; - g->n_nodes++; - return t; +static bool hash_contains(void * hash_table[], void * p) { + size_t i = hash_find(hash_table, p); + return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p); } -void graph_set_leafs_grads(struct ggml_cgraph * g) { - // moves leaf nodes to g->leafs. - // i.e. g->n_nodes might change. - int n_nodes = 0; - for (int i = 0; i < g->n_nodes; ++i) { - struct ggml_tensor * node = g->nodes[i]; - const bool is_leaf = node->op == GGML_OP_NONE && node->grad == NULL; - if (is_leaf) { - GGML_ASSERT(g->n_leafs < GGML_MAX_NODES); +struct hash_map { + void * keys[GGML_GRAPH_HASHTABLE_SIZE]; + void * vals[GGML_GRAPH_HASHTABLE_SIZE]; +}; +//static const size_t HASH_MAP_SIZE = sizeof(struct hash_map); - if (strlen(node->name) == 0) { - snprintf(node->name, sizeof(node->name), "leaf_%d", g->n_leafs); - } - - g->leafs[g->n_leafs] = node; - g->n_leafs++; - } else { - GGML_ASSERT(n_nodes < GGML_MAX_NODES); - - if (strlen(node->name) == 0) { - snprintf(node->name, sizeof(node->name), "node_%d", n_nodes); - } - - g->nodes[n_nodes] = node; - g->grads[n_nodes] = node->grad; - n_nodes++; - } +struct hash_map * new_hash_map() { + struct hash_map * result = new struct hash_map; + for (int i=0; ikeys[i] = NULL; + result->vals[i] = NULL; } - for (int i=n_nodes; i < g->n_nodes; ++i) { - g->nodes[n_nodes] = NULL; - g->grads[n_nodes] = NULL; - } - g->n_nodes = n_nodes; + return result; +}; + +void free_hash_map(struct hash_map * map) { + delete map; } -struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( - struct my_llama_model * model, - struct ggml_context * ctx0, +static bool ggml_is_view(struct ggml_tensor * t) { + return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE || + t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY; +} + +static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) { + switch (t->op) { + case GGML_OP_PERMUTE: + case GGML_OP_RESHAPE: + case GGML_OP_TRANSPOSE: + case GGML_OP_VIEW: + return t->src[0]; + case GGML_OP_CPY: + return t->src[1]; + default: + return NULL; + } +} + +static struct ggml_tensor * get_view_source(struct ggml_tensor * t) { + struct ggml_tensor * parent = t; + do { + parent = get_view_parent(parent); + } while (ggml_is_view(parent)); + return parent; +} + +struct ggml_tensor * ggml_recompute_graph_node( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct hash_map * replacements, + struct ggml_tensor * node) { + + if (node == NULL) { + return NULL; + } + + if (node->is_param) { + return node; + } + + if (!hash_contains(graph->visited_hash_table, node)) { + return node; + } + + int count_children = 0; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + if (node->src[k]) { + ++count_children; + } + } + + if (count_children == 0) { + return node; + } + + size_t i = hash_find(replacements->keys, node); + GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + if (replacements->keys[i] == node) { + return (struct ggml_tensor *) replacements->vals[i]; + } + + struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + + // insert clone into replacements + GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite + replacements->keys[i] = node; + replacements->vals[i] = clone; + + clone->op = node->op; + clone->grad = node->grad; + clone->is_param = node->is_param; + clone->extra = node->extra; + for (int k = 0; k < GGML_MAX_DIMS; ++k) { + clone->nb[k] = node->nb[k]; + } + for (int k = 0; k < GGML_MAX_SRC; ++k) { + clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + if (ggml_is_view(clone)) { + struct ggml_tensor * source = get_view_source(clone); + GGML_ASSERT(source != NULL); + clone->data = source->data; + } + + GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); + GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); + + return clone; +}; + +void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints) { + *gb_tmp = *gf; + ggml_build_backward_expand(ctx, gf, gb_tmp, true); + + if (n_checkpoints <= 0) { + *gb = *gb_tmp; + return; + } + + struct hash_map * replacements = new_hash_map(); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = hash_find(replacements->keys, checkpoints[i]); + GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite + replacements->keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; + } + + *gb = *gf; + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are checkpoints + node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + ggml_build_forward_expand(gb, node); + } + + free_hash_map(replacements); +} + +struct ggml_tensor * llama_build_train_graphs( + struct my_llama_model * model, + struct ggml_allocr * alloc, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, struct ggml_tensor * * logits, struct ggml_tensor * tokens_input, struct ggml_tensor * targets, - void * compute_buf_0, - void * compute_buf_1, - size_t size_buf_0, - size_t size_buf_1, const int n_tokens, - const int n_batch) { - - ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + const int n_batch, + const bool enable_flash_attn, + const bool enable_checkpointing) { + ggml_set_scratch(ctx, { 0, 0, nullptr, }); const int n_past = 0; const int N = n_tokens; - - gf->n_nodes = 0; - gf->n_leafs = 0; - gf->perf_runs = 0; - gf->perf_cycles = 0; - gf->perf_time_us = 0; - const auto & hparams = model->hparams; const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; @@ -1445,476 +671,162 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( const int n_layer = hparams.n_layer; const int n_head = hparams.n_head; const int n_rot = hparams.n_rot; - const int n_ff = get_n_ff(&hparams); - const int rope_mode = 0; + const int n_ff = hparams.n_ff; + const float f_norm_rms_eps = hparams.f_norm_rms_eps; + const float rope_freq_base = hparams.rope_freq_base; + const float rope_freq_scale = hparams.rope_freq_scale; - int last_buf = -1; - size_t buf_offs[2] = { 0, 0 }; - size_t buf_size[2] = { size_buf_0, - size_buf_1 }; - void * buf_data[2] = { compute_buf_0, - compute_buf_1 }; - auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) { - size_t last_offs = 0; - last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); - if (last_buf >= 0) { - buf_offs[last_buf] = last_offs; - } - if (buf >= 0) { - size_t offs = buf_offs[buf]; - size_t size = buf_size[buf]; - void * data = buf_data[buf]; - ggml_set_scratch(ctx0, { offs, size, data, }); - } - last_buf = buf; - }; - - bool track_max_mem = false; - size_t buf_maxs[2] = { 0, 0 }; - - auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) { - if (buf < 0) return; - if (track_max_mem) { - size_t last_offs = 0; - last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); - if (last_buf >= 0) { - buf_offs[last_buf] = last_offs; - buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]); - } - } - buf_offs[buf] = 0; - if (track_max_mem && last_buf >= 0) { - size_t offs = buf_offs[last_buf]; - size_t size = buf_size[last_buf]; - void * data = buf_data[last_buf]; - ggml_set_scratch(ctx0, { offs, size, data, }); + auto set_name = [](struct ggml_tensor * t, const char * n) { + ggml_set_name(t, n); + if (t->grad) { + ggml_format_name(t->grad, "%s->grad", n); } }; + // rope has so much parameters that we make a custom function for it + auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale] + (struct ggml_tensor * t) -> struct ggml_tensor * { + // not capturing these, to silcence warnings + const int n_past = 0; + const int rope_mode = 0; - auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { - int64_t ne0 = n_embd/n_head; - int64_t ne1 = N; - int64_t ne2 = n_head; - int64_t ne3 = n_batch; - size_t nb0 = ggml_element_size(t); - size_t nb1 = nb0*ne0; - size_t nb2 = nb1*ne1; - size_t nb3 = nb2*ne2; - size_t offset = 0; - return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + return ggml_rope_custom(ctx, + t, n_past, n_rot, rope_mode, n_ctx, + rope_freq_base, rope_freq_scale); }; - auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { - int64_t ne0 = n_embd/n_head; - int64_t ne1 = N; - int64_t ne2 = n_head; - int64_t ne3 = n_batch; - size_t nb0 = ggml_element_size(t); - size_t nb1 = nb0*ne0; - size_t nb2 = nb1*ne1; - size_t nb3 = nb2*ne2; - size_t offset = nb3*ne3; - return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); - }; + set_name(tokens_input, "tokens_input"); + set_name(targets, "targets"); - auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { - int64_t ne0 = N; - int64_t ne1 = n_embd/n_head; - int64_t ne2 = n_head; - int64_t ne3 = n_batch; - size_t nb0 = ggml_element_size(t); - size_t nb1 = nb0*ne0; - size_t nb2 = nb1*ne1; - size_t nb3 = nb2*ne2; - size_t offset = 2*nb3*ne3; - return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); - }; - - auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * { - if (a == NULL) { - return b; - } else { - return ggml_add_inplace(ctx0, a, b); - } - }; - - use_buf(-1); - - model->tok_embeddings->grad = NULL; - model->norm->grad = NULL; - model->output->grad = NULL; - - for (int il = 0; il < n_layer; ++il) { - struct my_llama_layer & layer = model->layers[il]; - layer.attention_norm->grad = NULL; - layer.wq->grad = NULL; - layer.wk->grad = NULL; - layer.wv->grad = NULL; - layer.wo->grad = NULL; - layer.ffn_norm->grad = NULL; - layer.w1->grad = NULL; - layer.w2->grad = NULL; - layer.w3->grad = NULL; - } - - clr_buf(0); - clr_buf(1); - - use_buf(-1); - - struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); - memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); - - use_buf(-1); - - struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch); - - // need to remember these for the backward pass - std::vector t02L; t02L.resize(n_layer, NULL); - std::vector t03L; t03L.resize(n_layer, NULL); - std::vector t04L; t04L.resize(n_layer, NULL); - std::vector t05L; t05L.resize(n_layer, NULL); - std::vector t06L; t06L.resize(n_layer, NULL); - std::vector t07L; t07L.resize(n_layer, NULL); - std::vector t08L; t08L.resize(n_layer, NULL); - std::vector t09L; t09L.resize(n_layer, NULL); - std::vector t10L; t10L.resize(n_layer, NULL); - std::vector t11L; t11L.resize(n_layer, NULL); - std::vector t12L; t12L.resize(n_layer, NULL); - std::vector t13L; t13L.resize(n_layer, NULL); - std::vector t14L; t14L.resize(n_layer, NULL); - std::vector t15L; t15L.resize(n_layer, NULL); - std::vector t16L; t16L.resize(n_layer, NULL); - std::vector t17L; t17L.resize(n_layer, NULL); - std::vector t18L; t18L.resize(n_layer, NULL); - std::vector t19L; t19L.resize(n_layer, NULL); - std::vector t20L; t20L.resize(n_layer, NULL); - std::vector t21L; t21L.resize(n_layer, NULL); - std::vector t22L; t22L.resize(n_layer, NULL); - std::vector t23L; t23L.resize(n_layer, NULL); - std::vector t24L; t24L.resize(n_layer, NULL); - std::vector t25L; t25L.resize(n_layer, NULL); - std::vector t26L; t26L.resize(n_layer, NULL); - std::vector t27L; t27L.resize(n_layer, NULL); - std::vector t28L; t28L.resize(n_layer, NULL); - std::vector t29L; t29L.resize(n_layer, NULL); - std::vector t30L; t30L.resize(n_layer, NULL); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch); + struct ggml_tensor * t01 = ggml_get_rows(ctx, model->tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch); struct ggml_tensor * cur = t01; + std::vector checkpoints; + checkpoints.push_back(tokens_input); + checkpoints.push_back(targets); + checkpoints.push_back(t00); + checkpoints.push_back(t01); + + struct ggml_tensor * kv_scale; + if (!enable_flash_attn) { + kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head)); + } + for (int il = 0; il < n_layer; ++il) { - clr_buf(0); struct my_llama_layer & layer = model->layers[il]; - // tensors with values necessary for backward pass are in persistent buf(-1) - // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed. - use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch); - use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch); - use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode, 0)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); - use_buf(-1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch); - use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode, 0)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); - use_buf(-1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd); - use_buf(-1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head); - use_buf(-1); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); - use_buf(-1); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch); - use_buf(-1); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); - use_buf(-1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); - use_buf( 0); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch); - use_buf(-1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch); - use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); - use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch); - use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); - use_buf(-1); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch); - use_buf(-1); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch); - use_buf(-1); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch); - use_buf( 0); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch); - t02L[il] = t02; - t03L[il] = t03; - t04L[il] = t04; - t05L[il] = t05; - t06L[il] = t06; - t07L[il] = t07; - t08L[il] = t08; - t09L[il] = t09; - t10L[il] = t10; - t11L[il] = t11; - t12L[il] = t12; - t13L[il] = t13; - t14L[il] = t14; - t15L[il] = t15; - t16L[il] = t16; - t17L[il] = t17; - t18L[il] = t18; - t19L[il] = t19; - t20L[il] = t20; - t21L[il] = t21; - t22L[il] = t22; - t23L[il] = t23; - t24L[il] = t24; - t25L[il] = t25; - t26L[il] = t26; - t27L[il] = t27; - t28L[il] = t28; - t29L[il] = t29; - t30L[il] = t30; - - cur = t30; - } - clr_buf(0); - use_buf(0); - struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch); - struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); - struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); - use_buf(-1); - struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch); - struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch); - struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1); - - { - /* - tok_embeddings | grad_tok_embeddings = ggml_get_rows_back(grad_t01, t00) - L0_att_norm | grad_L0_att_norm = ggml_repeat_back(grad_t03L0, L0_att_norm.shape) - L0_wq | grad_L0_wq = ggml_out_prod(t04L0, grad_t05L0) - L0_wk | grad_L0_wk = ggml_out_prod(t04L0, grad_t08L0) - L0_wv | grad_L0_wv = ggml_out_prod(t04L0, ggml_transpose(grad_t11L0)) - L0_wo | grad_L0_wo = ggml_out_prod(t19L0, grad_t20L0) - L0_ffn_norm | grad_L0_ffn_norm = ggml_repeat_back(grad_t23L0, L0_ffn_norm.shape) - L0_w1 | grad_L0_w1 = ggml_out_prod(t24L0, grad_t26L0) - L0_w2 | grad_L0_w2 = ggml_out_prod(t28L0, grad_t29L0) - L0_w3 | grad_L0_w3 = ggml_out_prod(t24L0, grad_t25L0) - L1_att_norm | grad_L1_att_norm = ggml_repeat_back(grad_t03L1, L1_att_norm.shape) - L1_wq | grad_L1_wq = ggml_out_prod(t04L1, grad_t05L1) - L1_wk | grad_L1_wk = ggml_out_prod(t04L1, grad_t08L1) - L1_wv | grad_L1_wv = ggml_out_prod(t04L1, ggml_transpose(grad_t11L1)) - L1_wo | grad_L1_wo = ggml_out_prod(t19L1, grad_t20L1) - L1_ffn_norm | grad_L1_ffn_norm = ggml_repeat_back(grad_t23L1, L1_ffn_norm.shape) - L1_w1 | grad_L1_w1 = ggml_out_prod(t24L1, grad_t26L1) - L1_w2 | grad_L1_w2 = ggml_out_prod(t28L1, grad_t29L1) - L1_w3 | grad_L1_w3 = ggml_out_prod(t24L1, grad_t25L1) - norm | grad_norm = ggml_repeat_back(grad_t32, norm.shape) - output | grad_output = ggml_out_prod(t33, grad_t34) - | - t01 = ggml_get_rows(tok_embeddings, t00) | grad_t01 = grad_t21L0 + ggml_rms_norm_back(t01, grad_t02L0) - for layer: | - t02L0*= ggml_rms_norm (t01) | grad_t02L0 = ggml_mul(grad_t04L0, t03L0) - t03L0 = ggml_repeat (L0_att_norm, t02L0_shape) | grad_t03L0 = ggml_mul(grad_t04L0, t02L0) - t04L0*= ggml_mul (t02L0, t03L0) | grad_t04L0 = ggml_out_prod(L0_wv, grad_t11L0) + ggml_out_prod(L0_wk, ggml_transpose(grad_t08L0)) + ggml_out_prod(L0_wq, ggml_transpose(grad_t05L0)) - t05L0 = ggml_mul_mat (L0_wq, t04L0) | grad_t05L0 = ggml_reshape(grad_t06L0, t05L0_shape) - t06L0 = ggml_reshape_4d (t05L0, n_embd/n_head, n_head, N, n_batch) | grad_t06L0 = ggml_rope_back(grad_t07L0) - t07L0 = ggml_rope_inplace (t06L0) | grad_t07L0 = ggml_permute_back(grad_t13L0, 0, 2, 1, 3) = ggml_permute(grad_t13L0, 0, 2, 1, 3) - t08L0 = ggml_mul_mat (L0_wk, t04L0) | grad_t08L0 = ggml_reshape(grad_t09L0, t08L0_shape) - t09L0 = ggml_reshape_4d (t08L0, n_embd/n_head, n_head, N, n_batch) | grad_t09L0 = ggml_rope_back(grad_t10L0) - t10L0 = ggml_rope_inplace (t09L0) | grad_t10L0 = ggml_permute_back(grad_t14L0, 0, 2, 1, 3) = ggml_permute(grad_t14L0, 0, 2, 1, 3) - t11L0 = ggml_mul_mat (t04L0, L0_wv) | grad_t11L0 = ggml_reshape(grad_t12L0, t11L0_shape) - t12L0 = ggml_reshape_4d (t11L0, N, n_batch, n_embd/n_head, n_head) | grad_t12L0 = ggml_permute_back(grad_t15L0, 0, 3, 1, 2) = ggml_permute(grad_t15L0, 0, 2, 3, 1) - t13L0*= ggml_permute (t07L0, 0, 2, 1, 3) | grad_t13L0 = view__q(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) - t14L0*= ggml_permute (t10L0, 0, 2, 1, 3) | grad_t14L0 = view__k(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) - t15L0*= ggml_permute (t12L0, 0, 3, 1, 2) | grad_t15L0 = view__v(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) - t16L0 = ggml_flash_attn (t13L0, t14L0, t15L0) | grad_t16L0 = ggml_permute_back(grad_t17L0, 0, 2, 1, 3) = ggml_permute(grad_t17L0, 0, 2, 1, 3) - t17L0 = ggml_permute (t16L0, 0, 2, 1, 3) | grad_t17L0 = grad_t18L0 - t18L0 = ggml_cont (t17L0) | grad_t18L0 = ggml_reshape(grad_t19L0, t18L0_shape) - t19L0*= ggml_reshape_2d (t18L0, n_embd, N*n_batch) | grad_t19L0 = ggml_out_prod(L0_wo, ggml_transpose(grad_t20L0)) - t20L0 = ggml_mul_mat (L0_wo, t19L0) | grad_t20L0 = grad_t21L0 - t21L0*= ggml_add (t20L0, t01) | grad_t21L0 = grad_t30L0 + ggml_rms_norm_back(t21L0, grad_t22L0) - t22L0*= ggml_rms_norm (t21L0) | grad_t22L0 = ggml_mul(grad_t24L0, t23L0) - t23L0 = ggml_repeat (L0_ffn_norm, t22L0_shape) | grad_t23L0 = ggml_mul(grad_t24L0, t22L0) - t24L0*= ggml_mul (t23L0, t22L0) | grad_t24L0 = ggml_out_prod(L0_w1, ggml_transpose(grad_t26L0)) + ggml_out_prod(L0_w3, ggml_transpose(grad_t25L0)) - t25L0*= ggml_mul_mat (L0_w3, t24L0) | grad_t25L0 = ggml_mul(grad_t28L0, t27L0) - t26L0*= ggml_mul_mat (L0_w1, t24L0) | grad_t26L0 = ggml_silu_back(t26L0, grad_t27L0) - t27L0*= ggml_silu (t26L0) | grad_t27L0 = ggml_mul(grad_t28L0, t25L0) - t28L0*= ggml_mul (t27L0, t25L0) | grad_t28L0 = ggml_out_prod(L0_w2, ggml_transpose(grad_t29L0)) - t29L0 = ggml_mul_mat (L0_w2, t28L0) | grad_t29L0 = grad_t30L0 - t30L0*= ggml_add (t21L0, t29L0) | grad_t30L0 = ggml_rms_norm_back(t30L0, grad_t02L1) + grad_t21L1 - ^ - t02L1*= ggml_rms_norm (t30L0) | grad_t02L1 = ggml_mul(grad_t04L1, t03L1) - t03L1 = ggml_repeat (L1_att_norm, t02L1_shape) | grad_t03L1 = ggml_mul(grad_t04L1, t02L1) - t04L1*= ggml_mul (t02L1, t03L1) | grad_t04L1 = ggml_out_prod(L1_wv, grad_t11L1) + ggml_out_prod(L1_wk, ggml_transpose(grad_t08L1)) + ggml_out_prod(L1_wq, ggml_transpose(grad_t05L1)) - t05L1 = ggml_mul_mat (L1_wq, t04L1) | grad_t05L1 = ggml_reshape(grad_t06L1, t05L1_shape) - t06L1 = ggml_reshape_4d (t05L1, n_embd/n_head, n_head, N, n_batch) | grad_t06L1 = ggml_rope_back(grad_t07L1) - t07L1 = ggml_rope_inplace (t06L1) | grad_t07L1 = ggml_permute_back(grad_t13L1, 0, 2, 1, 3) = ggml_permute(grad_t13L1, 0, 2, 1, 3) - t08L1 = ggml_mul_mat (L1_wk, t04L1) | grad_t08L1 = ggml_reshape(grad_t09L1, t08L1_shape) - t09L1 = ggml_reshape_4d (t08L1, n_embd/n_head, n_head, N, n_batch) | grad_t09L1 = ggml_rope_back(grad_t10L1) - t10L1 = ggml_rope_inplace (t09L1) | grad_t10L1 = ggml_permute_back(grad_t14L1, 0, 2, 1, 3) = ggml_permute(grad_t14L1, 0, 2, 1, 3) - t11L1 = ggml_mul_mat (t04L1, L1_wv) | grad_t11L1 = ggml_reshape(grad_t12L1, t11L1_shape) - t12L1 = ggml_reshape_4d (t11L1, N, n_batch, n_embd/n_head, n_head) | grad_t12L1 = ggml_permute_back(grad_t15L1, 0, 3, 1, 2) = ggml_permute(grad_t15L1, 0, 2, 3, 1) - t13L1*= ggml_permute (t07L1, 0, 2, 1, 3) | grad_t13L1 = view__q(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) - t14L1*= ggml_permute (t10L1, 0, 2, 1, 3) | grad_t14L1 = view__k(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) - t15L1*= ggml_permute (t12L1, 0, 3, 1, 2) | grad_t15L1 = view__v(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) - t16L1 = ggml_flash_attn (t13L1, t14L1, t15L1) | grad_t16L1 = ggml_permute_back(grad_t17L1, 0, 2, 1, 3) = ggml_permute(grad_t17L1, 0, 2, 1, 3) - t17L1 = ggml_permute (t16L1, 0, 2, 1, 3) | grad_t17L1 = grad_t18L1 - t18L1 = ggml_cont (t17L1) | grad_t18L1 = ggml_reshape(grad_t19L1, t18L1_shape) - t19L1*= ggml_reshape_2d (t18L1, n_embd, N*n_batch) | grad_t19L1 = ggml_out_prod(L1_wo, ggml_transpose(grad_t20L1)) - t20L1 = ggml_mul_mat (L1_wo, t19L1) | grad_t20L1 = grad_t21L1 - t21L1*= ggml_add (t20L1, t30L0) | grad_t21L1 = grad_t30L1 + ggml_rms_norm_back(t21L1, grad_t22L1) - t22L1*= ggml_rms_norm (t21L1) | grad_t22L1 = ggml_mul(grad_t24L1, t23L1) - t23L1 = ggml_repeat (L1_ffn_norm, t22L1_shape) | grad_t23L1 = ggml_mul(grad_t24L1, t22L1) - t24L1*= ggml_mul (t23L1, t22L1) | grad_t24L1 = ggml_out_prod(L1_w1, ggml_transpose(grad_t26L1)) + ggml_out_prod(L1_w3, ggml_transpose(grad_t25L1)) - t25L1*= ggml_mul_mat (L1_w3, t24L1) | grad_t25L1 = ggml_mul(grad_t28L1, t27L1) - t26L1*= ggml_mul_mat (L1_w1, t24L1) | grad_t26L1 = ggml_silu_back(t26L1, grad_t27L1) - t27L1*= ggml_silu (t26L1) | grad_t27L1 = ggml_mul(grad_t28L1, t25L1) - t28L1*= ggml_mul (t27L1, t25L1) | grad_t28L1 = ggml_out_prod(L1_w2, ggml_transpose(grad_t29L1)) - t29L1 = ggml_mul_mat (L1_w2, t28L1) | grad_t29L1 = grad_t30L1 - t30L1*= ggml_add (t21L1, t29L1) | grad_t30L1 = ggml_rms_norm_back(t30L1, grad_t31) - ^ - t31 = ggml_rms_norm (t30L1) | grad_t31 = ggml_mul(grad_t33, t32) - t32 = ggml_repeat (norm, t31.shape) | grad_t32 = ggml_mul(grad_t33, t31) - t33 = ggml_mul (t32, t31) | grad_t33 = ggml_out_prod(output, ggml_transpose(grad_t34)) - t34 = ggml_mul_mat (output, t33) | grad_t34 = ggml_reshape(grad_t35, t34.shape) - t35 = ggml_reshape_3d (t34, n_vocab, N, n_batch) | grad_t35 = ggml_cross_entropy_loss_back(t35, targets, grad_t36) - t36 = ggml_cross_entropy_loss(t35, targets) | grad_t36 = 1 (optimizer) - tensors marked with * need to be stored until grad computation - tensors during grad computation are all temporary - */ - } - - *gb = *gf; - - // t36->grad gets set to one by optimizer, so we need the tensor. - // initialize it with 1.0f to make sure. - use_buf(-1); - t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f)); - - use_buf(0); - t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch); - t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch); - t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch); - t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch); - - use_buf(-1); - - model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd); - model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab); - - clr_buf(1); - use_buf(1); - t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch); - - struct ggml_tensor * back_layer_inp = t31; - struct ggml_tensor * grad_layer_inp = NULL; - - for (int k = 0; k < n_layer; ++k) { - int il = n_layer-1-k; - struct my_llama_layer & layer = model->layers[il]; - - struct ggml_tensor * t02 = t02L[il]; - struct ggml_tensor * t03 = t03L[il]; - struct ggml_tensor * t04 = t04L[il]; - struct ggml_tensor * t05 = t05L[il]; - struct ggml_tensor * t06 = t06L[il]; - struct ggml_tensor * t07 = t07L[il]; - struct ggml_tensor * t08 = t08L[il]; - struct ggml_tensor * t09 = t09L[il]; - struct ggml_tensor * t10 = t10L[il]; - struct ggml_tensor * t11 = t11L[il]; - struct ggml_tensor * t12 = t12L[il]; - struct ggml_tensor * t13 = t13L[il]; - struct ggml_tensor * t14 = t14L[il]; - struct ggml_tensor * t15 = t15L[il]; - struct ggml_tensor * t16 = t16L[il]; - struct ggml_tensor * t17 = t17L[il]; - struct ggml_tensor * t18 = t18L[il]; - struct ggml_tensor * t19 = t19L[il]; - struct ggml_tensor * t20 = t20L[il]; - struct ggml_tensor * t21 = t21L[il]; - struct ggml_tensor * t22 = t22L[il]; - struct ggml_tensor * t23 = t23L[il]; - struct ggml_tensor * t24 = t24L[il]; - struct ggml_tensor * t25 = t25L[il]; - struct ggml_tensor * t26 = t26L[il]; - struct ggml_tensor * t27 = t27L[il]; - struct ggml_tensor * t28 = t28L[il]; - struct ggml_tensor * t29 = t29L[il]; - struct ggml_tensor * t30 = t30L[il]; - - clr_buf(0); - use_buf(0); - t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); - if (grad_layer_inp) { - t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch); + struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch); + struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch); + struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch); + struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch); + struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd); + struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head); + struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); + struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch); + struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); + struct ggml_tensor * t16; + if (enable_flash_attn) { + t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); + } else { + struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch); + struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch); + struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch); + struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch); + t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); } - clr_buf(1); - t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch); - t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch); - t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch); - t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch); - t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch); - t24->grad = expand(gb, ggml_add_inplace(ctx0, - ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)), - ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch); - t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch); - t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch); - use_buf(1); - t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch); - grad_layer_inp = t21; - use_buf(0); - t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch); - t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch); - t18->grad = expand(gb, ggml_reshape_4d(ctx0, t19->grad, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch); - t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch); - t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch); - struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch); - t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch); - t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch); - t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch); - t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head); - t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd); - t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch); - t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); - t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch); - t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch); - t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); - t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch); - t04->grad = expand(gb, ggml_add_inplace(ctx0, - ggml_add_inplace(ctx0, - ggml_out_prod(ctx0, layer.wv, t11->grad), - ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))), - ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch); - t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch); - use_buf(1); - t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, ggml_repeat(ctx0, layer.attention_norm, t02))); assert_shape_2d(t02->grad, n_embd, N*n_batch); - back_layer_inp = t02; - // use_buf(0); - - use_buf(-1); - layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd); - layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd); - layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd); - layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd); - layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd); - layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd); - layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff); - layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd); - layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff); - // use_buf(0); + struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch); + struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch); + struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch); + struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch); + struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch); + struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch); + struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch); + struct ggml_tensor * t26 = ggml_mul_mat (ctx, layer.w1, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch); + struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch); + struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch); + struct ggml_tensor * t29 = ggml_mul_mat (ctx, layer.w2, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch); + struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch); + cur = t30; + checkpoints.push_back(cur); + } + struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch); + struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch); + struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch); + struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch); + struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1); + + checkpoints.push_back(t31); + checkpoints.push_back(t32); + checkpoints.push_back(t33); + checkpoints.push_back(t34); + checkpoints.push_back(t35); + checkpoints.push_back(t36); + + ggml_build_forward_expand(gf, t36); + + if (enable_checkpointing) { + ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size()); + } else { + *gb = *gf; + ggml_build_backward_expand(ctx, gf, gb, true); + } + + if (alloc) { + // make sure some tensors are not reallocated by inserting new temporary nodes depending on them + int n_leafs_before = gb->n_leafs; + int n_nodes_before = gb->n_nodes; + struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f); + // output tensors + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); + // input gradient + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); + GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad)); + ggml_allocr_alloc(alloc, t36->grad); + // gradient tensors (will be set to zero by ggml_graph_reset) + // pinning these produces large unnecessary memory overhead, which will be resolved by PR 2632 + for (int i = 0; i < gf->n_nodes; ++i) { + if (!gf->grads[i]) continue; + if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) { + ggml_allocr_alloc(alloc, gf->grads[i]); + } + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one)); + } + // allocating checkpoints in one block to reduce memory fragmentation + // note: they will be freed in reverse order + for (int i = 0; i < (int) checkpoints.size(); ++i) { + if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) { + ggml_allocr_alloc(alloc, checkpoints[i]); + } + } + + //int n_leafs_after = gb->n_leafs; + //int n_nodes_after = gb->n_nodes; + + ggml_allocr_alloc_graph(alloc, gb); + + // remove the additional nodes and leafs + for (int i = n_leafs_before; i < gb->n_leafs; ++i) { + gb->leafs[i] = NULL; + } + for (int i = n_nodes_before; i < gb->n_nodes; ++i) { + gb->nodes[i] = NULL; + } + gb->n_leafs = n_leafs_before; + gb->n_nodes = n_nodes_before; } - clr_buf(0); - use_buf(0); - t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch); - use_buf(-1); - model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); - // clr_buf(1); - // clr_buf(0); *logits = t35; - - if (track_max_mem) { - printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]); - printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]); - } - - // now that all grads are created, set the graph leafs and grads - graph_set_leafs_grads(gf); - graph_set_leafs_grads(gb); - return t36; } @@ -1962,42 +874,6 @@ void print_matrix(struct ggml_tensor * probs) { } } - -void print_token(struct llama_context * ctx, llama_token token) { - printf("%s", llama_token_to_piece(ctx, token).c_str()); -} - -void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) { - for (int i=0; ine[0]; ++i) { - int token = ggml_get_i32_1d(tokens, i); - print_token(ctx, token); - } -} - -void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) { - for (int i1=0; i1ne[1]; ++i1) { - //int num_newline = 0; - for (int i0=0; i0ne[0]; ++i0) { - int token = get_i32_2d(tokens, i0, i1); - print_token(ctx, token); - // bool isnl = (token == llama_token_nl()); - // if (isnl) { - // ++num_newline; - // } - // if (isnl) { - // if (num_newline < 2) { - // print_token(ctx, token); - // } else { - // printf("\\n"); - // } - // } else { - // print_token(ctx, token); - // } - } - printf("\n--\n"); - } -} - void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { int n_tokens = tokens_input->ne[0]; int n_vocab = target_logits->ne[0]; @@ -2033,51 +909,27 @@ void get_example_targets_batch(struct llama_context * lctx, const int * train_sa ggml_set_f32(target_logits, -1.0f/n_vocab); ggml_set_f32(target_probs, 0.0f); + // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); for (int k=0; kne[0]; - int n_vocab = target_logits->ne[0]; - for (int i=0; i chars(len); - read_raw(chars.data(), len); - return std::string(chars.data(), len); - } - - void write_raw(const void * ptr, size_t size) { - if (size == 0) { - return; - } - errno = 0; - size_t ret = std::fwrite(ptr, size, 1, fp); - if (ret != 1) { - throw std::runtime_error(format("write error: %s", strerror(errno))); - } - } - - void write_u32(std::uint32_t val) { - write_raw(&val, sizeof(val)); - } - - ~llama_file() { - if (fp) { - std::fclose(fp); - } - } -}; - int tokenize_file(struct llama_context * lctx, const char * filename, std::vector& out) { - struct llama_file f(filename, "rb"); + FILE * fp = std::fopen(filename, "rb"); + if (fp == NULL) { + return 0; + } + +#ifdef _WIN32 + GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_END) == 0); +#else + GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_END) == 0); +#endif + + size_t size = 0; +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); + size = ret; +#else + long ret = std::ftell(fp); + size = ret; +#endif + +#ifdef _WIN32 + GGML_ASSERT(_fseeki64(fp, (__int64) 0, SEEK_SET) == 0); +#else + GGML_ASSERT(std::fseek(fp, (long) 0, SEEK_SET) == 0); +#endif std::vector buf; - buf.resize(f.size+1); + buf.resize(size+1); + out.resize(size+1); - f.read_raw(buf.data(), f.size); - buf[f.size] = '\0'; + if (std::fread(buf.data(), size, 1, fp) != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + + buf[size] = '\0'; int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); if (n_tokens < 0) { out.resize(-n_tokens); - llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); + n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); } + GGML_ASSERT(n_tokens >= 0); + out.resize(n_tokens); bool verify = false; if (verify) { @@ -2238,438 +1040,466 @@ void shuffle_ints(int * begin, int * end) { }); } -struct my_llama_sampler_params { - float temp = 0.0f; // <= 0.0 disabled - int top_k = 20; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - int repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float repeat_penalty = 1.0f; // 1.0 = disabled - float alpha_presence = 0.0f; // 0.0 = disabled - float alpha_frequency = 0.0f; // 0.0 = disabled - int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = true; // consider newlines as a repeatable token -}; - -struct my_llama_sampler { - struct llama_context * ctx = NULL; - my_llama_sampler_params params; - - int n_vocab = 0; - int n_ctx = 0; - - float mirostat_mu; - - std::vector candidates; - llama_token_data_array candidates_p; - -}; - -void init_sampler(struct my_llama_sampler * sampler, struct llama_context * ctx) { - sampler->ctx = ctx; - sampler->n_vocab = llama_n_vocab(sampler->ctx); - sampler->n_ctx = llama_n_ctx(sampler->ctx); - sampler->mirostat_mu = 2.0f * sampler->params.mirostat_tau; +#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ +{ \ + const std::string skey(key); \ + const int kid = gguf_find_key(ctx, skey.c_str()); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ + } \ } -llama_token sample(struct my_llama_sampler * sampler, float * logits, const llama_token * last_tokens, int n_last_tokens) { - GGML_ASSERT(sampler->ctx != NULL); - struct llama_context * ctx = sampler->ctx; +bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(a != NULL); + GGML_ASSERT(b != NULL); + GGML_ASSERT(a->type == b->type); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b)); - sampler->candidates.resize(sampler->n_vocab); - for (llama_token token_id = 0; token_id < sampler->n_vocab; ++token_id) { - sampler->candidates[token_id].id = token_id; - sampler->candidates[token_id].logit = logits[token_id]; - sampler->candidates[token_id].p = 0.0; + return true; +} + +void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) { + if (dst == NULL) { + return; } + struct ggml_tensor * t = ggml_get_tensor(ctx, name); + GGML_ASSERT(are_same_layout(dst, t)); + memcpy(dst->data, t->data, ggml_nbytes(t)); - llama_token_data_array * candidates_p = & sampler->candidates_p; - - candidates_p->data = sampler->candidates.data(); - candidates_p->size = sampler->candidates.size(); - candidates_p->sorted = false; - - const auto params = sampler->params; - - // Apply penalties - const float nl_logit = logits[llama_token_nl(ctx)]; - - const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx); - - llama_sample_repetition_penalty( - ctx, - candidates_p, - last_tokens + n_last_tokens - n_last, - n_last, - params.repeat_penalty); - llama_sample_frequency_and_presence_penalties( - ctx, - candidates_p, - last_tokens + n_last_tokens - n_last, - n_last, - params.alpha_frequency, - params.alpha_presence); - - if (!params.penalize_nl) { - logits[llama_token_nl(ctx)] = nl_logit; + if (strlen(ggml_get_name(dst)) == 0) { + ggml_set_name(dst, name); } +} - llama_token token = 0; - if (params.temp <= 0) { - // Greedy sampling - token = llama_sample_token_greedy(ctx, candidates_p); +void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) { + // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read + + uint32_t file_version; + GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION); + GGML_ASSERT(file_version == 0); + + GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT); + GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT); + GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED); + + uint64_t nx; + GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT); + opt->nx = (size_t) nx; + + // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know + + std::string opt_type; + GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE); + if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) { + opt->params.type = GGML_OPT_ADAM; + + GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS); + GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS); + GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT); + + GGML_ASSERT(opt->ctx != NULL); + ggml_opt_init(opt->ctx, opt, opt->params, opt->nx); + + read_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS); + read_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); + read_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); + } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) { + opt->params.type = GGML_OPT_LBFGS; + + GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT); + GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS); + GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP); + GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J); + GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K); + GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END); + GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT); + + GGML_ASSERT(opt->ctx != NULL); + ggml_opt_init(opt->ctx, opt, opt->params, opt->nx); + + read_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS); + read_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS); + read_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS); + read_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS); + read_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION); + read_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES); + read_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA); + read_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS); + read_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); + read_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); } else { - if (params.mirostat == 1) { - int mirostat_m = 100; - llama_sample_temperature(ctx, candidates_p, params.temp); - token = llama_sample_token_mirostat(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, mirostat_m, &sampler->mirostat_mu); - } else if (params.mirostat == 2) { - llama_sample_temperature(ctx, candidates_p, params.temp); - token = llama_sample_token_mirostat_v2(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, &sampler->mirostat_mu); - } else { - // Temperature sampling - llama_sample_top_k (ctx, candidates_p, params.top_k, 1); - llama_sample_tail_free (ctx, candidates_p, params.tfs_z, 1); - llama_sample_typical (ctx, candidates_p, params.typical_p, 1); - - llama_sample_top_p (ctx, candidates_p, params.top_p, 1); - llama_sample_temperature (ctx, candidates_p, params.temp); - token = llama_sample_token(ctx, candidates_p); - } - } - return token; -} - -void set_logits_masked(struct ggml_tensor * logits, std::vector& mask, float value) { - GGML_ASSERT(logits->ne[0] == (int64_t) mask.size()); - for (int i2 = 0; i2 < logits->ne[2]; ++i2) { - for (int i1 = 0; i1 < logits->ne[1]; ++i1) { - for (int i0 = 0; i0 < logits->ne[0]; ++i0) { - if (!mask[i0]) continue; - float * ptr = (float *) ((char *) logits->data + i2*logits->nb[2] + i1*logits->nb[1] + i0*logits->nb[0]); - *ptr = value; - } - } + throw std::runtime_error("unknown optimizer type\n"); } } -void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) { - if (tensor == NULL) { - file->write_u32(0); - file->write_u32(0); - file->write_u32(GGML_TYPE_F32); - file->seek((0-file->tell()) & 31, SEEK_CUR); - return; - } - const char * name = ggml_get_name(tensor); - uint32_t name_len = strlen(name); - uint32_t nd = tensor->n_dims; - uint32_t ne[4] = { (uint32_t)tensor->ne[0], - (uint32_t)tensor->ne[1], - (uint32_t)tensor->ne[2], - (uint32_t)tensor->ne[3] }; - file->write_u32(nd); - file->write_u32(name_len); - file->write_u32(tensor->type); - file->write_raw(ne, sizeof(ne[0]) * nd); - file->write_raw(name, name_len); - file->seek((0-file->tell()) & 31, SEEK_CUR); - file->write_raw(tensor->data, ggml_nbytes(tensor)); -} - -void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) { - int32_t nd = file->read_u32(); - GGML_ASSERT(nd == tensor->n_dims); - - uint32_t name_len = file->read_u32(); - enum ggml_type type = (enum ggml_type) file->read_u32(); - GGML_ASSERT(type == tensor->type); - - uint32_t ne[4]; - file->read_raw(ne, sizeof(ne[0]) * nd); - for (int i=0; ine[i]); - } - - std::string name = file->read_string(name_len); - GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0); - - file->seek((0-file->tell()) & 31, SEEK_CUR); - file->read_raw(tensor->data, ggml_nbytes(tensor)); -} - -void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) { - const uint32_t version = 0; - GGML_ASSERT(opt->nx >= 0); - GGML_ASSERT(opt->iter >= 0); - file->write_u32(version); - file->write_raw(&opt->params, sizeof(opt->params)); - file->write_raw(&opt->nx, sizeof(opt->nx)); - file->write_raw(&opt->iter, sizeof(opt->iter)); - file->write_u32((uint32_t) opt->just_initialized); - switch (opt->params.type) { - case GGML_OPT_ADAM: - { - GGML_ASSERT(opt->adam.x != NULL); - write_tensor(file, opt->adam.x); - write_tensor(file, opt->adam.g1); - write_tensor(file, opt->adam.g2); - write_tensor(file, opt->adam.m); - write_tensor(file, opt->adam.v); - write_tensor(file, opt->adam.mh); - write_tensor(file, opt->adam.vh); - write_tensor(file, opt->adam.pf); - file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); - file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); - file->write_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); - } break; - case GGML_OPT_LBFGS: - { - GGML_ASSERT(opt->adam.x != NULL); - write_tensor(file, opt->lbfgs.x); - write_tensor(file, opt->lbfgs.xp); - write_tensor(file, opt->lbfgs.g); - write_tensor(file, opt->lbfgs.gp); - write_tensor(file, opt->lbfgs.d); - write_tensor(file, opt->lbfgs.pf); - write_tensor(file, opt->lbfgs.lmal); - write_tensor(file, opt->lbfgs.lmys); - write_tensor(file, opt->lbfgs.lms); - write_tensor(file, opt->lbfgs.lmy); - file->write_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); - file->write_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); - file->write_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); - file->write_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); - file->write_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); - file->write_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); - } break; - } -} - -void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { - uint32_t version = file->read_u32(); - GGML_ASSERT(version == 0); - - file->read_raw(&opt->params, sizeof(opt->params)); - file->read_raw(&opt->nx, sizeof(opt->nx)); - ggml_opt_init(ctx, opt, opt->params, opt->nx); - - file->read_raw(&opt->iter, sizeof(opt->iter)); - opt->just_initialized = (bool) file->read_u32(); +void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) { + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past); + gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter); + gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized); switch (opt->params.type) { case GGML_OPT_ADAM: { - read_tensor(file, opt->adam.x); - read_tensor(file, opt->adam.g1); - read_tensor(file, opt->adam.g2); - read_tensor(file, opt->adam.m); - read_tensor(file, opt->adam.v); - read_tensor(file, opt->adam.mh); - read_tensor(file, opt->adam.vh); - if (opt->adam.pf) { read_tensor(file, opt->adam.pf); } - file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); - file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); - file->read_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); + gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement); + + ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS); + ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); + if (opt->adam.pf) { + ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); + } + + gguf_add_tensor(fctx, opt->adam.m); + gguf_add_tensor(fctx, opt->adam.v); + if (opt->adam.pf) { + gguf_add_tensor(fctx, opt->adam.pf); + } } break; case GGML_OPT_LBFGS: { - GGML_ASSERT(opt->adam.x != NULL); - read_tensor(file, opt->lbfgs.x); - read_tensor(file, opt->lbfgs.xp); - read_tensor(file, opt->lbfgs.g); - read_tensor(file, opt->lbfgs.gp); - read_tensor(file, opt->lbfgs.d); - if (opt->lbfgs.pf) { read_tensor(file, opt->lbfgs.pf); } - read_tensor(file, opt->lbfgs.lmal); - read_tensor(file, opt->lbfgs.lmys); - read_tensor(file, opt->lbfgs.lms); - read_tensor(file, opt->lbfgs.lmy); - file->read_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); - file->read_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); - file->read_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); - file->read_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); - file->read_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); - file->read_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); + gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best); + gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step); + gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j); + gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k); + gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end); + gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement); + + ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS); + ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS); + ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS); + ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS); + ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION); + if (opt->lbfgs.pf) { + ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES); + } + ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA); + ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS); + ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); + ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); + + gguf_add_tensor(fctx, opt->lbfgs.x); + gguf_add_tensor(fctx, opt->lbfgs.xp); + gguf_add_tensor(fctx, opt->lbfgs.g); + gguf_add_tensor(fctx, opt->lbfgs.gp); + gguf_add_tensor(fctx, opt->lbfgs.d); + if (opt->lbfgs.pf) { + gguf_add_tensor(fctx, opt->lbfgs.pf); + } + gguf_add_tensor(fctx, opt->lbfgs.lmal); + gguf_add_tensor(fctx, opt->lbfgs.lmys); + gguf_add_tensor(fctx, opt->lbfgs.lms); + gguf_add_tensor(fctx, opt->lbfgs.lmy); } break; } } -void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename) { - struct llama_file file(filename, "wb"); - if (file.fp == NULL) { - return; +void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model) { + // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read + std::string arch; + + std::vector keybuf; + keybuf.resize(512); + auto kv = [&arch, &keybuf](const char * key) -> const char * { + snprintf(keybuf.data(), keybuf.size(), key, arch.c_str()); + return keybuf.data(); + }; + + std::vector tn_buf; + tn_buf.resize(GGML_MAX_NAME); + auto tn = [&tn_buf](const char * key) -> const char * { + snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key); + return tn_buf.data(); + }; + auto tni = [&tn_buf](const char * key, int bid) -> const char * { + snprintf(tn_buf.data(), tn_buf.size(), key, bid); + std::string s = tn_buf.data(); + snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str()); + return tn_buf.data(); + }; + + GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE); + GGML_ASSERT(arch == "llama"); + + uint32_t ftype_u; + GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE); + GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32); + + // n_ctx was not saved in earlier checkpoint file versions, so we make it optional here + GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH)); + + GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); + GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); + GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); + GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + + model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head; + GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + + float rope_freq_scale = 1.0f; + GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); + GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (rope_freq_scale != 1.0f) { + model->hparams.rope_freq_scale = 1.0f / rope_freq_scale; } - const uint32_t magic = 'ggcp'; - const uint32_t version = 0; + init_model(model); - file.write_u32(magic); - file.write_u32(version); - file.write_u32(model->train_its); - file.write_u32(model->train_samples); - file.write_u32(model->train_tokens); - file.write_u32(model->hparams.n_vocab); - file.write_u32(model->hparams.n_embd); - file.write_u32(model->hparams.n_mult); - file.write_u32(model->hparams.n_head); - file.write_u32(model->hparams.n_layer); - file.write_u32(model->hparams.n_rot); - - write_tensor(&file, model->tok_embeddings); - write_tensor(&file, model->norm); - write_tensor(&file, model->output); + read_tensor_by_name(model->tok_embeddings, f_ggml_ctx, tn(LLM_TENSOR_TOKEN_EMBD)); + read_tensor_by_name(model->norm, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT_NORM)); + read_tensor_by_name(model->output, f_ggml_ctx, tn(LLM_TENSOR_OUTPUT)); for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { auto & layer = model->layers[i]; - write_tensor(&file, layer.attention_norm); - write_tensor(&file, layer.wq); - write_tensor(&file, layer.wk); - write_tensor(&file, layer.wv); - write_tensor(&file, layer.wo); - write_tensor(&file, layer.ffn_norm); - write_tensor(&file, layer.w1); - write_tensor(&file, layer.w2); - write_tensor(&file, layer.w3); + read_tensor_by_name(layer.attention_norm, f_ggml_ctx, tni(LLM_TENSOR_ATTN_NORM, i)); + read_tensor_by_name(layer.wq, f_ggml_ctx, tni(LLM_TENSOR_ATTN_Q, i)); + read_tensor_by_name(layer.wk, f_ggml_ctx, tni(LLM_TENSOR_ATTN_K, i)); + read_tensor_by_name(layer.wv, f_ggml_ctx, tni(LLM_TENSOR_ATTN_V, i)); + read_tensor_by_name(layer.wo, f_ggml_ctx, tni(LLM_TENSOR_ATTN_OUT, i)); + read_tensor_by_name(layer.ffn_norm, f_ggml_ctx, tni(LLM_TENSOR_FFN_NORM, i)); + read_tensor_by_name(layer.w1, f_ggml_ctx, tni(LLM_TENSOR_FFN_GATE, i)); + read_tensor_by_name(layer.w2, f_ggml_ctx, tni(LLM_TENSOR_FFN_DOWN, i)); + read_tensor_by_name(layer.w3, f_ggml_ctx, tni(LLM_TENSOR_FFN_UP, i)); } - - write_opt_context(&file, opt); } -bool load_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename, bool init) { - struct llama_file file(filename, "rb"); +void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model) { + const char * arch = "llama"; + enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32; - uint32_t magic; - uint32_t version; + std::vector keybuf; + keybuf.resize(512); + auto kv = [arch, &keybuf](const char * key) -> const char * { + snprintf(keybuf.data(), keybuf.size(), key, arch); + return keybuf.data(); + }; - uint32_t train_its = 0; - uint32_t train_samples = 0; - uint32_t train_tokens = 0; + // set arch + gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch); + gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype); - if (file.fp) { - printf("%s: Loading model from '%s'.\n", __func__, filename); - magic = file.read_u32(); - GGML_ASSERT(magic == 'ggcp'); - version = file.read_u32(); - GGML_ASSERT(version == 0); - train_its = file.read_u32(); - train_samples = file.read_u32(); - train_tokens = file.read_u32(); - model->hparams.n_vocab = file.read_u32(); - model->hparams.n_embd = file.read_u32(); - model->hparams.n_mult = file.read_u32(); - model->hparams.n_head = file.read_u32(); - model->hparams.n_layer = file.read_u32(); - model->hparams.n_rot = file.read_u32(); - print_params(&model->hparams); - } + // set hparams + gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx ); + gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd ); + gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff ); + gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head ); + gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer ); + gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot ); - if (init) { - init_model(model); - } + gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps ); + gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base ); // TODO load in llama.cpp + gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), 1.0f / model->hparams.rope_freq_scale ); - if (file.fp) { - model->train_its = train_its; - model->train_samples = train_samples; - model->train_tokens = train_tokens; - } + // set vocab by copying from vocab_model gguf file + { + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + struct gguf_context * vctx = gguf_init_from_file(fn_vocab_model, params); - printf("%s: Training iterations: %u.\n", __func__, model->train_its); - printf("%s: Training samples: %u.\n", __func__, model->train_samples); - printf("%s: Training tokens: %u.\n", __func__, model->train_tokens); + const int token_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_LIST)); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + const uint32_t n_vocab = gguf_get_arr_n(vctx, token_idx); - if (file.fp) { - read_tensor(&file, model->tok_embeddings); - read_tensor(&file, model->norm); - read_tensor(&file, model->output); - - for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { - auto & layer = model->layers[i]; - - read_tensor(&file, layer.attention_norm); - read_tensor(&file, layer.wq); - read_tensor(&file, layer.wk); - read_tensor(&file, layer.wv); - read_tensor(&file, layer.wo); - read_tensor(&file, layer.ffn_norm); - read_tensor(&file, layer.w1); - read_tensor(&file, layer.w2); - read_tensor(&file, layer.w3); + const int score_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_SCORES)); + if (score_idx == -1) { + throw std::runtime_error("cannot find tokenizer scores in model file\n"); } - read_opt_context(&file, model->ctx, opt); + const float * scores = (const float * ) gguf_get_arr_data(vctx, score_idx); + + const int toktype_idx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE)); + if (toktype_idx == -1) { + throw std::runtime_error("cannot find token type list in GGUF file\n"); + } + + const int * toktypes = (const int * ) gguf_get_arr_data(vctx, toktype_idx); + + std::string tokenizer_name; + GGUF_GET_KEY(vctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL)); + + gguf_set_val_str(fctx, kv(LLM_KV_TOKENIZER_MODEL), tokenizer_name.c_str()); + gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_SCORES), GGUF_TYPE_FLOAT32, scores, n_vocab); + gguf_set_arr_data(fctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE), GGUF_TYPE_INT32, toktypes, n_vocab); + + int32_t special_bos_id = 1; + int32_t special_eos_id = 2; + int32_t special_unk_id = 0; + int32_t special_sep_id = -1; + int32_t special_pad_id = -1; + if (tokenizer_name == "llama") { + // default special tokens + special_bos_id = 1; + special_eos_id = 2; + special_unk_id = 0; + special_sep_id = -1; + special_pad_id = -1; + } else if (tokenizer_name == "gpt2") { + // read and copy bpe merges + const int merges_keyidx = gguf_find_key(vctx, kv(LLM_KV_TOKENIZER_MERGES)); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(vctx, merges_keyidx); + + std::vector merges; + merges.resize(n_merges); + for (int i = 0; i < n_merges; i++) { + merges[i] = gguf_get_arr_str(vctx, merges_keyidx, i); + } + gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_MERGES), merges.data(), n_merges); + + // default special tokens + special_bos_id = 11; + special_eos_id = 11; + special_unk_id = -1; + special_sep_id = -1; + special_pad_id = -1; + } else { + fprintf(stderr, "%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); + fprintf(stderr, "%s: using default tokenizer: 'llama'", __func__); + } + + std::vector tokens; + tokens.resize(n_vocab); + for (uint32_t i = 0; i < n_vocab; i++) { + tokens[i] = gguf_get_arr_str(vctx, token_idx, i); + } + gguf_set_arr_str(fctx, kv(LLM_KV_TOKENIZER_LIST), tokens.data(), n_vocab); + + GGUF_GET_KEY(vctx, special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); + GGUF_GET_KEY(vctx, special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); + GGUF_GET_KEY(vctx, special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); + GGUF_GET_KEY(vctx, special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); + GGUF_GET_KEY(vctx, special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); + + gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_BOS_ID), special_bos_id); + gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_EOS_ID), special_eos_id); + gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_UNK_ID), special_unk_id); + gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_SEP_ID), special_sep_id); + gguf_set_val_u32(fctx, kv(LLM_KV_TOKENIZER_PAD_ID), special_pad_id); + + gguf_free(vctx); } - return (file.fp != NULL); + // add tensors + gguf_add_tensor(fctx, model->tok_embeddings); + gguf_add_tensor(fctx, model->norm); + gguf_add_tensor(fctx, model->output); + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + + gguf_add_tensor(fctx, layer.attention_norm); + gguf_add_tensor(fctx, layer.wq); + gguf_add_tensor(fctx, layer.wk); + gguf_add_tensor(fctx, layer.wv); + gguf_add_tensor(fctx, layer.wo); + gguf_add_tensor(fctx, layer.ffn_norm); + gguf_add_tensor(fctx, layer.w1); + gguf_add_tensor(fctx, layer.w2); + gguf_add_tensor(fctx, layer.w3); + } } -void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * model, const char * filename) { - struct llama_file file(filename, "wb"); - if (file.fp == NULL) { - return; +void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) { + struct gguf_context * fctx = gguf_init_empty(); + + save_llama_model_gguf(fctx, fn_vocab_model, model); + + // write file + const bool only_meta = false; + gguf_write_to_file(fctx, filename, only_meta); + gguf_free(fctx); +} + +void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) { + load_llama_model_gguf(fctx, f_ggml_ctx, model); + + uint32_t file_version; + GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); + GGML_ASSERT(file_version == 0); + + GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); + + load_opt_context_gguf(fctx, f_ggml_ctx, opt); +} + +void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { + save_llama_model_gguf(fctx, fn_vocab_model, model); + + gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0); + gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its); + gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples); + gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens); + + save_opt_context_gguf(fctx, opt); +} + +bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) { + struct ggml_context * f_ggml_ctx; + struct gguf_init_params params; + params.no_alloc = false; + params.ctx = &f_ggml_ctx; + struct gguf_context * fctx = gguf_init_from_file(filename, params); + if (fctx == NULL) { + return false; } -#pragma message("TODO: implement file saving using gguf") - (void) vocab; - (void) model; -// // write_magic -// file.write_u32(LLAMA_FILE_MAGIC); // magic -// file.write_u32(LLAMA_FILE_VERSION); // version -// // write_hparams -// file.write_u32(model->hparams.n_vocab); -// file.write_u32(model->hparams.n_embd); -// file.write_u32(model->hparams.n_mult); -// file.write_u32(model->hparams.n_head); -// file.write_u32(model->hparams.n_layer); -// file.write_u32(model->hparams.n_rot); -// file.write_u32(LLAMA_FTYPE_ALL_F32); -// // write_vocab -// uint32_t n_vocab = model->hparams.n_vocab; -// for (uint32_t i = 0; i < n_vocab; i++) { -// const auto & token_data = vocab->id_to_token.at(i); -// file.write_u32((uint32_t) token_data.tok.size()); -// file.write_raw(token_data.tok.data(), token_data.tok.size()); -// file.write_raw(&token_data.score, sizeof(token_data.score)); -// } -// // write tensors -// write_tensor(&file, model->tok_embeddings); -// write_tensor(&file, model->norm); -// write_tensor(&file, model->output); -// for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { -// auto & layer = model->layers[i]; -// -// write_tensor(&file, layer.attention_norm); -// write_tensor(&file, layer.wq); -// write_tensor(&file, layer.wk); -// write_tensor(&file, layer.wv); -// write_tensor(&file, layer.wo); -// write_tensor(&file, layer.ffn_norm); -// write_tensor(&file, layer.w1); -// write_tensor(&file, layer.w2); -// write_tensor(&file, layer.w3); -// } + load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt); + + return true; } -float cosine_decay(const int decay_steps, const float alpha, int step) { +void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { + struct gguf_context * fctx = gguf_init_empty(); + + save_checkpoint_gguf(fctx, fn_vocab_model, model, opt); + + // write file + const bool only_meta = false; + gguf_write_to_file(fctx, filename, only_meta); + gguf_free(fctx); +} + +float cosine_decay(const int decay_steps, const float minimum, int step) { if (step > decay_steps) { step = decay_steps; } const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps)); - const float decay = (1 - alpha)*cosine_decay + alpha; + const float decay = (1 - minimum)*cosine_decay + minimum; return decay; } -float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult) { - while (step > decay_steps) { - step -= decay_steps; - decay_steps = (int) restart_step_mult * decay_steps; +float cosine_decay_restart(int decay_steps, const float minimum, int step, float restart_step_mult, bool enable_restart) { + if (enable_restart) { + while (step > decay_steps) { + step -= decay_steps; + decay_steps = (int) restart_step_mult * decay_steps; + } } - return cosine_decay(decay_steps, alpha, step); + return cosine_decay(decay_steps, minimum, step); } struct train_params { @@ -2683,39 +1513,51 @@ struct train_params { int n_ctx; int n_embd; - int n_mult; int n_head; int n_layer; - int n_rotmax; + int n_ff; int n_threads; int n_batch; int n_examples; - int n_predict; + + float f_norm_rms_eps; + float rope_freq_base; + float rope_freq_scale; int print_info_interval; - int print_details_interval; bool samples_start_after_nl; bool use_adam; bool use_flash; - bool use_scratch; + bool use_checkpointing; + bool use_alloc; // only adam int warmup; int cos_decay_steps; float cos_decay_restart; - float cos_decay_alpha; + float cos_decay_min; + bool enable_restart; + + int opt_past; + float opt_delta; + int opt_max_no_improvement; int lbfgs_n_iter; int adam_n_iter; float adam_alpha; + float adam_min_alpha; float adam_decay; + int adam_decay_min_ndim; + float adam_beta1; + float adam_beta2; + float adam_gclip; + float adam_eps_f; int mem_model_gb; int mem_compute_gb; int mem_compute0_gb; - int mem_compute1_gb; }; struct train_params get_default_train_params() { @@ -2730,40 +1572,51 @@ struct train_params get_default_train_params() { params.n_ctx = 128; params.n_embd = 256; - params.n_mult = 256; params.n_head = 8; params.n_layer = 16; - params.n_rotmax = 64; + params.n_ff = 768; params.n_threads = 6; params.n_batch = 8; - params.n_examples = 8; - params.n_predict = 1024; + params.n_examples = 1; + + params.f_norm_rms_eps = 1e-5; + params.rope_freq_base = 10000.0f; + params.rope_freq_scale = 1.0f; params.print_info_interval = 1; - params.print_details_interval = 2; params.samples_start_after_nl = false; params.use_adam = true; params.use_flash = true; - params.use_scratch = true; + params.use_checkpointing = true; + params.use_alloc = true; + + params.opt_past = 0; + params.opt_delta = 1e-5f; + params.opt_max_no_improvement = 0; // only adam params.warmup = 100; params.cos_decay_steps = 1000; params.cos_decay_restart = 1.1f; - params.cos_decay_alpha = 0.0f; + params.cos_decay_min = 0.1f; + params.enable_restart = false; - params.lbfgs_n_iter = 16; - params.adam_n_iter = 16; - params.adam_alpha = 1e-3f; - params.adam_decay = 1e-3f; + params.lbfgs_n_iter = 256; + params.adam_n_iter = 256; + params.adam_alpha = 1e-3f; + params.adam_min_alpha = 0; + params.adam_decay = 1e-1f; + params.adam_decay_min_ndim = 2; + params.adam_beta1 = 0.9f; + params.adam_beta2 = 0.999f; + params.adam_gclip = 1.0f; + params.adam_eps_f = 0.0f; - params.mem_model_gb = 2; + params.mem_model_gb = 2; params.mem_compute_gb = 24; params.mem_compute0_gb = 8; - params.mem_compute1_gb = 2; - return params; } @@ -2780,35 +1633,47 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd); - fprintf(stderr, " --mult N Mult size used for new models, influences feedforward size. (default %d)\n", params->n_mult); + fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff); fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head); fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer); - fprintf(stderr, " --rotmax N Maximal number Rope dimensions for new models (default %d)\n", params->n_rotmax); + fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); + fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); + fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); - fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict); fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); - fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval); fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); - fprintf(stderr, " --no-flash Don't use flash attention.\n"); + fprintf(stderr, " --no-flash Don't use flash attention \n"); fprintf(stderr, " --use-flash Use flash attention (default)\n"); - fprintf(stderr, " --no-scratch Don't use scratch buffers\n"); - fprintf(stderr, " --use-scratch Use scratch buffers (default)\n"); - fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup); - fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps); - fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); - fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha); - fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter); + fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); + fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n"); + fprintf(stderr, " --no-alloc Don't use allocator\n"); + fprintf(stderr, " --use-alloc Use allocator (default)\n"); + fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); + fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); + fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); + fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min); + fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : ""); + fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : ""); + fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); + fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); + fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); + fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); + fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); + fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); + fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); + fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); + fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); - fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb); - fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb); + fprintf(stderr, " --mem-compute0 N Memory to allocate for automatic memory allocator in gigabytes. (default %d)\n", params->mem_compute0_gb); fprintf(stderr, "\n"); } @@ -2872,12 +1737,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_embd = std::stoi(argv[i]); - } else if (arg == "--mult") { + } else if (arg == "--ff") { if (++i >= argc) { invalid_param = true; break; } - params->n_mult = std::stoi(argv[i]); + params->n_ff = std::stoi(argv[i]); } else if (arg == "--head") { if (++i >= argc) { invalid_param = true; @@ -2890,12 +1755,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_layer = std::stoi(argv[i]); - } else if (arg == "--rotmax") { + } else if (arg == "--norm-rms-eps") { if (++i >= argc) { invalid_param = true; break; } - params->n_rotmax = std::stoi(argv[i]); + params->f_norm_rms_eps = std::stof(argv[i]); + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->rope_freq_base = std::stof(argv[i]); + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->rope_freq_scale = std::stof(argv[i]); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -2914,24 +1791,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_examples = std::stoi(argv[i]); - } else if (arg == "--predict") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->n_predict = std::stoi(argv[i]); } else if (arg == "--print-info-interval") { if (++i >= argc) { invalid_param = true; break; } params->print_info_interval = std::stoi(argv[i]); - } else if (arg == "--print-details-interval") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->print_details_interval = std::stoi(argv[i]); } else if (arg == "--samples-after-nl") { params->samples_start_after_nl = true; } else if (arg == "--use-lbfgs") { @@ -2942,10 +1807,14 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { params->use_flash = false; } else if (arg == "--use-flash") { params->use_flash = true; - } else if (arg == "--no-scratch") { - params->use_scratch = false; - } else if (arg == "--use-scratch") { - params->use_scratch = true; + } else if (arg == "--no-checkpointing") { + params->use_checkpointing = false; + } else if (arg == "--use-checkpointing") { + params->use_checkpointing = true; + } else if (arg == "--no-alloc") { + params->use_alloc = false; + } else if (arg == "--use-alloc") { + params->use_alloc = true; } else if (arg == "--warmup") { if (++i >= argc) { invalid_param = true; @@ -2964,18 +1833,40 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->cos_decay_restart = std::stof(argv[i]); - } else if (arg == "--cos-decay-alpha") { + } else if (arg == "--cos-decay-min") { if (++i >= argc) { invalid_param = true; break; } - params->cos_decay_alpha = std::stof(argv[i]); - } else if (arg == "--lbfgs-iter") { + params->cos_decay_min = std::stof(argv[i]); + } else if (arg == "--enable-restart") { + params->enable_restart = true; + } else if (arg == "--disable-restart") { + params->enable_restart = false; + } else if (arg == "--opt-past") { if (++i >= argc) { invalid_param = true; break; } - params->lbfgs_n_iter = std::stoi(argv[i]); + params->opt_past = std::stoi(argv[i]); + } else if (arg == "--opt-delta") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->opt_delta = std::stof(argv[i]); + } else if (arg == "--opt-max-no-improvement") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->opt_max_no_improvement = std::stoi(argv[i]); + } else if (arg == "--adam-epsf") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_eps_f = std::stof(argv[i]); } else if (arg == "--adam-iter") { if (++i >= argc) { invalid_param = true; @@ -2988,12 +1879,48 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->adam_alpha = std::stof(argv[i]); + } else if (arg == "--adam-min-alpha") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_min_alpha = std::stof(argv[i]); } else if (arg == "--adam-decay") { if (++i >= argc) { invalid_param = true; break; } params->adam_decay = std::stof(argv[i]); + } else if (arg == "--adam-decay-min-ndim") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_decay_min_ndim = std::stoi(argv[i]); + } else if (arg == "--adam-beta1") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_beta1 = std::stof(argv[i]); + } else if (arg == "--adam-beta2") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_beta2 = std::stof(argv[i]); + } else if (arg == "--adam-gclip") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_gclip = std::stof(argv[i]); + } else if (arg == "--lbfgs-iter") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->lbfgs_n_iter = std::stoi(argv[i]); } else if (arg == "--mem-model") { if (++i >= argc) { invalid_param = true; @@ -3012,12 +1939,6 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->mem_compute0_gb = std::stoi(argv[i]); - } else if (arg == "--mem-compute1") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->mem_compute1_gb = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { train_print_usage(argc, argv, &default_params); exit(0); @@ -3036,6 +1957,63 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { return true; } +struct opt_callback_data { + struct train_params * params; + struct ggml_opt_context * opt; + struct llama_context * lctx; + llama_token * tokens_data; + size_t tokens_size; + int * samples_data; + size_t samples_size; + int shuffle_countdown; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_logits; + struct ggml_tensor * target_probs; +}; + +void opt_callback(void * vdata, float * sched) { + struct opt_callback_data * data = (struct opt_callback_data *) vdata; + struct train_params * params = data->params; + struct ggml_opt_context * opt = data->opt; + int n_batch = params->n_batch; + + *sched = (opt->iter < params->warmup) + ? (float) opt->iter / (float) params->warmup + : cosine_decay_restart( + params->cos_decay_steps, + params->cos_decay_min, + opt->iter - params->warmup, + params->cos_decay_restart, + params->enable_restart); + float min_sched = params->adam_min_alpha / params->adam_alpha; + *sched = min_sched + *sched * (1.0f - min_sched); + + int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); + printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0); + + if (data->shuffle_countdown < n_batch) { + printf("%s: reshuffle samples\n", __func__); + shuffle_ints(data->samples_data, data->samples_data + data->samples_size); + for (int i = 0; i < (int) data->samples_size; ++i) { + GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size); + } + data->shuffle_countdown = data->samples_size; + } + + get_example_targets_batch( + data->lctx, + data->samples_data, + data->samples_size, + data->tokens_data, + data->tokens_size, + opt->iter, + data->tokens_input, + data->target_logits, + data->target_probs); + + data->shuffle_countdown -= n_batch; +} + int main(int argc, char ** argv) { struct train_params params = get_default_train_params(); @@ -3055,18 +2033,6 @@ int main(int argc, char ** argv) { struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params); struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); - struct llama_vocab vocab; - { - const int n_vocab = llama_n_vocab(lctx); - vocab.id_to_token.resize(n_vocab); - for (int i=0; i train_tokens; if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) { @@ -3078,10 +2044,14 @@ int main(int argc, char ** argv) { model.hparams.n_vocab = llama_n_vocab(lctx); model.hparams.n_ctx = params.n_ctx; model.hparams.n_embd = params.n_embd; - model.hparams.n_mult = params.n_mult; model.hparams.n_head = params.n_head; model.hparams.n_layer = params.n_layer; - model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); + model.hparams.n_ff = params.n_ff; + // llama.cpp requires n_rot to be exactly n_embd / n_head + model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head; + model.hparams.f_norm_rms_eps = params.f_norm_rms_eps; + model.hparams.rope_freq_base = params.rope_freq_base; + model.hparams.rope_freq_scale = params.rope_freq_scale; print_params(&model.hparams); @@ -3103,19 +2073,12 @@ int main(int argc, char ** argv) { } printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); - struct my_llama_kv_cache kv_self; - - struct ggml_init_params lcparams; lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); lcparams.mem_buffer = NULL; lcparams.no_alloc = false; model.ctx = ggml_init(lcparams); - kv_self.ctx = model.ctx; - - my_llama_sampler sampler; - int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; @@ -3126,24 +2089,38 @@ int main(int argc, char ** argv) { struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); - opt_params_adam.print_forward_graph = false; + opt_params_adam.print_forward_graph = false; opt_params_adam.print_backward_graph = false; - opt_params_adam.n_threads = params.n_threads; - opt_params_adam.adam.n_iter = params.adam_n_iter; - opt_params_adam.adam.sched = 1.0f; - opt_params_adam.adam.alpha = params.adam_alpha; - opt_params_adam.adam.decay = params.adam_decay; + opt_params_adam.n_threads = params.n_threads; + opt_params_adam.past = params.opt_past; + opt_params_adam.delta = params.opt_delta; + opt_params_adam.max_no_improvement = params.opt_max_no_improvement; + opt_params_adam.adam.n_iter = params.adam_n_iter; + opt_params_adam.adam.sched = 1.0f; + opt_params_adam.adam.alpha = params.adam_alpha; + opt_params_adam.adam.decay = params.adam_decay; + opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim; + opt_params_adam.adam.beta1 = params.adam_beta1; + opt_params_adam.adam.beta2 = params.adam_beta2; + opt_params_adam.adam.gclip = params.adam_gclip; + opt_params_adam.adam.eps_f = params.adam_eps_f; - opt_params_lbfgs.print_forward_graph = false; + opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_backward_graph = false; - opt_params_lbfgs.n_threads = params.n_threads; - opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter; + opt_params_lbfgs.n_threads = params.n_threads; + opt_params_adam.past = params.opt_past; + opt_params_adam.delta = params.opt_delta; + opt_params_adam.max_no_improvement = params.opt_max_no_improvement; + opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter; opt->ctx = model.ctx; opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; printf("%s: init model\n", __func__); - bool existed = load_checkpoint(&model, opt, params.fn_checkpoint_in, true); + bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt); + if (!existed) { + init_model(&model); + } set_param_model(&model); opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; @@ -3156,11 +2133,7 @@ int main(int argc, char ** argv) { randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); } - init_kv_cache(&kv_self, &model, 1); - // init_kv_cache(&kv_self, &model, n_batch); - init_sampler(&sampler, lctx); - - printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx)); + printf("used_mem model: %zu bytes\n", ggml_used_mem(model.ctx)); // ggml_print_tensor_objects(model.ctx); // TODO: use std::vector intead of "new" @@ -3168,9 +2141,13 @@ int main(int argc, char ** argv) { uint8_t * compute_addr = new uint8_t[compute_size]; size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb); - size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb); uint8_t * compute_buf_0 = new uint8_t[size_buf_0]; - uint8_t * compute_buf_1 = new uint8_t[size_buf_1]; + + ggml_allocr * alloc = NULL; + if (params.use_alloc) { + static const size_t tensor_alignment = 32; + alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); + } GGML_ASSERT(n_tokens < (int) train_tokens.size()); std::vector train_samples; @@ -3185,10 +2162,23 @@ int main(int argc, char ** argv) { GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size()); } - std::vector work_buffer; - printf("%s: begin training\n", __func__); + struct opt_callback_data opt_cb_data; + opt_cb_data.params = ¶ms; + opt_cb_data.opt = opt; + opt_cb_data.lctx = lctx; + opt_cb_data.tokens_data = train_tokens.data(); + opt_cb_data.tokens_size = train_tokens.size(); + opt_cb_data.samples_data = train_samples.data(); + opt_cb_data.samples_size = train_samples.size(); + opt_cb_data.shuffle_countdown = train_samples.size(); + opt_cb_data.tokens_input = NULL; + opt_cb_data.target_logits = NULL; + opt_cb_data.target_probs = NULL; + + int64_t t0 = ggml_time_ms(); + for (int ex = 0; ex < params.n_examples; ++ex) { if (ex*n_batch >= (int) train_samples.size()) { shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size()); @@ -3198,198 +2188,110 @@ int main(int argc, char ** argv) { } struct ggml_init_params cparams = { - /*.mem_size =*/ compute_size, - /*.mem_buffer =*/ compute_addr, - /*.no_alloc =*/ false, + compute_size, // mem_size + compute_addr, // mem_buffer + false, // no_alloc }; struct ggml_context * ctx0 = ggml_init(cparams); - struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + ggml_set_no_alloc(ctx0, false); + + // don't use alloc for input tensors, so we can safely fill them with data + //struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); //struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + ggml_set_no_alloc(ctx0, (alloc != NULL)); + + if (alloc) { + ggml_allocr_reset(alloc); + } + + opt_cb_data.tokens_input = tokens_input; + opt_cb_data.target_logits = target_logits; + opt_cb_data.target_probs = target_probs; + int n_past = 0; - struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - - memset(gfbuf->data, 0, ggml_nbytes(gfbuf)); - memset(gbbuf->data, 0, ggml_nbytes(gbbuf)); - - struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; - struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; - - - get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + struct ggml_cgraph * gb = ggml_new_graph(ctx0); + struct ggml_cgraph * gb_tmp = params.use_checkpointing + ? ggml_new_graph(ctx0) + : NULL; GGML_ASSERT(n_past == 0); struct ggml_tensor * loss = NULL; struct ggml_tensor * logits = NULL; - if (params.use_scratch) { - loss = forward_batch_wo_cache_flash_attn_train( - &model, ctx0, - gf, gb, - &logits, tokens_input, target_probs, - compute_buf_0, compute_buf_1, - size_buf_0, size_buf_1, - n_tokens, n_batch); - } else if (params.use_flash) { - logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch); - loss = cross_entropy_loss(ctx0, logits, target_probs); - ggml_build_forward_expand(gf, loss); - *gb = ggml_build_backward(ctx0, gf, true); - } else { - logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch); - loss = cross_entropy_loss(ctx0, logits, target_probs); - ggml_build_forward_expand(gf, loss); - *gb = ggml_build_backward(ctx0, gf, true); - } - - ggml_graph_compute_helper(work_buffer, gf, params.n_threads); + loss = llama_build_train_graphs( + &model, alloc, ctx0, + gf, gb, gb_tmp, + &logits, tokens_input, target_probs, + n_tokens, n_batch, + params.use_flash, + params.use_checkpointing + ); size_t used_mem_before_opt = ggml_used_mem(ctx0); - float error_before_opt = ggml_get_f32_1d(loss, 0); - opt->params.adam.sched = (opt->iter < params.warmup) ? (float) opt->iter / (float) params.warmup : cosine_decay_restart( params.cos_decay_steps, - params.cos_decay_alpha, + params.cos_decay_min, opt->iter - params.warmup, - params.cos_decay_restart); + params.cos_decay_restart, + params.enable_restart); + + float min_sched = params.adam_min_alpha / params.adam_alpha; + opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched); printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); - ggml_opt_resume_g(ctx0, opt, loss, gf, gb); + ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data); size_t used_mem_after_opt = ggml_used_mem(ctx0); + int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter; model.train_its = opt->iter; - model.train_samples += n_batch; - model.train_tokens += n_batch * n_tokens; - - ggml_graph_compute_helper(work_buffer, gf, params.n_threads); - - float error_after_opt = ggml_get_f32_1d(loss, 0); + model.train_samples += n_batch * n_iter; + model.train_tokens += n_batch * n_tokens * n_iter; if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); - printf("error_before_opt: %.6f\n", error_before_opt); - printf("error_after_opt: %.6f\n", error_after_opt); + printf("error_before_opt: %.6f\n", opt->loss_before); + printf("error_after_opt: %.6f\n", opt->loss_after); printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); } - if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) { - // set_logits_masked(logits, token_notavail, -1e9); - for (int i=0; idata + i*logits->nb[2] + k*logits->nb[1]), - (llama_token *) ((char *) tokens_input->data + i*tokens_input->nb[1]), - k); - * ((int32_t *) ((char *) after_opt_best_samples->data + i*after_opt_best_samples->nb[1] + k*after_opt_best_samples->nb[0])) = token; - } - } - - // printf("probabilities after optimization:\n"); - // print_matrix(after_opt_probs); - printf("Example:\n---\n"); - print_tokens_batch(lctx, tokens_input); - printf("\n---\n"); - - // printf("best samples after optimization:\n---\n"); - printf("samples after optimization:\n---\n"); - print_tokens_batch(lctx, after_opt_best_samples); - printf("\n---\n"); - } - ggml_free(ctx0); } + int64_t t1 = ggml_time_ms(); + int64_t d = t1-t0; + double dd = (double) d * 1e-3; + printf("%s: total training time=%f seconds\n", __func__, dd); + if (params.n_examples > 0) { - save_checkpoint(&model, opt, params.fn_checkpoint_out); + save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt); } if (strlen(params.fn_model_out) > 0) { - save_as_llama_model(&vocab, &model, params.fn_model_out); + save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model); } - { - int n_gen = params.n_predict; - int sample_ctx = n_tokens - n_tokens/8; - - sampler.params.temp = 0.2f; - sampler.params.repeat_penalty = 1.1f; - sampler.params.mirostat = 2; - init_sampler(&sampler, lctx); - - printf("Generating %d tokens.\n", n_gen); - - struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); - struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); - - get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); - for (int i=sample_ctx; idata + (sample_ctx-1)*logits->nb[1]), - (llama_token *) tokens_input->data, - sample_ctx-1); - //int token = ggml_get_i32_1d(best_samples, sample_ctx-1); - - // print_row(probs, sample_at); - print_token(lctx, token); - - lshift_examples(tokens_input, target_logits, target_probs, 1); - ggml_set_i32_1d(tokens_input, 0, 0); - ggml_set_i32_1d(tokens_input, sample_ctx-1, token); - - ggml_free(ctx0); - } + if (alloc) { + ggml_allocr_free(alloc); } delete[] compute_addr; delete[] compute_buf_0; - delete[] compute_buf_1; - + ggml_free(model.ctx); llama_free(lctx); llama_free_model(lmodel); - ggml_free(model.ctx); - return 0; } diff --git a/ggml-alloc.c b/ggml-alloc.c index 140e9a2a7..63beb1d4e 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -107,6 +107,10 @@ static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct g } void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { +#ifdef GGML_ALLOCATOR_DEBUG + GGML_ASSERT(ggml_is_view(tensor) == false); // views generally get data pointer from one of their sources + GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated +#endif size_t size = ggml_allocator_get_alloc_size(alloc, tensor); size = aligned_offset(NULL, size, alloc->alignment); diff --git a/ggml.c b/ggml.c index dadb30757..9a787863d 100644 --- a/ggml.c +++ b/ggml.c @@ -123,6 +123,8 @@ typedef void * thread_ret_t; #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 #define GGML_SILU_FP16 +// #define GGML_CROSS_ENTROPY_EXP_FP16 +// #define GGML_FLASH_ATTN_EXP_FP16 #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 @@ -186,8 +188,8 @@ typedef void * thread_ret_t; // #if defined(_MSC_VER) || defined(__MINGW32__) -#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) -#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) +#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) +#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) #else inline static void * ggml_aligned_malloc(size_t size) { void * aligned_memory = NULL; @@ -212,8 +214,8 @@ inline static void * ggml_aligned_malloc(size_t size) { } return aligned_memory; } -#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) -#define GGML_ALIGNED_FREE(ptr) free(ptr) +#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) +#define GGML_ALIGNED_FREE(ptr) free(ptr) #endif #define UNUSED GGML_UNUSED @@ -5857,7 +5859,8 @@ struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * b, + float eps) { bool is_node = false; if (a->grad) { @@ -5867,6 +5870,8 @@ struct ggml_tensor * ggml_rms_norm_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &eps, sizeof(eps)); + result->op = GGML_OP_RMS_NORM_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; @@ -9443,6 +9448,8 @@ static void ggml_compute_forward_div_f32( #ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_div_f32); + vDSP_vdiv( (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, @@ -10749,7 +10756,8 @@ static void ggml_compute_forward_rms_norm_back_f32( GGML_TENSOR_BINARY_OP_LOCALS; - const float eps = 1e-6f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -12139,6 +12147,7 @@ static void ggml_compute_forward_soft_max_back_f32( // dx = J * dy // dxk = sum_i(Jki * dyi) // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk // dxk = sum_i(-yk*yi * dyi) + yk*dyk // dxk = -yk * sum_i(yi * dyi) + yk*dyk // dxk = -yk * dot(y, dy) + yk*dyk @@ -13929,7 +13938,7 @@ static void ggml_compute_forward_flash_attn_f32( vvexpf(S, S, &Mup); ggml_vec_sum_f32(Mup, &sum, S); #else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { @@ -13939,9 +13948,13 @@ static void ggml_compute_forward_flash_attn_f32( if (SS[j] == -INFINITY) { SS[j] = 0.0f; } else { +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SS[j] - max); +#else ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); +#endif sump[j] += (ggml_float)val; SS[j] = val; } @@ -14519,7 +14532,7 @@ static void ggml_compute_forward_flash_attn_back_f32( vvexpf(SM, SM, &Mup); ggml_vec_sum_f32(Mup, &sum, SM); #else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { @@ -14530,9 +14543,13 @@ static void ggml_compute_forward_flash_attn_back_f32( if (SR[j] == -INFINITY) { SW[j] = 0.0f; } else { +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SR[j] - max); +#else ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); +#endif sump[j] += (ggml_float)val; SW[j] = val; } @@ -15270,6 +15287,8 @@ static void ggml_compute_forward_cross_entropy_loss_f32( const int nc = src0->ne[0]; const int nr = ggml_nrows(src0); + GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); + if (params->type == GGML_TASK_INIT) { if (ith == 0) { memset(sums, 0, sizeof(float) * (nth + nth * nc)); @@ -15281,7 +15300,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( if (ith == 0) { float * dp = (float *) dst->data; ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f; + dp[0] *= -1.0f / (float) nr; } return; } @@ -15298,7 +15317,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( for (int i1 = ir0; i1 < ir1; i1++) { float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - float * st = (float *) params->wdata + nth + ith*nc; + float * st = ((float *) params->wdata) + nth + ith*nc; #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -15313,15 +15332,19 @@ static void ggml_compute_forward_cross_entropy_loss_f32( float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - uint16_t scvt; + uint16_t scvt; UNUSED(scvt); for (int i = 0; i < nc; i++) { if (s0[i] == -INFINITY) { st[i] = 0.0f; } else { - // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); +#ifndef GGML_CROSS_ENTROPY_EXP_FP16 + const float s = s0[i] - max; + const float val = expf(s); +#else ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); +#endif sum += (ggml_float)val; st[i] = val; } @@ -15337,7 +15360,9 @@ static void ggml_compute_forward_cross_entropy_loss_f32( ggml_vec_log_f32(nc, st, st); ggml_vec_mul_f32(nc, st, st, s1); - ggml_vec_sum_f32(nc, sums + ith, st); + float st_sum = 0; + ggml_vec_sum_f32(nc, &st_sum, st); + sums[ith] += st_sum; #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -15387,7 +15412,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( return; } - const float eps = 1e-9f; + const double eps = 1e-9; // TODO: handle transposed/permuted matrices const int64_t nc = src0->ne[0]; @@ -15406,7 +15431,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - float * sm = (float *) params->wdata + ith*nc; #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -15415,54 +15439,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(s1[i])); } #endif - // step by step explanation: - { - //float * sums = (float *) params->wdata; - - // forward pass with annotated gradients from backward pass - // (built by going in reverse operation order, adding to gradients of current operation args) - // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum - // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1])) - // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps) - // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3] - // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3 - // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1 - // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]] - // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel] - - // substitute into grad[st1], because we can reuse softmax_back from this point on - // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps)) - // postorder: - // grad[st1] := softmax(s0) - // grad[st1] := grad[st1]*(1.0 - eps) - // grad[st1] := grad[st1] + eps - // grad[st1] := s1 / grad[st1] - // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel] - - // src0 gradients by going through softmax_back - // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1])) - // from softmax_back: - // dxk = yk * (dyk - dot(y, dy)) - // dot_y_dy := dot(y, dy) - // dx := dy - // dx := dx - dot_y_dy - // dx := dx * y - // postorder: - // dot_st1_dst1 := dot(st1, grad[st1]) - // grad[s0] := grad[st1] - // grad[s0] := grad[s0] - dot_st1_dst1 - // grad[s0] := grad[s0] * st1 - - // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1] - // sm := softmax(s0) - // grad[s0] := sm*(1.0 - eps) - // grad[s0] := grad[s0] + eps - // grad[s0] := s1 / grad[s0] - // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel] - // dot_st1_dst1 := dot(sm, grad[s0]) - // grad[s0] := grad[s0] - dot_st1_dst1 - // grad[s0] := grad[s0] * sm - } // soft_max ggml_float sum = 0.0; @@ -15470,39 +15446,37 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - uint16_t scvt; + uint16_t scvt; UNUSED(scvt); for (int i = 0; i < nc; i++) { if (s0[i] == -INFINITY) { - sm[i] = 0.0f; + ds0[i] = 0.0f; } else { - // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); +#ifndef GGML_CROSS_ENTROPY_EXP_FP16 + const float s = s0[i] - max; + const float val = expf(s); +#else ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); memcpy(&scvt, &s, sizeof(scvt)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); +#endif sum += (ggml_float)val; - sm[i] = val; + ds0[i] = val; } } assert(sum > 0.0); - sum = 1.0/sum; + sum = (1.0 - eps)/sum; } - float dot_st1_dst1 = 0; - ggml_vec_scale_f32(nc, sm, sum); - ggml_vec_cpy_f32 (nc, ds0, sm); - ggml_vec_scale_f32(nc, ds0, (1.0f - eps)); - ggml_vec_add1_f32 (nc, ds0, ds0, eps); - ggml_vec_div_f32 (nc, ds0, s1, ds0); - ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]); - ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0); - ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1); - ggml_vec_mul_f32 (nc, ds0, ds0, sm); + // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + ggml_vec_scale_f32(nc, ds0, sum); + ggml_vec_add1_f32(nc, ds0, ds0, eps); + ggml_vec_sub_f32(nc, ds0, ds0, s1); + ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); + #ifndef NDEBUG for (int i = 0; i < nc; ++i) { - assert(!isnan(sm[i])); - assert(!isinf(sm[i])); assert(!isnan(ds0[i])); assert(!isinf(ds0[i])); } @@ -16057,9 +16031,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + src0->grad = ggml_add_impl(ctx, src0->grad, - ggml_rms_norm_back(ctx, src0, tensor->grad), + ggml_rms_norm_back(ctx, src0, tensor->grad, eps), inplace); } } break; @@ -16827,9 +16804,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { return result; } -struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { - struct ggml_cgraph result = *gf; - +void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { GGML_ASSERT(gf->n_nodes > 0); // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph @@ -16853,15 +16828,19 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg } } - for (int i = gf->n_nodes - 1; i >= 0; i--) { + for (int i = 0; i < gf->n_nodes; i++) { struct ggml_tensor * node = gf->nodes[i]; if (node->is_param) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_expand(&result, node->grad); + ggml_build_forward_expand(gb, node->grad); } } +} +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { + struct ggml_cgraph result = *gf; + ggml_build_backward_expand(ctx, gf, &result, keep); return result; } @@ -17537,10 +17516,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { n_tasks = n_threads; - - size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks; - - work_size = MAX(work_size, cur); } break; case GGML_OP_NONE: { @@ -18418,14 +18393,16 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { GGML_ASSERT(ggml_is_scalar(f)); // these will store the parameters we want to optimize struct ggml_tensor * ps[GGML_MAX_PARAMS]; int np = 0; - int nx = 0; + int64_t nx = 0; for (int i = 0; i < gf->n_nodes; ++i) { if (gf->nodes[i]->is_param) { GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); @@ -18444,31 +18421,32 @@ static enum ggml_opt_result ggml_opt_adam( } // constants - const float sched = params.adam.sched; - const float decay = params.adam.decay * sched; - const float alpha = params.adam.alpha * sched; + float sched = params.adam.sched; + const float alpha = params.adam.alpha; + const float decay = params.adam.decay * alpha; const float beta1 = params.adam.beta1; const float beta2 = params.adam.beta2; const float eps = params.adam.eps; + const float gclip = params.adam.gclip; + const int decay_min_ndim = params.adam.decay_min_ndim; - float * x = opt->adam.x->data; // view of the parameters - float * g1 = opt->adam.g1->data; // gradient - float * g2 = opt->adam.g2->data; // gradient squared float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment - float * mh = opt->adam.mh->data; // first moment hat - float * vh = opt->adam.vh->data; // second moment hat float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - // update view - ggml_opt_get_params(np, ps, x); + if (callback) { + callback(callback_data, &sched); + } // compute the function value ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute_with_ctx(ctx, gb, params.n_threads); + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + ggml_graph_compute(gb, &cplan); opt->adam.fx_prev = ggml_get_f32_1d(f, 0); opt->adam.fx_best = opt->adam.fx_prev; @@ -18476,6 +18454,9 @@ static enum ggml_opt_result ggml_opt_adam( pf[opt->iter % params.past] = opt->adam.fx_prev; } + opt->loss_before = opt->adam.fx_prev; + opt->loss_after = opt->adam.fx_prev; + // initialize if (opt->just_initialized) { opt->adam.n_no_improvement = 0; @@ -18508,50 +18489,55 @@ static enum ggml_opt_result ggml_opt_adam( UNUSED(t_start_cpu); { - // update the gradient - ggml_opt_get_grad(np, ps, g1); + float gnorm = 1.0f; + if (gclip > 0.0f) { + // gradient clipping + ggml_float sum = 0.0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]); + for (int64_t j = 0; j < ne; ++j) { + float g = ggml_get_f32_1d(ps[p]->grad, j); + sum += (ggml_float)(g*g); + } + } + ggml_float norm = sqrt(sum); + if (norm > (ggml_float) gclip) { + gnorm = (float) ((ggml_float) gclip / norm); + } + } + const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]); + const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; + for (int64_t j = 0; j < ne; ++j) { + float x = ggml_get_f32_1d(ps[p], j); + float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; + m[i] = m[i]*beta1 + g*(1.0f - beta1); + v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float mh = m[i]*beta1h; + float vh = v[i]*beta2h; + vh = sqrtf(vh) + eps; + x = x*(1.0f - p_decay) - mh/vh; + ggml_set_f32_1d(ps[p], j, x); + ++i; + } + } + } - // m_t = beta1*m_t-1 + (1 - beta1)*g_t - ggml_vec_scale_f32(nx, m, beta1); - ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); - - // g2 = g1^2 - ggml_vec_sqr_f32 (nx, g2, g1); - - // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 - ggml_vec_scale_f32(nx, v, beta2); - ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); - - // m^hat = m_t / (1 - beta1^t) - // v^hat = v_t / (1 - beta2^t) - // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1) - // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1 - // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps) - // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps) - // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay) - ggml_vec_cpy_f32 (nx, mh, m); - ggml_vec_cpy_f32 (nx, vh, v); - - ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter))); - ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter))); - - ggml_vec_sqrt_f32 (nx, vh, vh); - ggml_vec_acc1_f32 (nx, vh, eps); - - ggml_vec_div_f32 (nx, mh, mh, vh); - ggml_vec_scale_f32(nx, x, 1.0f - decay); - ggml_vec_sub_f32 (nx, x, x, mh); - - // update the parameters - ggml_opt_set_params(np, ps, x); + if (callback) { + callback(callback_data, &sched); } ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute_with_ctx(ctx, gb, params.n_threads); + ggml_graph_compute(gb, &cplan); const float fx = ggml_get_f32_1d(f, 0); + opt->loss_after = fx; + // check convergence if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { @@ -18620,7 +18606,6 @@ struct ggml_lbfgs_iteration_data { }; static enum ggml_opt_result linesearch_backtracking( - struct ggml_context * ctx, const struct ggml_opt_params * params, int nx, float * x, @@ -18632,8 +18617,11 @@ static enum ggml_opt_result linesearch_backtracking( struct ggml_tensor * f, struct ggml_cgraph * gf, struct ggml_cgraph * gb, + struct ggml_cplan * cplan, const int np, - struct ggml_tensor * ps[]) { + struct ggml_tensor * ps[], + ggml_opt_callback callback, + void * callback_data) { int count = 0; float width = 0.0f; @@ -18662,6 +18650,12 @@ static enum ggml_opt_result linesearch_backtracking( dgtest = params->lbfgs.ftol*dginit; while (true) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -18672,7 +18666,7 @@ static enum ggml_opt_result linesearch_backtracking( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute_with_ctx(ctx, gb, params->n_threads); + ggml_graph_compute(gb, cplan); ggml_opt_get_grad(np, ps, g); @@ -18732,7 +18726,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { @@ -18764,6 +18760,10 @@ static enum ggml_opt_result ggml_opt_lbfgs( opt->iter = iter; } + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + float * x = opt->lbfgs.x->data; // current parameters float * xp = opt->lbfgs.xp->data; // previous parameters float * g = opt->lbfgs.g->data; // current gradient @@ -18785,6 +18785,12 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); @@ -18792,11 +18798,14 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute_with_ctx(ctx, gb, params.n_threads); + ggml_graph_compute(gb, &cplan); ggml_opt_get_grad(np, ps, g); fx = ggml_get_f32_1d(f, 0); + + opt->loss_before = fx; + opt->loss_after = fx; } // search direction = -gradient @@ -18851,7 +18860,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data); if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -18861,6 +18870,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( return ls; } + opt->loss_after = fx; + ggml_vec_norm_f32(nx, &xnorm, x); ggml_vec_norm_f32(nx, &gnorm, g); @@ -18918,7 +18929,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( // ys = y^t \cdot s -> 1 / \rho. // yy = y^t \cdot y. // - ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]); + ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); lm_ys[end[0]] = ys; @@ -18981,13 +18992,15 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .adam = { .n_iter = 10000, .sched = 1.000f, - .decay = 0.001f, + .decay = 0.0f, + .decay_min_ndim = 2, .alpha = 0.001f, .beta1 = 0.9f, .beta2 = 0.999f, .eps = 1e-8f, .eps_f = 1e-5f, .eps_g = 1e-3f, + .gclip = 0.0f, }, }; } break; @@ -19037,23 +19050,13 @@ GGML_API void ggml_opt_init( switch (opt->params.type) { case GGML_OPT_ADAM: { - opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) : NULL; - ggml_set_zero(opt->adam.x); - ggml_set_zero(opt->adam.g1); - ggml_set_zero(opt->adam.g2); ggml_set_zero(opt->adam.m); ggml_set_zero(opt->adam.v); - ggml_set_zero(opt->adam.mh); - ggml_set_zero(opt->adam.vh); if (opt->adam.pf) { ggml_set_zero(opt->adam.pf); } @@ -19137,7 +19140,7 @@ enum ggml_opt_result ggml_opt_resume( *gf = ggml_build_forward (f); *gb = ggml_build_backward(ctx, gf, true); - return ggml_opt_resume_g(ctx, opt, f, gf, gb); + return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } enum ggml_opt_result ggml_opt_resume_g( @@ -19145,7 +19148,9 @@ enum ggml_opt_result ggml_opt_resume_g( struct ggml_opt_context * opt, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { // build forward + backward compute graphs enum ggml_opt_result result = GGML_OPT_OK; @@ -19153,11 +19158,11 @@ enum ggml_opt_result ggml_opt_resume_g( switch (opt->params.type) { case GGML_OPT_ADAM: { - result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; case GGML_OPT_LBFGS: { - result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; } @@ -19612,7 +19617,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the kv pairs { - ctx->kv = GGML_ALIGNED_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv)); + ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv)); for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { struct gguf_kv * kv = &ctx->kv[i]; @@ -19695,7 +19700,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the tensor infos { - ctx->infos = GGML_ALIGNED_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); + ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { struct gguf_tensor_info * info = &ctx->infos[i]; @@ -19896,7 +19901,7 @@ void gguf_free(struct gguf_context * ctx) { } } - GGML_ALIGNED_FREE(ctx->kv); + free(ctx->kv); } if (ctx->infos) { @@ -19908,7 +19913,7 @@ void gguf_free(struct gguf_context * ctx) { } } - GGML_ALIGNED_FREE(ctx->infos); + free(ctx->infos); } GGML_ALIGNED_FREE(ctx); diff --git a/ggml.h b/ggml.h index 4ef3d5253..8b410cc85 100644 --- a/ggml.h +++ b/ggml.h @@ -952,11 +952,11 @@ extern "C" { // a - x // b - dy - // TODO: update with configurable eps GGML_API struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float eps); // A: n columns, m rows // B: n columns, p rows (i.e. we transpose it internally) @@ -1612,7 +1612,8 @@ extern "C" { struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); @@ -1677,6 +1678,8 @@ extern "C" { GGML_LINESEARCH_INVALID_PARAMETERS, }; + typedef void (*ggml_opt_callback)(void * data, float * sched); + // optimization parameters // // see ggml.c (ggml_opt_default_params) for default values @@ -1712,12 +1715,14 @@ extern "C" { float sched; // schedule multiplier (fixed, decay or warmup) float decay; // weight decay for AdamW, use 0.0f to disable + int decay_min_ndim; // minimum number of tensor dimension to apply weight decay float alpha; // learning rate float beta1; float beta2; float eps; // epsilon for numerical stability float eps_f; // epsilon for convergence test float eps_g; // epsilon for convergence test + float gclip; // gradient clipping } adam; // LBFGS parameters @@ -1745,14 +1750,12 @@ extern "C" { bool just_initialized; + float loss_before; + float loss_after; + struct { - struct ggml_tensor * x; // view of the parameters - struct ggml_tensor * g1; // gradient - struct ggml_tensor * g2; // gradient squared struct ggml_tensor * m; // first moment struct ggml_tensor * v; // second moment - struct ggml_tensor * mh; // first moment hat - struct ggml_tensor * vh; // second moment hat struct ggml_tensor * pf; // past function values float fx_best; float fx_prev; @@ -1789,10 +1792,10 @@ extern "C" { // initialize optimizer context GGML_API void ggml_opt_init( - struct ggml_context * ctx, + struct ggml_context * ctx, struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); + struct ggml_opt_params params, + int64_t nx); // continue optimizing the function defined by the tensor f GGML_API enum ggml_opt_result ggml_opt_resume( @@ -1806,7 +1809,9 @@ extern "C" { struct ggml_opt_context * opt, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb); + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data); // // quantization diff --git a/llama.cpp b/llama.cpp index 11697ee65..7cb468538 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6248,7 +6248,6 @@ const char * llama_print_system_info(void) { } void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { - fprintf(stream, "\n"); fprintf(stream, "###########\n"); fprintf(stream, "# Timings #\n"); @@ -6264,10 +6263,10 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); - fprintf(stream, "t_eval_us: %ld # total microseconds spent generating tokens\n", ctx->t_eval_us); - fprintf(stream, "t_load_us: %ld # total microseconds spent loading the model\n", ctx->t_load_us); - fprintf(stream, "t_p_eval_us: %ld # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %ld # total microseconds spent sampling\n", ctx->t_sample_us); + fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); + fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); + fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); + fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 75a698d73..468cde66a 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -275,14 +275,14 @@ static bool check_gradient( ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - const float f0 = ggml_get_f32_1d(f, 0); + const double f0 = ggml_get_f32_1d(f, 0); ggml_set_f32_1d(x[i], k, xm); ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - const float f1 = ggml_get_f32_1d(f, 0); - const float g0 = (f0 - f1)/(2.0f*eps); + const double f1 = ggml_get_f32_1d(f, 0); + const double g0 = (f0 - f1)/(2.0*(double) eps); ggml_set_f32_1d(x[i], k, x0); @@ -292,10 +292,10 @@ static bool check_gradient( ggml_graph_compute_with_ctx(ctx0, &gb, n_threads); - const float g1 = ggml_get_f32_1d(x[i]->grad, k); + const double g1 = ggml_get_f32_1d(x[i]->grad, k); - const float error_abs = fabsf(g0 - g1); - const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0; + const double error_abs = fabs(g0 - g1); + const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0; if (error_abs > max_error_abs || error_rel > max_error_rel) { printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n", @@ -531,7 +531,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0])); - check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f); + check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f); } } @@ -1345,9 +1345,18 @@ int main(int argc, const char ** argv) { x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); ggml_set_param(ctx0, x[0]); - struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0])); + float eps = 1e-6f; + // dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work + // instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0) + struct ggml_tensor * f = ggml_sum(ctx0, + ggml_log(ctx0, + ggml_add1(ctx0, + ggml_scale(ctx0, + ggml_soft_max(ctx0, x[0]), + ggml_new_f32(ctx0, 1.0f - eps)), + ggml_new_f32(ctx0, eps)))); - check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); + check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY); } } @@ -1358,15 +1367,26 @@ int main(int argc, const char ** argv) { int64_t ne2[4]; get_random_dims(ne2, 4); - for (int ndims = 1; ndims <= 3; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); + for (int ndims = 1; ndims <= 4; ++ndims) { + x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f); + // the second argument to cross_entropy_loss must sum up to 1 for each row + int nr = ggml_nrows(x[1]); + int nc = ggml_nelements(x[1]) / nr; + for (int ir = 0; ir < nr; ++ir) { + float sum = 0; + for (int ic = 0; ic < nc; ++ic) { + sum += ((float *) x[1]->data)[ic + ir*nc]; + } + for (int ic = 0; ic < nc; ++ic) { + ((float *) x[1]->data)[ic + ir*nc] /= sum; + } + } ggml_set_param(ctx0, x[0]); - struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1])); + struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]); - check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY); - // finite differences regularly fails! + check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY); } } @@ -1473,7 +1493,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); + check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); } } } @@ -1514,7 +1534,7 @@ int main(int argc, const char ** argv) { struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); + check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY); } } }