diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3deff4077..de005f3e3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -37,6 +37,7 @@ else() add_subdirectory(save-load-state) add_subdirectory(benchmark) add_subdirectory(baby-llama) + add_subdirectory(train-text-from-scratch) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 5573c154b..e5639da37 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -79,34 +79,39 @@ struct ggml_tensor * randomize_tensor_normal( int ndims, const int64_t ne[], struct random_normal_distribution * rnd) { + float scale = 1.0; // xavier switch (ndims) { case 1: + scale /= sqrtf(ne[0]); for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i0] = frand_normal(rnd); + ((float *)tensor->data)[i0] = scale * frand_normal(rnd); } break; case 2: + scale /= sqrtf(ne[0]+ne[1]); for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i1*ne[0] + i0] = scale * frand_normal(rnd); } } break; case 3: + scale /= sqrtf(ne[0]+ne[1]); for (int i2 = 0; i2 < ne[2]; i2++) { for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd); } } } break; case 4: + scale /= sqrtf(ne[0]+ne[1]); for (int i3 = 0; i3 < ne[3]; i3++) { for (int i2 = 0; i2 < ne[2]; i2++) { for (int i1 = 0; i1 < ne[1]; i1++) { for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand_normal(rnd); + ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = scale * frand_normal(rnd); } } } diff --git a/examples/train-text-from-scratch/CMakeLists.txt b/examples/train-text-from-scratch/CMakeLists.txt new file mode 100644 index 000000000..1a44c4961 --- /dev/null +++ b/examples/train-text-from-scratch/CMakeLists.txt @@ -0,0 +1,4 @@ +set(TARGET train-text-from-scratch) +add_executable(${TARGET} train-text-from-scratch.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/train-text-from-scratch/README.md b/examples/train-text-from-scratch/README.md new file mode 100644 index 000000000..5344d1f52 --- /dev/null +++ b/examples/train-text-from-scratch/README.md @@ -0,0 +1,22 @@ +# train-text-from-scratch + +Basic usage instructions: + +```bash +# get training data +wget https://github.com/brunoklein99/deep-learning-notes/blob/master/shakespeare.txt + +# train +./bin/train-text-from-scratch \ + --vocab-model ../models/ggml-vocab.bin \ + --ctx 64 --embd 256 --head 8 --layer 16 \ + --checkpoint-in chk-shakespeare-256x16.bin \ + --checkpoint-out chk-shakespeare-256x16.bin \ + --model-out ggml-shakespeare-256x16-f32.bin \ + --train-data "shakespeare.txt" \ + -t 6 -b 16 -n 32 --seed 1 --adam-iter 16 \ + --print-details-interval 0 --predict 16 --use-flash + +# predict +./bin/main -m ggml-shakespeare-256x16-f32.bin +``` diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp new file mode 100644 index 000000000..51271b497 --- /dev/null +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -0,0 +1,3399 @@ +#include "ggml.h" +#include "llama.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +struct random_normal_distribution { + std::mt19937 gen; + std::normal_distribution rd; + float min; + float max; +}; + + +struct random_uniform_distribution { + std::mt19937 gen; + std::uniform_real_distribution rd; +}; + +void init_random_normal_distribution(struct random_normal_distribution * rnd, int seed, float mean, float std, float min, float max) { + rnd->gen = std::mt19937(seed); + rnd->rd = std::normal_distribution{mean, std}; + rnd->min = min; + rnd->max = max; +} + +void init_random_uniform_distribution(struct random_uniform_distribution * rnd, int seed, float min, float max) { + rnd->gen = std::mt19937(seed); + rnd->rd = std::uniform_real_distribution{min, max}; +} + +int clamp(const int v, const int min, const int max) { + return ((v < min) ? (min) : (v > max) ? (max) : v); +} + +float fclamp(const float v, const float min, const float max) { + return ((v < min) ? (min) : (v > max) ? (max) : v); +} + +float frand() { + return (float)rand()/(float)RAND_MAX; +} + +float frand_normal(struct random_normal_distribution * rnd) { + return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max); +} + +float frand_uniform(struct random_uniform_distribution * rnd) { + return rnd->rd(rnd->gen); +} + +struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { + float scale = 1.0f; // xavier + switch (tensor->n_dims) { + case 1: + scale /= sqrtf(tensor->ne[0]); + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); + *dst = scale * frand_normal(rnd); + } + break; + case 2: + scale /= sqrtf(tensor->ne[0]+tensor->ne[1]); + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + *dst = scale * frand_normal(rnd); + } + } + break; + case 3: + scale /= sqrtf(tensor->ne[0]+tensor->ne[1]); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); + *dst = scale * frand_normal(rnd); + } + } + } + break; + case 4: + scale /= sqrtf(tensor->ne[0]+tensor->ne[1]); + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); + *dst = scale * frand_normal(rnd); + } + } + } + } + break; + default: + assert(false); + }; + return tensor; +} + +struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) { + switch (tensor->n_dims) { + case 1: + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); + *dst = frand_uniform(rnd); + } + break; + case 2: + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + *dst = frand_uniform(rnd); + } + } + break; + case 3: + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); + *dst = frand_uniform(rnd); + } + } + } + break; + case 4: + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); + *dst = frand_uniform(rnd); + } + } + } + } + break; + default: + assert(false); + }; + return tensor; +} + +struct llama_vocab { + using id = int32_t; + using token = std::string; + + struct token_score { + token tok; + float score; + }; + + std::unordered_map token_to_id; + std::vector id_to_token; +}; + +struct my_llama_hparams { + uint32_t n_vocab = 32000; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_mult = 4; + uint32_t n_head = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + + bool operator!=(const my_llama_hparams& other) const { + return memcmp(this, &other, sizeof(my_llama_hparams)); + } +}; + +struct my_llama_layer { + // normalization + struct ggml_tensor * attention_norm; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + + // normalization + struct ggml_tensor * ffn_norm; + + // ff + struct ggml_tensor * w1; + struct ggml_tensor * w2; + struct ggml_tensor * w3; +}; + +struct my_llama_kv_cache { + struct ggml_context * ctx = NULL; + + struct ggml_tensor * k; + struct ggml_tensor * v; + + // llama_ctx_buffer buf; + + int n; // number of tokens currently in the cache +}; + +struct my_llama_model { + struct ggml_context * ctx = NULL; + + my_llama_hparams hparams; + + struct ggml_tensor * tok_embeddings; + + struct ggml_tensor * norm; + struct ggml_tensor * output; + + std::vector layers; + + uint32_t train_its = 0; + uint32_t train_samples = 0; + uint32_t train_tokens = 0; +}; + +uint32_t get_n_ff(const struct my_llama_hparams* hparams) { + const uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult; + return n_ff; +} + +void print_params(struct my_llama_hparams * params) { + printf("%s: n_vocab: %d\n", __func__, params->n_vocab); + printf("%s: n_ctx: %d\n", __func__, params->n_ctx); + printf("%s: n_embd: %d\n", __func__, params->n_embd); + printf("%s: n_mult: %d\n", __func__, params->n_mult); + printf("%s: n_head: %d\n", __func__, params->n_head); + printf("%s: n_ff: %d\n", __func__, get_n_ff(params)); + printf("%s: n_layer: %d\n", __func__, params->n_layer); + printf("%s: n_rot: %d\n", __func__, params->n_rot); +} + +void init_model(struct my_llama_model * model) { + const auto & hparams = model->hparams; + + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_vocab = hparams.n_vocab; + + const uint32_t n_ff = get_n_ff(&hparams); + + struct ggml_context * ctx = model->ctx; + + model->train_its = 0; + model->train_samples = 0; + model->train_tokens = 0; + + model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + + ggml_set_name(model->tok_embeddings, "tok_embeddings.weight"); + ggml_set_name(model->norm, "norm.weight"); + ggml_set_name(model->output, "output.weight"); + + model->layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + + std::string layers_i = "layers." + std::to_string(i); + + layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + + layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); + layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + + ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str()); + + ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str()); + ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str()); + ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str()); + ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str()); + + ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str()); + + // 'layers.10.feed_forward.w1.weight' has length of 32. + // ggml_tensor->name only has 32 characters, but we need one more for the '\0' terminator. + // ggml_set_name will set the last character to '\0', so we can only store 'layers.10.feed_forward.w1.weigh'. + // when saving llama compatible model the tensors names will miss a character. + // ggml_set_name(layer.w1, (layers_i + ".feed_forward.w1.weight").c_str()); + // ggml_set_name(layer.w2, (layers_i + ".feed_forward.w2.weight").c_str()); + // ggml_set_name(layer.w3, (layers_i + ".feed_forward.w3.weight").c_str()); + + strncpy(layer.w1->name, (layers_i + ".feed_forward.w1.weight").c_str(), sizeof(layer.w1->name)); + strncpy(layer.w2->name, (layers_i + ".feed_forward.w2.weight").c_str(), sizeof(layer.w2->name)); + strncpy(layer.w3->name, (layers_i + ".feed_forward.w3.weight").c_str(), sizeof(layer.w3->name)); + layer.w1->padding[0] = 0; + layer.w2->padding[0] = 0; + layer.w3->padding[0] = 0; + } +} + +void set_param_model(struct my_llama_model * model) { + const auto& hparams = model->hparams; + + const uint32_t n_layer = hparams.n_layer; + + struct ggml_context* ctx = model->ctx; + + ggml_set_param(ctx, model->tok_embeddings); + ggml_set_param(ctx, model->norm); + ggml_set_param(ctx, model->output); + + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + + ggml_set_param(ctx, layer.attention_norm); + ggml_set_param(ctx, layer.wq); + ggml_set_param(ctx, layer.wk); + ggml_set_param(ctx, layer.wv); + ggml_set_param(ctx, layer.wo); + ggml_set_param(ctx, layer.ffn_norm); + ggml_set_param(ctx, layer.w1); + ggml_set_param(ctx, layer.w2); + ggml_set_param(ctx, layer.w3); + } +} + +void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) { + const auto & hparams = model->hparams; + + const uint32_t n_layer = hparams.n_layer; + + struct random_normal_distribution rnd; + init_random_normal_distribution(&rnd, seed, mean, std, min, max); + + randomize_tensor_normal(model->tok_embeddings, &rnd); + randomize_tensor_normal(model->norm, &rnd); + randomize_tensor_normal(model->output, &rnd); + + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + randomize_tensor_normal(layer.attention_norm, &rnd); + + randomize_tensor_normal(layer.wq, &rnd); + randomize_tensor_normal(layer.wk, &rnd); + randomize_tensor_normal(layer.wv, &rnd); + randomize_tensor_normal(layer.wo, &rnd); + + randomize_tensor_normal(layer.ffn_norm, &rnd); + + randomize_tensor_normal(layer.w1, &rnd); + randomize_tensor_normal(layer.w2, &rnd); + randomize_tensor_normal(layer.w3, &rnd); + } +} + +bool init_kv_cache(struct my_llama_kv_cache* cache, struct my_llama_model * model, int n_batch) { + const auto & hparams = model->hparams; + + const uint32_t n_ctx = hparams.n_ctx; + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_layer = hparams.n_layer; + + const int64_t n_mem = n_layer*n_ctx*n_batch; + const int64_t n_elements = n_embd*n_mem; + + // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + + // struct ggml_init_params params; + // params.mem_size = cache.buf.size; + // params.mem_buffer = cache.buf.addr; + // params.no_alloc = false; + if (!cache->ctx) { + struct ggml_init_params params; + params.mem_size = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024; + params.mem_buffer = NULL; + params.no_alloc = false; + + cache->ctx = ggml_init(params); + + if (!cache->ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + } + + cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); + cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); + + return true; +} + +struct ggml_tensor * forward( + struct my_llama_model * model, + struct my_llama_kv_cache * cache, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * tokens_input, + const int n_tokens, + const int n_past) { + + const int N = n_tokens; + + struct my_llama_kv_cache& kv_self = *cache; + const auto & hparams = model->hparams; + const int n_ctx = hparams.n_ctx; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + + struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens)); + + struct ggml_tensor * kc = kv_self.k; + struct ggml_tensor * vc = kv_self.v; + + // inpL shape [n_embd,N,1,1] + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // lctx.use_buf(ctx0, 0); + + // norm + { + // cur shape [n_embd,N,1,1] + cur = ggml_rms_norm(ctx0, inpL); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].attention_norm, cur), + cur); + } + + // self-attention + { + // compute Q and K and RoPE them + // wq shape [n_embd, n_embd, 1, 1] + // wk shape [n_embd, n_embd, 1, 1] + // Qcur shape [n_embd/n_head, n_head, N, 1] + // Kcur shape [n_embd/n_head, n_head, N, 1] + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + + // store key and value to memory + { + // compute the transposed [N, n_embd] V matrix + // wv shape [n_embd, n_embd, 1, 1] + // Vcur shape [n_embd, N, 1, 1] + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N))); + + // kv_self.k shape [n_embd * n_ctx * n_layer, 1] + // kv_self.v shape [n_embd * n_ctx * n_layer, 1] + // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0] + // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0] + + /* { + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } //*/ + + kc = ggml_set_1d_inplace(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + vc = ggml_set_2d_inplace(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + } + + // Qcur shape [n_embd/n_head, n_head, N, 1] + // Q shape [n_embd/n_head, N, n_head, 1] + struct ggml_tensor * Q = + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); + + // kv_self.k shape [n_embd * n_ctx * n_layer, 1] + // K shape [n_embd/n_head, n_past + N, n_head, 1] + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // K * Q + // KQ shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + + // KQ_masked = mask_past(KQ_scaled) + // KQ_masked shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + // KQ_soft_max shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // split cached V into n_head heads + //// V shape [n_past + N, n_embd/n_head, n_head, 1] + // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1] + struct ggml_tensor * V = + ggml_view_3d(ctx0, vc, + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(vc), + n_ctx*ggml_element_size(vc)*n_embd/n_head, + il*n_ctx*ggml_element_size(vc)*n_embd); + + // KQV shape [n_embd/n_head, N, n_head, 1] + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // KQV_merged shape [n_embd/n_head, n_head, N, 1] + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + // KQV_merged shape + + // cur = KQV_merged.contiguous().view(n_embd, N) + // cur shape [n_embd,N,1,1] + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N); + // cur = ggml_cpy(ctx0, + // KQV_merged, + // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + // cur shape [n_embd,N,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].wo, + cur); + } + + // lctx.use_buf(ctx0, 1); + + // inpFF shape [n_embd,N,1,1] + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + + // feed-forward network + { + // norm + { + // cur shape [n_embd,N,1,1] + cur = ggml_rms_norm(ctx0, inpFF); + + // cur = ffn_norm*cur + // cur shape [n_embd,N,1,1] + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), + cur); + } + + // tmp shape [n_ff,N,1,1] + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model->layers[il].w3, + cur); + + // cur shape [n_ff,N,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w1, + cur); + + // SILU activation + // cur shape [n_ff,N,1,1] + cur = ggml_silu(ctx0, cur); + + // cur shape [n_ff,N,1,1] + cur = ggml_mul(ctx0, cur, tmp); + + // cur shape [n_embd,N,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w2, + cur); + } + + // cur shape [n_embd,N,1,1] + cur = ggml_add(ctx0, cur, inpFF); + + // input for next layer + // inpL shape [n_embd,N,1,1] + inpL = cur; + } + + // norm + { + + // inpL shape [n_embd,N,1,1] + inpL = ggml_rms_norm(ctx0, inpL); + + // inpL = norm*inpL + // inpL shape [n_embd,N,1,1] + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model->norm, inpL), + inpL); + + //embeddings = inpL; + } + + // lm_head + // inpL shape [n_vocab,N,1,1] + inpL = ggml_mul_mat(ctx0, model->output, inpL); + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return inpL; +} + +void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) { + GGML_ASSERT(tensor->n_dims == 1); + GGML_ASSERT(tensor->ne[0] == ne0); +} + +void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) { + GGML_ASSERT(tensor->n_dims == 2); + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); +} + +void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) { + GGML_ASSERT(tensor->n_dims == 3); + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == ne2); +} + +void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + GGML_ASSERT(tensor->n_dims == 4); + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == ne2); + GGML_ASSERT(tensor->ne[3] == ne3); +} + +struct ggml_tensor * forward_batch( + struct my_llama_model * model, + struct my_llama_kv_cache * cache, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * tokens_input, + const int n_tokens, + const int n_past, + const int n_batch) { + + const int N = n_tokens; + + struct my_llama_kv_cache& kv_self = *cache; + const auto & hparams = model->hparams; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + + struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); + memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + struct ggml_tensor * kc = kv_self.k; + struct ggml_tensor * vc = kv_self.v; + + // inpL shape [n_embd,N*n_batch,1] + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); + assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // lctx.use_buf(ctx0, 0); + + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].attention_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // self-attention + { + // compute Q and K and RoPE them + // wq shape [n_embd, n_embd, 1, 1] + // wk shape [n_embd, n_embd, 1, 1] + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Kcur shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); + assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); + + // store key and value to memory + { + // compute the transposed [N, n_embd] V matrix + // wv shape [n_embd, n_embd, 1, 1] + // Vcur shape [N, n_embd, n_batch, 1] + struct ggml_tensor * Vcur = ggml_cont(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_mul_mat(ctx0, + model->layers[il].wv, + cur), + n_embd, N, n_batch), + 1, 0, 2, 3)); + assert_shape_3d(Vcur, N, n_embd, n_batch); + + // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] + // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] + // k shape [n_embd * N, n_batch] == kv_self.k[:,n_past:n_past+N,:,il] + // v shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il] + + /* { + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } //*/ + + kc = ggml_set_2d_inplace(ctx0, kc, + ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch), + ggml_element_size(kc)*n_embd*n_ctx, + (ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past)); + vc = ggml_set_2d_inplace(ctx0, vc, + ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch), + ggml_element_size(vc)*n_ctx*n_embd, + ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx)); + + assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer); + assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer); + } + + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Q shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * Q = + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); + assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); + + // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] + // K shape [n_embd/n_head, n_past + N, n_head, n_batch] + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_4d(ctx0, + ggml_view_3d(ctx0, + kc, + n_embd, + (n_past + N), + n_batch, + n_embd*ggml_element_size(kc), + n_ctx*n_embd*ggml_element_size(kc), + il*n_batch*n_ctx*n_embd*ggml_element_size(kc)), + n_embd/n_head, n_head, n_past + N, n_batch), + 0, 2, 1, 3); + assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch); + + // K * Q + // KQ shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + assert_shape_4d(KQ, n_past + N, N, n_head, n_batch); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch); + + // KQ_masked = mask_past(KQ_scaled) + // KQ_masked shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch); + + // KQ = soft_max(KQ_masked) + // KQ_soft_max shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch); + + // split cached V into n_head heads + // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] + // V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il] + struct ggml_tensor * V = + ggml_view_4d(ctx0, vc, + n_past + N, n_embd/n_head, n_head, n_batch, + ggml_element_size(vc)*n_ctx, + ggml_element_size(vc)*n_ctx*n_embd/n_head, + ggml_element_size(vc)*n_ctx*n_embd, + il*n_batch*n_ctx*n_embd*ggml_element_size(vc)); + assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch); + + // KQV shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); + // KQV_merged shape + + // cur = KQV_merged.contiguous().view(n_embd, N) + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); + assert_shape_2d(cur, n_embd, N*n_batch); + // cur = ggml_cpy(ctx0, + // KQV_merged, + // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].wo, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // lctx.use_buf(ctx0, 1); + + // inpFF shape [n_embd,N*n_batch,1,1] + struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); + assert_shape_2d(inpFF, n_embd, N*n_batch); + + // feed-forward network + { + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = ffn_norm*cur + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // tmp shape [n_ff,N*n_batch,1,1] + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model->layers[il].w3, + cur); + assert_shape_2d(tmp, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w1, + cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // SILU activation + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_silu(ctx0, cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul(ctx0, cur, tmp); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w2, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_add_inplace(ctx0, cur, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // input for next layer + // inpL shape [n_embd,N*n_batch,1,1] + inpL = cur; + assert_shape_2d(inpL, n_embd, N*n_batch); + } + + // norm + { + + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(inpL, n_embd, N*n_batch); + + // inpL = norm*inpL + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model->norm, inpL), + inpL); + + assert_shape_2d(inpL, n_embd, N*n_batch); + + //embeddings = inpL; + } + + // lm_head + // inpL shape [n_vocab,N*n_batch,1,1] + inpL = ggml_mul_mat(ctx0, model->output, inpL); + assert_shape_2d(inpL, n_vocab, N*n_batch); + + { + // inpL shape [n_vocab,N,n_batch,1] + inpL = ggml_reshape_3d(ctx0, + inpL, + n_vocab, N, n_batch); + assert_shape_3d(inpL, n_vocab, N, n_batch); + } + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return inpL; +} + +struct ggml_tensor * forward_batch_wo_cache( + struct my_llama_model * model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * tokens_input, + const int n_tokens, + const int n_batch) { + + const int n_past = 0; + const int N = n_tokens; + + const auto & hparams = model->hparams; + //const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + + struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); + memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + // inpL shape [n_embd,N*n_batch,1] + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); + assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // lctx.use_buf(ctx0, 0); + + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].attention_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // self-attention + { + // compute Q and K and RoPE them + // wq shape [n_embd, n_embd, 1, 1] + // wk shape [n_embd, n_embd, 1, 1] + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Kcur shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); + assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); + + // Vcur shape [N, n_batch, n_embd/n_head, n_head] + struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); + assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); + + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Q shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * Q = + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); + assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); + + // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] + // K shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * K = + ggml_permute(ctx0, + Kcur, + 0, 2, 1, 3); + assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); + + // K * Q + // KQ shape [N, N, n_head, n_batch] + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + assert_shape_4d(KQ, N, N, n_head, n_batch); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled shape [N, N, n_head, n_batch] + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + assert_shape_4d(KQ_scaled, N, N, n_head, n_batch); + + // KQ_masked = mask_past(KQ_scaled) + // KQ_masked shape [N, N, n_head, n_batch] + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + assert_shape_4d(KQ_masked, N, N, n_head, n_batch); + + // KQ = soft_max(KQ_masked) + // KQ_soft_max shape [N, N, n_head, n_batch] + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + assert_shape_4d(KQ_soft_max, N, N, n_head, n_batch); + + // Vcur shape [N, n_batch, n_embd/n_head, n_head] + // V shape [N, n_embd/n_head, n_head, n_batch] + struct ggml_tensor * V = + ggml_permute(ctx0, + Vcur, + 0, 3, 1, 2); + assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); + + // KQV shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); + // KQV_merged shape + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); + assert_shape_2d(cur, n_embd, N*n_batch); + + // projection (no bias) + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].wo, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // lctx.use_buf(ctx0, 1); + + // inpFF shape [n_embd,N*n_batch,1,1] + struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); + assert_shape_2d(inpFF, n_embd, N*n_batch); + + // feed-forward network + { + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = ffn_norm*cur + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // tmp shape [n_ff,N*n_batch,1,1] + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model->layers[il].w3, + cur); + assert_shape_2d(tmp, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w1, + cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // SILU activation + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_silu(ctx0, cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul(ctx0, cur, tmp); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w2, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_add_inplace(ctx0, cur, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // input for next layer + // inpL shape [n_embd,N*n_batch,1,1] + inpL = cur; + assert_shape_2d(inpL, n_embd, N*n_batch); + } + + // norm + { + + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(inpL, n_embd, N*n_batch); + + // inpL = norm*inpL + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model->norm, inpL), + inpL); + + assert_shape_2d(inpL, n_embd, N*n_batch); + + //embeddings = inpL; + } + + // lm_head + // inpL shape [n_vocab,N*n_batch,1,1] + inpL = ggml_mul_mat(ctx0, model->output, inpL); + assert_shape_2d(inpL, n_vocab, N*n_batch); + + { + // inpL shape [n_vocab,N,n_batch,1] + inpL = ggml_reshape_3d(ctx0, + inpL, + n_vocab, N, n_batch); + assert_shape_3d(inpL, n_vocab, N, n_batch); + } + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return inpL; +} + +struct ggml_tensor * forward_batch_wo_cache_flash_attn( + struct my_llama_model * model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * tokens_input, + const int n_tokens, + const int n_batch) { + + const int n_past = 0; + const int N = n_tokens; + + const auto & hparams = model->hparams; + //const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + + struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); + memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); + assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // norm + { + cur = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].attention_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // self-attention + { + // compute Q and K and RoPE them + // wq shape [n_embd, n_embd, 1, 1] + // wk shape [n_embd, n_embd, 1, 1] + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); + assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); + + struct ggml_tensor * Vcur = ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, cur, model->layers[il].wv), N, n_batch, n_embd/n_head, n_head); + assert_shape_4d(Vcur, N, n_batch, n_embd/n_head, n_head); + + struct ggml_tensor * Q = + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); + assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); + + struct ggml_tensor * K = + ggml_permute(ctx0, + Kcur, + 0, 2, 1, 3); + assert_shape_4d(K, n_embd/n_head, N, n_head, n_batch); + + struct ggml_tensor * V = + ggml_permute(ctx0, + Vcur, + 0, 3, 1, 2); + assert_shape_4d(V, N, n_embd/n_head, n_head, n_batch); + + bool masked = true; + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, masked); + assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); + assert_shape_2d(cur, n_embd, N*n_batch); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model->layers[il].wo, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + struct ggml_tensor * inpFF = ggml_add_inplace(ctx0, cur, inpSA); + assert_shape_2d(inpFF, n_embd, N*n_batch); + + // feed-forward network + { + // norm + { + cur = ggml_rms_norm(ctx0, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = ffn_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model->layers[il].w3, + cur); + assert_shape_2d(tmp, n_ff, N*n_batch); + + cur = ggml_mul_mat(ctx0, + model->layers[il].w1, + cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // SILU activation + cur = ggml_silu(ctx0, cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + cur = ggml_mul(ctx0, cur, tmp); + assert_shape_2d(cur, n_ff, N*n_batch); + + cur = ggml_mul_mat(ctx0, + model->layers[il].w2, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + cur = ggml_add_inplace(ctx0, cur, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // input for next layer + inpL = cur; + assert_shape_2d(inpL, n_embd, N*n_batch); + } + + // norm + { + + inpL = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(inpL, n_embd, N*n_batch); + + // inpL = norm*inpL + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model->norm, inpL), + inpL); + + assert_shape_2d(inpL, n_embd, N*n_batch); + } + + // lm_head + inpL = ggml_mul_mat(ctx0, model->output, inpL); + assert_shape_2d(inpL, n_vocab, N*n_batch); + + { + inpL = ggml_reshape_3d(ctx0, + inpL, + n_vocab, N, n_batch); + assert_shape_3d(inpL, n_vocab, N, n_batch); + } + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return inpL; +} + +// expand the graph nodes without creating leafs. +struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) { + // check if already visited + for (int i = 0; i < g->n_nodes; i++) { + if (g->nodes[i] == t) { + return t; + } + } + + for (int i = 0; i < g->n_leafs; i++) { + if (g->leafs[i] == t) { + return t; + } + } + + if (t->src0) { + expand(g, t->src0); + } + + if (t->src1) { + expand(g, t->src1); + } + + for (int i = 0; i < GGML_MAX_OPT; ++i) { + if (t->opt[i]) { + expand(g, t->opt[i]); + } + } + + GGML_ASSERT(g->n_nodes < GGML_MAX_NODES); + + if (strlen(t->name) == 0) { + snprintf(t->name, sizeof(t->name), "node_%d", g->n_nodes); + } + + g->nodes[g->n_nodes] = t; + g->grads[g->n_nodes] = t->grad; + g->n_nodes++; + return t; +} + +void graph_set_leafs_grads(struct ggml_cgraph * g) { + // moves leaf nodes to g->leafs. + // i.e. g->n_nodes might change. + int n_nodes = 0; + for (int i = 0; i < g->n_nodes; ++i) { + struct ggml_tensor * node = g->nodes[i]; + const bool is_leaf = node->op == GGML_OP_NONE && node->grad == NULL; + if (is_leaf) { + GGML_ASSERT(g->n_leafs < GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + snprintf(node->name, sizeof(node->name), "leaf_%d", g->n_leafs); + } + + g->leafs[g->n_leafs] = node; + g->n_leafs++; + } else { + GGML_ASSERT(n_nodes < GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + snprintf(node->name, sizeof(node->name), "node_%d", n_nodes); + } + + g->nodes[n_nodes] = node; + g->grads[n_nodes] = node->grad; + n_nodes++; + } + } + for (int i=n_nodes; i < g->n_nodes; ++i) { + g->nodes[n_nodes] = NULL; + g->grads[n_nodes] = NULL; + } + g->n_nodes = n_nodes; +} + +struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( + struct my_llama_model * model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_tensor * * logits, + struct ggml_tensor * tokens_input, + struct ggml_tensor * targets, + void * compute_buf_0, + void * compute_buf_1, + size_t size_buf_0, + size_t size_buf_1, + const int n_tokens, + const int n_batch) { + + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + + const int n_past = 0; + const int N = n_tokens; + + gf->n_nodes = 0; + gf->n_leafs = 0; + gf->work_size = 0; + gf->perf_runs = 0; + gf->perf_cycles = 0; + gf->perf_time_us = 0; + gf->work = NULL; + + const auto & hparams = model->hparams; + //const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + const int rope_mode = 0; + + int last_buf = -1; + size_t buf_offs[2] = { 0, 0 }; + size_t buf_size[2] = { size_buf_0, + size_buf_1 }; + void * buf_data[2] = { compute_buf_0, + compute_buf_1 }; + auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) { + size_t last_offs = 0; + last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + if (last_buf >= 0) { + buf_offs[last_buf] = last_offs; + } + if (buf >= 0) { + size_t offs = buf_offs[buf]; + size_t size = buf_size[buf]; + void * data = buf_data[buf]; + ggml_set_scratch(ctx0, { offs, size, data, }); + } + last_buf = buf; + }; + + bool track_max_mem = false; + size_t buf_maxs[2] = { 0, 0 }; + + auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) { + if (buf < 0) return; + if (track_max_mem) { + size_t last_offs = 0; + last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + if (last_buf >= 0) { + buf_offs[last_buf] = last_offs; + buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]); + } + } + buf_offs[buf] = 0; + if (track_max_mem && last_buf >= 0) { + size_t offs = buf_offs[last_buf]; + size_t size = buf_size[last_buf]; + void * data = buf_data[last_buf]; + ggml_set_scratch(ctx0, { offs, size, data, }); + } + }; + + + auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { + int64_t ne0 = n_embd/n_head; + int64_t ne1 = N; + int64_t ne2 = n_head; + int64_t ne3 = n_batch; + size_t nb0 = ggml_element_size(t); + size_t nb1 = nb0*ne0; + size_t nb2 = nb1*ne1; + size_t nb3 = nb2*ne2; + size_t offset = 0; + return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + }; + + auto view__k = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { + int64_t ne0 = n_embd/n_head; + int64_t ne1 = N; + int64_t ne2 = n_head; + int64_t ne3 = n_batch; + size_t nb0 = ggml_element_size(t); + size_t nb1 = nb0*ne0; + size_t nb2 = nb1*ne1; + size_t nb3 = nb2*ne2; + size_t offset = nb3*ne3; + return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + }; + + auto view__v = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { + int64_t ne0 = N; + int64_t ne1 = n_embd/n_head; + int64_t ne2 = n_head; + int64_t ne3 = n_batch; + size_t nb0 = ggml_element_size(t); + size_t nb1 = nb0*ne0; + size_t nb2 = nb1*ne1; + size_t nb3 = nb2*ne2; + size_t offset = 2*nb3*ne3; + return ggml_view_4d(ctx0, t, ne0, ne1, ne2, ne3, nb1, nb2, nb3, offset); + }; + + auto add_or_set = [ctx0] (struct ggml_tensor * a, struct ggml_tensor * b) -> struct ggml_tensor * { + if (a == NULL) { + return b; + } else { + return ggml_add_inplace(ctx0, a, b); + } + }; + + use_buf(-1); + + model->tok_embeddings->grad = NULL; + model->norm->grad = NULL; + model->output->grad = NULL; + + for (int il = 0; il < n_layer; ++il) { + struct my_llama_layer & layer = model->layers[il]; + layer.attention_norm->grad = NULL; + layer.wq->grad = NULL; + layer.wk->grad = NULL; + layer.wv->grad = NULL; + layer.wo->grad = NULL; + layer.ffn_norm->grad = NULL; + layer.w1->grad = NULL; + layer.w2->grad = NULL; + layer.w3->grad = NULL; + } + + clr_buf(0); + clr_buf(1); + + use_buf(-1); + + struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); + memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); + + use_buf(-1); + + struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch); + + // need to remember these for the backward pass + std::vector t02L; t02L.resize(n_layer, NULL); + std::vector t03L; t03L.resize(n_layer, NULL); + std::vector t04L; t04L.resize(n_layer, NULL); + std::vector t05L; t05L.resize(n_layer, NULL); + std::vector t06L; t06L.resize(n_layer, NULL); + std::vector t07L; t07L.resize(n_layer, NULL); + std::vector t08L; t08L.resize(n_layer, NULL); + std::vector t09L; t09L.resize(n_layer, NULL); + std::vector t10L; t10L.resize(n_layer, NULL); + std::vector t11L; t11L.resize(n_layer, NULL); + std::vector t12L; t12L.resize(n_layer, NULL); + std::vector t13L; t13L.resize(n_layer, NULL); + std::vector t14L; t14L.resize(n_layer, NULL); + std::vector t15L; t15L.resize(n_layer, NULL); + std::vector t16L; t16L.resize(n_layer, NULL); + std::vector t17L; t17L.resize(n_layer, NULL); + std::vector t18L; t18L.resize(n_layer, NULL); + std::vector t19L; t19L.resize(n_layer, NULL); + std::vector t20L; t20L.resize(n_layer, NULL); + std::vector t21L; t21L.resize(n_layer, NULL); + std::vector t22L; t22L.resize(n_layer, NULL); + std::vector t23L; t23L.resize(n_layer, NULL); + std::vector t24L; t24L.resize(n_layer, NULL); + std::vector t25L; t25L.resize(n_layer, NULL); + std::vector t26L; t26L.resize(n_layer, NULL); + std::vector t27L; t27L.resize(n_layer, NULL); + std::vector t28L; t28L.resize(n_layer, NULL); + std::vector t29L; t29L.resize(n_layer, NULL); + std::vector t30L; t30L.resize(n_layer, NULL); + + struct ggml_tensor * cur = t01; + + for (int il = 0; il < n_layer; ++il) { + clr_buf(0); + struct my_llama_layer & layer = model->layers[il]; + // tensors with values necessary for backward pass are in persistent buf(-1) + // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed. + use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); + use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch); + use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); + use_buf(-1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch); + use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); + use_buf(-1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd); + use_buf(-1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head); + use_buf(-1); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); + use_buf(-1); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch); + use_buf(-1); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch); + use_buf(-1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch); + use_buf( 0); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch); + use_buf(-1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch); + use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); + use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); + use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); + use_buf(-1); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch); + use_buf(-1); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch); + use_buf(-1); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch); + use_buf( 0); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch); + t02L[il] = t02; + t03L[il] = t03; + t04L[il] = t04; + t05L[il] = t05; + t06L[il] = t06; + t07L[il] = t07; + t08L[il] = t08; + t09L[il] = t09; + t10L[il] = t10; + t11L[il] = t11; + t12L[il] = t12; + t13L[il] = t13; + t14L[il] = t14; + t15L[il] = t15; + t16L[il] = t16; + t17L[il] = t17; + t18L[il] = t18; + t19L[il] = t19; + t20L[il] = t20; + t21L[il] = t21; + t22L[il] = t22; + t23L[il] = t23; + t24L[il] = t24; + t25L[il] = t25; + t26L[il] = t26; + t27L[il] = t27; + t28L[il] = t28; + t29L[il] = t29; + t30L[il] = t30; + + cur = t30; + } + clr_buf(0); + use_buf(0); + struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); + struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); + use_buf(-1); + struct ggml_tensor * t34 = expand(gf, ggml_mul_mat (ctx0, model->output, t33)); assert_shape_2d(t34, n_vocab, N*n_batch); + struct ggml_tensor * t35 = expand(gf, ggml_reshape_3d(ctx0, t34, n_vocab, N, n_batch)); assert_shape_3d(t35, n_vocab, N, n_batch); + struct ggml_tensor * t36 = expand(gf, ggml_cross_entropy_loss(ctx0, t35, targets)); assert_shape_1d(t36, 1); + + { + /* + tok_embeddings | grad_tok_embeddings = ggml_get_rows_back(grad_t01, t00) + L0_att_norm | grad_L0_att_norm = ggml_repeat_back(grad_t03L0, L0_att_norm.shape) + L0_wq | grad_L0_wq = ggml_out_prod(t04L0, grad_t05L0) + L0_wk | grad_L0_wk = ggml_out_prod(t04L0, grad_t08L0) + L0_wv | grad_L0_wv = ggml_out_prod(t04L0, ggml_transpose(grad_t11L0)) + L0_wo | grad_L0_wo = ggml_out_prod(t19L0, grad_t20L0) + L0_ffn_norm | grad_L0_ffn_norm = ggml_repeat_back(grad_t23L0, L0_ffn_norm.shape) + L0_w1 | grad_L0_w1 = ggml_out_prod(t24L0, grad_t26L0) + L0_w2 | grad_L0_w2 = ggml_out_prod(t28L0, grad_t29L0) + L0_w3 | grad_L0_w3 = ggml_out_prod(t24L0, grad_t25L0) + L1_att_norm | grad_L1_att_norm = ggml_repeat_back(grad_t03L1, L1_att_norm.shape) + L1_wq | grad_L1_wq = ggml_out_prod(t04L1, grad_t05L1) + L1_wk | grad_L1_wk = ggml_out_prod(t04L1, grad_t08L1) + L1_wv | grad_L1_wv = ggml_out_prod(t04L1, ggml_transpose(grad_t11L1)) + L1_wo | grad_L1_wo = ggml_out_prod(t19L1, grad_t20L1) + L1_ffn_norm | grad_L1_ffn_norm = ggml_repeat_back(grad_t23L1, L1_ffn_norm.shape) + L1_w1 | grad_L1_w1 = ggml_out_prod(t24L1, grad_t26L1) + L1_w2 | grad_L1_w2 = ggml_out_prod(t28L1, grad_t29L1) + L1_w3 | grad_L1_w3 = ggml_out_prod(t24L1, grad_t25L1) + norm | grad_norm = ggml_repeat_back(grad_t32, norm.shape) + output | grad_output = ggml_out_prod(t33, grad_t34) + | + t01 = ggml_get_rows(tok_embeddings, t00) | grad_t01 = grad_t21L0 + ggml_rms_norm_back(t01, grad_t02L0) + for layer: | + t02L0*= ggml_rms_norm (t01) | grad_t02L0 = ggml_mul(grad_t04L0, t03L0) + t03L0 = ggml_repeat (L0_att_norm, t02L0_shape) | grad_t03L0 = ggml_mul(grad_t04L0, t02L0) + t04L0*= ggml_mul (t02L0, t03L0) | grad_t04L0 = ggml_out_prod(L0_wv, grad_t11L0) + ggml_out_prod(L0_wk, ggml_transpose(grad_t08L0)) + ggml_out_prod(L0_wq, ggml_transpose(grad_t05L0)) + t05L0 = ggml_mul_mat (L0_wq, t04L0) | grad_t05L0 = ggml_reshape(grad_t06L0, t05L0_shape) + t06L0 = ggml_reshape_4d (t05L0, n_embd/n_head, n_head, N, n_batch) | grad_t06L0 = ggml_rope_back(grad_t07L0) + t07L0 = ggml_rope_inplace (t06L0) | grad_t07L0 = ggml_permute_back(grad_t13L0, 0, 2, 1, 3) = ggml_permute(grad_t13L0, 0, 2, 1, 3) + t08L0 = ggml_mul_mat (L0_wk, t04L0) | grad_t08L0 = ggml_reshape(grad_t09L0, t08L0_shape) + t09L0 = ggml_reshape_4d (t08L0, n_embd/n_head, n_head, N, n_batch) | grad_t09L0 = ggml_rope_back(grad_t10L0) + t10L0 = ggml_rope_inplace (t09L0) | grad_t10L0 = ggml_permute_back(grad_t14L0, 0, 2, 1, 3) = ggml_permute(grad_t14L0, 0, 2, 1, 3) + t11L0 = ggml_mul_mat (t04L0, L0_wv) | grad_t11L0 = ggml_reshape(grad_t12L0, t11L0_shape) + t12L0 = ggml_reshape_4d (t11L0, N, n_batch, n_embd/n_head, n_head) | grad_t12L0 = ggml_permute_back(grad_t15L0, 0, 3, 1, 2) = ggml_permute(grad_t15L0, 0, 2, 3, 1) + t13L0*= ggml_permute (t07L0, 0, 2, 1, 3) | grad_t13L0 = view__q(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) + t14L0*= ggml_permute (t10L0, 0, 2, 1, 3) | grad_t14L0 = view__k(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) + t15L0*= ggml_permute (t12L0, 0, 3, 1, 2) | grad_t15L0 = view__v(ggml_flash_attn_back(t13L0, t14L0, t15L0, grad_t16L0)) + t16L0 = ggml_flash_attn (t13L0, t14L0, t15L0) | grad_t16L0 = ggml_permute_back(grad_t17L0, 0, 2, 1, 3) = ggml_permute(grad_t17L0, 0, 2, 1, 3) + t17L0 = ggml_permute (t16L0, 0, 2, 1, 3) | grad_t17L0 = grad_t18L0 + t18L0 = ggml_cont (t17L0) | grad_t18L0 = ggml_reshape(grad_t19L0, t18L0_shape) + t19L0*= ggml_reshape_2d (t18L0, n_embd, N*n_batch) | grad_t19L0 = ggml_out_prod(L0_wo, ggml_transpose(grad_t20L0)) + t20L0 = ggml_mul_mat (L0_wo, t19L0) | grad_t20L0 = grad_t21L0 + t21L0*= ggml_add (t20L0, t01) | grad_t21L0 = grad_t30L0 + ggml_rms_norm_back(t21L0, grad_t22L0) + t22L0*= ggml_rms_norm (t21L0) | grad_t22L0 = ggml_mul(grad_t24L0, t23L0) + t23L0 = ggml_repeat (L0_ffn_norm, t22L0_shape) | grad_t23L0 = ggml_mul(grad_t24L0, t22L0) + t24L0*= ggml_mul (t23L0, t22L0) | grad_t24L0 = ggml_out_prod(L0_w1, ggml_transpose(grad_t26L0)) + ggml_out_prod(L0_w3, ggml_transpose(grad_t25L0)) + t25L0*= ggml_mul_mat (L0_w3, t24L0) | grad_t25L0 = ggml_mul(grad_t28L0, t27L0) + t26L0*= ggml_mul_mat (L0_w1, t24L0) | grad_t26L0 = ggml_silu_back(t26L0, grad_t27L0) + t27L0*= ggml_silu (t26L0) | grad_t27L0 = ggml_mul(grad_t28L0, t25L0) + t28L0*= ggml_mul (t27L0, t25L0) | grad_t28L0 = ggml_out_prod(L0_w2, ggml_transpose(grad_t29L0)) + t29L0 = ggml_mul_mat (L0_w2, t28L0) | grad_t29L0 = grad_t30L0 + t30L0*= ggml_add (t21L0, t29L0) | grad_t30L0 = ggml_rms_norm_back(t30L0, grad_t02L1) + grad_t21L1 + ^ + t02L1*= ggml_rms_norm (t30L0) | grad_t02L1 = ggml_mul(grad_t04L1, t03L1) + t03L1 = ggml_repeat (L1_att_norm, t02L1_shape) | grad_t03L1 = ggml_mul(grad_t04L1, t02L1) + t04L1*= ggml_mul (t02L1, t03L1) | grad_t04L1 = ggml_out_prod(L1_wv, grad_t11L1) + ggml_out_prod(L1_wk, ggml_transpose(grad_t08L1)) + ggml_out_prod(L1_wq, ggml_transpose(grad_t05L1)) + t05L1 = ggml_mul_mat (L1_wq, t04L1) | grad_t05L1 = ggml_reshape(grad_t06L1, t05L1_shape) + t06L1 = ggml_reshape_4d (t05L1, n_embd/n_head, n_head, N, n_batch) | grad_t06L1 = ggml_rope_back(grad_t07L1) + t07L1 = ggml_rope_inplace (t06L1) | grad_t07L1 = ggml_permute_back(grad_t13L1, 0, 2, 1, 3) = ggml_permute(grad_t13L1, 0, 2, 1, 3) + t08L1 = ggml_mul_mat (L1_wk, t04L1) | grad_t08L1 = ggml_reshape(grad_t09L1, t08L1_shape) + t09L1 = ggml_reshape_4d (t08L1, n_embd/n_head, n_head, N, n_batch) | grad_t09L1 = ggml_rope_back(grad_t10L1) + t10L1 = ggml_rope_inplace (t09L1) | grad_t10L1 = ggml_permute_back(grad_t14L1, 0, 2, 1, 3) = ggml_permute(grad_t14L1, 0, 2, 1, 3) + t11L1 = ggml_mul_mat (t04L1, L1_wv) | grad_t11L1 = ggml_reshape(grad_t12L1, t11L1_shape) + t12L1 = ggml_reshape_4d (t11L1, N, n_batch, n_embd/n_head, n_head) | grad_t12L1 = ggml_permute_back(grad_t15L1, 0, 3, 1, 2) = ggml_permute(grad_t15L1, 0, 2, 3, 1) + t13L1*= ggml_permute (t07L1, 0, 2, 1, 3) | grad_t13L1 = view__q(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) + t14L1*= ggml_permute (t10L1, 0, 2, 1, 3) | grad_t14L1 = view__k(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) + t15L1*= ggml_permute (t12L1, 0, 3, 1, 2) | grad_t15L1 = view__v(ggml_flash_attn_back(t13L1, t14L1, t15L1, grad_t16L1)) + t16L1 = ggml_flash_attn (t13L1, t14L1, t15L1) | grad_t16L1 = ggml_permute_back(grad_t17L1, 0, 2, 1, 3) = ggml_permute(grad_t17L1, 0, 2, 1, 3) + t17L1 = ggml_permute (t16L1, 0, 2, 1, 3) | grad_t17L1 = grad_t18L1 + t18L1 = ggml_cont (t17L1) | grad_t18L1 = ggml_reshape(grad_t19L1, t18L1_shape) + t19L1*= ggml_reshape_2d (t18L1, n_embd, N*n_batch) | grad_t19L1 = ggml_out_prod(L1_wo, ggml_transpose(grad_t20L1)) + t20L1 = ggml_mul_mat (L1_wo, t19L1) | grad_t20L1 = grad_t21L1 + t21L1*= ggml_add (t20L1, t30L0) | grad_t21L1 = grad_t30L1 + ggml_rms_norm_back(t21L1, grad_t22L1) + t22L1*= ggml_rms_norm (t21L1) | grad_t22L1 = ggml_mul(grad_t24L1, t23L1) + t23L1 = ggml_repeat (L1_ffn_norm, t22L1_shape) | grad_t23L1 = ggml_mul(grad_t24L1, t22L1) + t24L1*= ggml_mul (t23L1, t22L1) | grad_t24L1 = ggml_out_prod(L1_w1, ggml_transpose(grad_t26L1)) + ggml_out_prod(L1_w3, ggml_transpose(grad_t25L1)) + t25L1*= ggml_mul_mat (L1_w3, t24L1) | grad_t25L1 = ggml_mul(grad_t28L1, t27L1) + t26L1*= ggml_mul_mat (L1_w1, t24L1) | grad_t26L1 = ggml_silu_back(t26L1, grad_t27L1) + t27L1*= ggml_silu (t26L1) | grad_t27L1 = ggml_mul(grad_t28L1, t25L1) + t28L1*= ggml_mul (t27L1, t25L1) | grad_t28L1 = ggml_out_prod(L1_w2, ggml_transpose(grad_t29L1)) + t29L1 = ggml_mul_mat (L1_w2, t28L1) | grad_t29L1 = grad_t30L1 + t30L1*= ggml_add (t21L1, t29L1) | grad_t30L1 = ggml_rms_norm_back(t30L1, grad_t31) + ^ + t31 = ggml_rms_norm (t30L1) | grad_t31 = ggml_mul(grad_t33, t32) + t32 = ggml_repeat (norm, t31.shape) | grad_t32 = ggml_mul(grad_t33, t31) + t33 = ggml_mul (t32, t31) | grad_t33 = ggml_out_prod(output, ggml_transpose(grad_t34)) + t34 = ggml_mul_mat (output, t33) | grad_t34 = ggml_reshape(grad_t35, t34.shape) + t35 = ggml_reshape_3d (t34, n_vocab, N, n_batch) | grad_t35 = ggml_cross_entropy_loss_back(t35, targets, grad_t36) + t36 = ggml_cross_entropy_loss(t35, targets) | grad_t36 = 1 (optimizer) + tensors marked with * need to be stored until grad computation + tensors during grad computation are all temporary + */ + } + + *gb = *gf; + + // t36->grad gets set to one by optimizer, so we need the tensor. + // initialize it with 1.0f to make sure. + use_buf(-1); + t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f)); + + use_buf(0); + t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch); + t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch); + t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch); + t32->grad = expand(gb, ggml_mul (ctx0, t33->grad, t31)); assert_shape_2d(t32->grad, n_embd, N*n_batch); + + use_buf(-1); + + model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd); + model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab); + + clr_buf(1); + use_buf(1); + t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch); + + struct ggml_tensor * back_layer_inp = t31; + struct ggml_tensor * grad_layer_inp = NULL; + + for (int k = 0; k < n_layer; ++k) { + int il = n_layer-1-k; + struct my_llama_layer & layer = model->layers[il]; + + struct ggml_tensor * t02 = t02L[il]; + struct ggml_tensor * t03 = t03L[il]; + struct ggml_tensor * t04 = t04L[il]; + struct ggml_tensor * t05 = t05L[il]; + struct ggml_tensor * t06 = t06L[il]; + struct ggml_tensor * t07 = t07L[il]; + struct ggml_tensor * t08 = t08L[il]; + struct ggml_tensor * t09 = t09L[il]; + struct ggml_tensor * t10 = t10L[il]; + struct ggml_tensor * t11 = t11L[il]; + struct ggml_tensor * t12 = t12L[il]; + struct ggml_tensor * t13 = t13L[il]; + struct ggml_tensor * t14 = t14L[il]; + struct ggml_tensor * t15 = t15L[il]; + struct ggml_tensor * t16 = t16L[il]; + struct ggml_tensor * t17 = t17L[il]; + struct ggml_tensor * t18 = t18L[il]; + struct ggml_tensor * t19 = t19L[il]; + struct ggml_tensor * t20 = t20L[il]; + struct ggml_tensor * t21 = t21L[il]; + struct ggml_tensor * t22 = t22L[il]; + struct ggml_tensor * t23 = t23L[il]; + struct ggml_tensor * t24 = t24L[il]; + struct ggml_tensor * t25 = t25L[il]; + struct ggml_tensor * t26 = t26L[il]; + struct ggml_tensor * t27 = t27L[il]; + struct ggml_tensor * t28 = t28L[il]; + struct ggml_tensor * t29 = t29L[il]; + struct ggml_tensor * t30 = t30L[il]; + + clr_buf(0); + use_buf(0); + t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + if (grad_layer_inp) { + t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch); + } + clr_buf(1); + t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch); + t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch); + t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch); + t26->grad = expand(gb, ggml_silu_back(ctx0, t26, t27->grad)); assert_shape_2d(t26->grad, n_ff, N*n_batch); + t25->grad = expand(gb, ggml_mul(ctx0, t28->grad, t27)); assert_shape_2d(t25->grad, n_ff, N*n_batch); + t24->grad = expand(gb, ggml_add_inplace(ctx0, + ggml_out_prod(ctx0, layer.w1, ggml_transpose(ctx0, t26->grad)), + ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch); + t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch); + t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch); + use_buf(1); + t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch); + grad_layer_inp = t21; + use_buf(0); + t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch); + t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch); + t18->grad = expand(gb, ggml_reshape_4d(ctx0, t19->grad, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch); + t17->grad = t18->grad; assert_shape_4d(t17->grad, n_embd/n_head, n_head, N, n_batch); + t16->grad = expand(gb, ggml_permute(ctx0, t17->grad, 0, 2, 1, 3)); assert_shape_4d(t16->grad, n_embd/n_head, N, n_head, n_batch); + struct ggml_tensor * flash_attn = expand(gb, ggml_flash_attn_back(ctx0, t13, t14, t15, t16->grad, true)); assert_shape_4d(flash_attn, n_embd/n_head, N*3, n_head, n_batch); + t15->grad = expand(gb, view__v(flash_attn)); assert_shape_4d(t15->grad, N, n_embd/n_head, n_head, n_batch); + t14->grad = expand(gb, view__k(flash_attn)); assert_shape_4d(t14->grad, n_embd/n_head, N, n_head, n_batch); + t13->grad = expand(gb, view__q(flash_attn)); assert_shape_4d(t13->grad, n_embd/n_head, N, n_head, n_batch); + t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head); + t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd); + t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch); + t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); + t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch); + t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch); + t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); + t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch); + t04->grad = expand(gb, ggml_add_inplace(ctx0, + ggml_add_inplace(ctx0, + ggml_out_prod(ctx0, layer.wv, t11->grad), + ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))), + ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch); + t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch); + use_buf(1); + t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, ggml_repeat(ctx0, layer.attention_norm, t02))); assert_shape_2d(t02->grad, n_embd, N*n_batch); + back_layer_inp = t02; + // use_buf(0); + + use_buf(-1); + layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd); + layer.wq->grad = expand(gb, add_or_set(layer.wq->grad, ggml_out_prod(ctx0, t04, t05->grad))); assert_shape_2d(layer.wq->grad, n_embd, n_embd); + layer.wk->grad = expand(gb, add_or_set(layer.wk->grad, ggml_out_prod(ctx0, t04, t08->grad))); assert_shape_2d(layer.wk->grad, n_embd, n_embd); + layer.wv->grad = expand(gb, add_or_set(layer.wv->grad, ggml_out_prod(ctx0, t04, ggml_transpose(ctx0, t11->grad)))); assert_shape_2d(layer.wv->grad, n_embd, n_embd); + layer.wo->grad = expand(gb, add_or_set(layer.wo->grad, ggml_out_prod(ctx0, t19, t20->grad))); assert_shape_2d(layer.wo->grad, n_embd, n_embd); + layer.ffn_norm->grad = expand(gb, add_or_set(layer.ffn_norm->grad, ggml_repeat_back(ctx0, t23->grad, layer.ffn_norm))); assert_shape_1d(layer.ffn_norm->grad, n_embd); + layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff); + layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd); + layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff); + // use_buf(0); + } + clr_buf(0); + use_buf(0); + t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch); + use_buf(-1); + model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); + // clr_buf(1); + // clr_buf(0); + + *logits = t35; + + if (track_max_mem) { + printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]); + printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]); + } + + // now that all grads are created, set the graph leafs and grads + graph_set_leafs_grads(gf); + graph_set_leafs_grads(gb); + + return t36; +} + +void set_f32_3d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int64_t i2, float value) { + float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); + *ptr = value; +} + +void set_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, float value) { + float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + *ptr = value; +} + +void set_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1, int32_t value) { + int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + *ptr = value; +} + +float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { + float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + return *ptr; +} + +int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { + int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + return *ptr; +} + +void print_row(struct ggml_tensor * probs, int i) { + for (int k = 0; k < probs->ne[0]; ++k) { + float p = get_f32_2d(probs, k, i); + printf(" %.2f", p); + } + printf("\n"); +} + +void print_matrix(struct ggml_tensor * probs) { + assert(probs->n_dims == 2); + for (int i = 0; i < probs->ne[1]; ++i) { + for (int k = 0; k < probs->ne[0]; ++k) { + float p = get_f32_2d(probs, k, i); + printf(" %.2f", p); + } + printf("\n"); + } +} + + +void print_token(struct llama_context * ctx, llama_token token) { + printf("%s", llama_token_to_str(ctx, token)); +} + +void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) { + for (int i=0; ine[0]; ++i) { + int token = ggml_get_i32_1d(tokens, i); + print_token(ctx, token); + } +} + +void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) { + for (int i1=0; i1ne[1]; ++i1) { + //int num_newline = 0; + for (int i0=0; i0ne[0]; ++i0) { + int token = get_i32_2d(tokens, i0, i1); + print_token(ctx, token); + // bool isnl = (token == llama_token_nl()); + // if (isnl) { + // ++num_newline; + // } + // if (isnl) { + // if (num_newline < 2) { + // print_token(ctx, token); + // } else { + // printf("\\n"); + // } + // } else { + // print_token(ctx, token); + // } + } + printf("\n--\n"); + } +} + +void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { + int n_tokens = tokens_input->ne[0]; + int n_vocab = target_logits->ne[0]; + + size_t sample = train_samples[example_id % n_train_samples]; + GGML_ASSERT(sample+n_tokens-1 < n_train_data); + + ggml_set_f32(target_logits, -1.0f/n_vocab); + ggml_set_f32(target_probs, 0.0f); + ggml_set_i32_1d(tokens_input, 0, llama_token_bos()); + for (int i=1; in_dims == 2); + GGML_ASSERT(target_logits->n_dims == 3); + GGML_ASSERT(target_probs->n_dims == 3); + int n_vocab = target_logits->ne[0]; + int n_tokens = tokens_input->ne[0]; + int n_batch = tokens_input->ne[1]; + GGML_ASSERT(n_tokens == target_logits->ne[1]); + GGML_ASSERT(n_batch == target_logits->ne[2]); + GGML_ASSERT(n_vocab == target_probs->ne[0]); + GGML_ASSERT(n_tokens == target_probs->ne[1]); + GGML_ASSERT(n_batch == target_probs->ne[2]); + + ggml_set_f32(target_logits, -1.0f/n_vocab); + ggml_set_f32(target_probs, 0.0f); + for (int k=0; kne[0]; + int n_vocab = target_logits->ne[0]; + for (int i=0; i= 0 && size < INT_MAX); + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + void write_raw(const void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, size, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(std::uint32_t val) { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +int tokenize_file(struct llama_context * lctx, const char * filename, std::vector& out) { + struct llama_file f(filename, "rb"); + + std::vector buf; + buf.resize(f.size+1); + + f.read_raw(buf.data(), f.size); + buf[f.size] = '\0'; + + out.resize(buf.size()); + + int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), buf.size(), false); + if (n_tokens >= 0) { + out.resize(n_tokens); + } + + bool verify = false; + if (verify) { + const char * in = buf.data(); + const char * end = buf.data() + buf.size(); + for (int i = 0; i < (int) out.size(); ++i) { + const char * s = llama_token_to_str(lctx, out[i]); + int len = strlen(s); + if (in >= end) { + printf("%s: unexpected end of original text.\n", __func__); + break; + } + const bool matches = (strncmp(in, s, len) == 0); + if (matches) { + in += len; + } else { + printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s); + } + } + } + + return n_tokens; +} + +void shuffle_ints(int * begin, int * end) { + if (end <= begin) return; + int max=begin[0]; + for (int i=1; i max) { + max = begin[i]; + } + } + std::vector vals; + vals.resize(max+1); + for (int i=0; i candidates; + llama_token_data_array candidates_p; + +}; + +void init_sampler(struct my_llama_sampler * sampler, struct llama_context * ctx) { + sampler->ctx = ctx; + sampler->n_vocab = llama_n_vocab(sampler->ctx); + sampler->n_ctx = llama_n_ctx(sampler->ctx); + sampler->mirostat_mu = 2.0f * sampler->params.mirostat_tau; +} + +llama_token sample(struct my_llama_sampler * sampler, float * logits, const llama_token * last_tokens, int n_last_tokens) { + GGML_ASSERT(sampler->ctx != NULL); + + struct llama_context * ctx = sampler->ctx; + + sampler->candidates.resize(sampler->n_vocab); + for (llama_token token_id = 0; token_id < sampler->n_vocab; ++token_id) { + sampler->candidates[token_id].id = token_id; + sampler->candidates[token_id].logit = logits[token_id]; + sampler->candidates[token_id].p = 0.0; + } + + llama_token_data_array * candidates_p = & sampler->candidates_p; + + candidates_p->data = sampler->candidates.data(); + candidates_p->size = sampler->candidates.size(); + candidates_p->sorted = false; + + const auto params = sampler->params; + + // Apply penalties + const float nl_logit = logits[llama_token_nl()]; + + const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx); + + llama_sample_repetition_penalty( + ctx, + candidates_p, + last_tokens + n_last_tokens - n_last, + n_last, + params.repeat_penalty); + llama_sample_frequency_and_presence_penalties( + ctx, + candidates_p, + last_tokens + n_last_tokens - n_last, + n_last, + params.alpha_frequency, + params.alpha_presence); + + if (!params.penalize_nl) { + logits[llama_token_nl()] = nl_logit; + } + + llama_token token = 0; + if (params.temp <= 0) { + // Greedy sampling + token = llama_sample_token_greedy(ctx, candidates_p); + } else { + if (params.mirostat == 1) { + int mirostat_m = 100; + llama_sample_temperature(ctx, candidates_p, params.temp); + token = llama_sample_token_mirostat(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, mirostat_m, &sampler->mirostat_mu); + } else if (params.mirostat == 2) { + llama_sample_temperature(ctx, candidates_p, params.temp); + token = llama_sample_token_mirostat_v2(ctx, candidates_p, params.mirostat_tau, params.mirostat_eta, &sampler->mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k (ctx, candidates_p, params.top_k, 1); + llama_sample_tail_free (ctx, candidates_p, params.tfs_z, 1); + llama_sample_typical (ctx, candidates_p, params.typical_p, 1); + + llama_sample_top_p (ctx, candidates_p, params.top_p, 1); + llama_sample_temperature (ctx, candidates_p, params.temp); + token = llama_sample_token(ctx, candidates_p); + } + } + return token; +} + +void set_logits_masked(struct ggml_tensor * logits, std::vector& mask, float value) { + GGML_ASSERT(logits->ne[0] == (int64_t) mask.size()); + for (int i2 = 0; i2 < logits->ne[2]; ++i2) { + for (int i1 = 0; i1 < logits->ne[1]; ++i1) { + for (int i0 = 0; i0 < logits->ne[0]; ++i0) { + if (!mask[i0]) continue; + float * ptr = (float *) ((char *) logits->data + i2*logits->nb[2] + i1*logits->nb[1] + i0*logits->nb[0]); + *ptr = value; + } + } + } +} + +void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) { + if (tensor == NULL) { + file->write_u32(0); + file->write_u32(0); + file->write_u32(GGML_TYPE_F32); + file->seek(-file->tell() & 31, SEEK_CUR); + return; + } + const char * name = ggml_get_name(tensor); + uint32_t name_len = strlen(name); + uint32_t nd = tensor->n_dims; + uint32_t ne[4] = { (uint32_t)tensor->ne[0], + (uint32_t)tensor->ne[1], + (uint32_t)tensor->ne[2], + (uint32_t)tensor->ne[3] }; + file->write_u32(nd); + file->write_u32(name_len); + file->write_u32(tensor->type); + file->write_raw(ne, sizeof(ne[0]) * nd); + file->write_raw(name, name_len); + file->seek(-file->tell() & 31, SEEK_CUR); + file->write_raw(tensor->data, ggml_nbytes(tensor)); +} + +void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) { + int32_t nd = file->read_u32(); + GGML_ASSERT(nd == tensor->n_dims); + + uint32_t name_len = file->read_u32(); + enum ggml_type type = (enum ggml_type) file->read_u32(); + GGML_ASSERT(type == tensor->type); + + uint32_t ne[4]; + file->read_raw(ne, sizeof(ne[0]) * nd); + for (int i=0; ine[i]); + } + + std::string name = file->read_string(name_len); + GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0); + + file->seek(-file->tell() & 31, SEEK_CUR); + file->read_raw(tensor->data, ggml_nbytes(tensor)); +} + +void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) { + const uint32_t version = 0; + GGML_ASSERT(opt->nx >= 0); + GGML_ASSERT(opt->iter >= 0); + file->write_u32(version); + file->write_raw(&opt->params, sizeof(opt->params)); + file->write_raw(&opt->nx, sizeof(opt->nx)); + file->write_raw(&opt->iter, sizeof(opt->iter)); + file->write_u32((uint32_t) opt->just_initialized); + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + GGML_ASSERT(opt->adam.x != NULL); + write_tensor(file, opt->adam.x); + write_tensor(file, opt->adam.g1); + write_tensor(file, opt->adam.g2); + write_tensor(file, opt->adam.m); + write_tensor(file, opt->adam.v); + write_tensor(file, opt->adam.mh); + write_tensor(file, opt->adam.vh); + write_tensor(file, opt->adam.pf); + file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); + file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); + file->write_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); + } break; + case GGML_OPT_LBFGS: + { + GGML_ASSERT(opt->adam.x != NULL); + write_tensor(file, opt->lbfgs.x); + write_tensor(file, opt->lbfgs.xp); + write_tensor(file, opt->lbfgs.g); + write_tensor(file, opt->lbfgs.gp); + write_tensor(file, opt->lbfgs.d); + write_tensor(file, opt->lbfgs.pf); + write_tensor(file, opt->lbfgs.lmal); + write_tensor(file, opt->lbfgs.lmys); + write_tensor(file, opt->lbfgs.lms); + write_tensor(file, opt->lbfgs.lmy); + file->write_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); + file->write_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); + file->write_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); + file->write_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); + file->write_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); + file->write_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); + } break; + } +} + +void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { + uint32_t version = file->read_u32(); + GGML_ASSERT(version == 0); + + file->read_raw(&opt->params, sizeof(opt->params)); + file->read_raw(&opt->nx, sizeof(opt->nx)); + ggml_opt_init(ctx, opt, opt->params, opt->nx); + + file->read_raw(&opt->iter, sizeof(opt->iter)); + opt->just_initialized = (bool) file->read_u32(); + + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + read_tensor(file, opt->adam.x); + read_tensor(file, opt->adam.g1); + read_tensor(file, opt->adam.g2); + read_tensor(file, opt->adam.m); + read_tensor(file, opt->adam.v); + read_tensor(file, opt->adam.mh); + read_tensor(file, opt->adam.vh); + if (opt->adam.pf) { read_tensor(file, opt->adam.pf); } + file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); + file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); + file->read_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); + } break; + case GGML_OPT_LBFGS: + { + GGML_ASSERT(opt->adam.x != NULL); + read_tensor(file, opt->lbfgs.x); + read_tensor(file, opt->lbfgs.xp); + read_tensor(file, opt->lbfgs.g); + read_tensor(file, opt->lbfgs.gp); + read_tensor(file, opt->lbfgs.d); + if (opt->lbfgs.pf) { read_tensor(file, opt->lbfgs.pf); } + read_tensor(file, opt->lbfgs.lmal); + read_tensor(file, opt->lbfgs.lmys); + read_tensor(file, opt->lbfgs.lms); + read_tensor(file, opt->lbfgs.lmy); + file->read_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); + file->read_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); + file->read_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); + file->read_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); + file->read_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); + file->read_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); + } break; + } +} + +void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename) { + struct llama_file file(filename, "wb"); + if (file.fp == NULL) { + return; + } + + const uint32_t magic = 'ggcp'; + const uint32_t version = 0; + + file.write_u32(magic); + file.write_u32(version); + file.write_u32(model->train_its); + file.write_u32(model->train_samples); + file.write_u32(model->train_tokens); + file.write_u32(model->hparams.n_vocab); + file.write_u32(model->hparams.n_embd); + file.write_u32(model->hparams.n_mult); + file.write_u32(model->hparams.n_head); + file.write_u32(model->hparams.n_layer); + file.write_u32(model->hparams.n_rot); + + write_tensor(&file, model->tok_embeddings); + write_tensor(&file, model->norm); + write_tensor(&file, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + write_tensor(&file, layer.attention_norm); + write_tensor(&file, layer.wq); + write_tensor(&file, layer.wk); + write_tensor(&file, layer.wv); + write_tensor(&file, layer.wo); + write_tensor(&file, layer.ffn_norm); + write_tensor(&file, layer.w1); + write_tensor(&file, layer.w2); + write_tensor(&file, layer.w3); + } + + write_opt_context(&file, opt); +} + +bool load_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename, bool init) { + struct llama_file file(filename, "rb"); + + uint32_t magic; + uint32_t version; + + uint32_t train_its = 0; + uint32_t train_samples = 0; + uint32_t train_tokens = 0; + + if (file.fp) { + printf("%s: Loading model from '%s'.\n", __func__, filename); + magic = file.read_u32(); + GGML_ASSERT(magic == 'ggcp'); + version = file.read_u32(); + GGML_ASSERT(version == 0); + train_its = file.read_u32(); + train_samples = file.read_u32(); + train_tokens = file.read_u32(); + model->hparams.n_vocab = file.read_u32(); + model->hparams.n_embd = file.read_u32(); + model->hparams.n_mult = file.read_u32(); + model->hparams.n_head = file.read_u32(); + model->hparams.n_layer = file.read_u32(); + model->hparams.n_rot = file.read_u32(); + print_params(&model->hparams); + } + + if (init) { + init_model(model); + } + + if (file.fp) { + model->train_its = train_its; + model->train_samples = train_samples; + model->train_tokens = train_tokens; + } + + printf("%s: Training iterations: %u.\n", __func__, model->train_its); + printf("%s: Training samples: %u.\n", __func__, model->train_samples); + printf("%s: Training tokens: %u.\n", __func__, model->train_tokens); + + if (file.fp) { + read_tensor(&file, model->tok_embeddings); + read_tensor(&file, model->norm); + read_tensor(&file, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + read_tensor(&file, layer.attention_norm); + read_tensor(&file, layer.wq); + read_tensor(&file, layer.wk); + read_tensor(&file, layer.wv); + read_tensor(&file, layer.wo); + read_tensor(&file, layer.ffn_norm); + read_tensor(&file, layer.w1); + read_tensor(&file, layer.w2); + read_tensor(&file, layer.w3); + } + + read_opt_context(&file, model->ctx, opt); + } + + return (file.fp != NULL); +} + +void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * model, const char * filename) { + struct llama_file file(filename, "wb"); + if (file.fp == NULL) { + return; + } + + // write_magic + file.write_u32(LLAMA_FILE_MAGIC); // magic + file.write_u32(LLAMA_FILE_VERSION); // version + // write_hparams + file.write_u32(model->hparams.n_vocab); + file.write_u32(model->hparams.n_embd); + file.write_u32(model->hparams.n_mult); + file.write_u32(model->hparams.n_head); + file.write_u32(model->hparams.n_layer); + file.write_u32(model->hparams.n_rot); + file.write_u32(LLAMA_FTYPE_ALL_F32); + // write_vocab + uint32_t n_vocab = model->hparams.n_vocab; + for (uint32_t i = 0; i < n_vocab; i++) { + const auto & token_score = vocab->id_to_token.at(i); + file.write_u32((uint32_t) token_score.tok.size()); + file.write_raw(token_score.tok.data(), token_score.tok.size()); + file.write_raw(&token_score.score, sizeof(token_score.score)); + } + // write tensors + write_tensor(&file, model->tok_embeddings); + write_tensor(&file, model->norm); + write_tensor(&file, model->output); + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + write_tensor(&file, layer.attention_norm); + write_tensor(&file, layer.wq); + write_tensor(&file, layer.wk); + write_tensor(&file, layer.wv); + write_tensor(&file, layer.wo); + write_tensor(&file, layer.ffn_norm); + write_tensor(&file, layer.w1); + write_tensor(&file, layer.w2); + write_tensor(&file, layer.w3); + } +} + +float cosine_decay(const int decay_steps, const float alpha, int step) { + if (step > decay_steps) { + step = decay_steps; + } + const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps)); + const float decay = (1 - alpha)*cosine_decay + alpha; + return decay; +} + +float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult) { + while (step > decay_steps) { + step -= decay_steps; + decay_steps = (int) restart_step_mult * decay_steps; + } + return cosine_decay(decay_steps, alpha, step); +} + +struct train_params { + const char * fn_vocab_model; + const char * fn_train_data; + const char * fn_checkpoint_in; + const char * fn_checkpoint_out; + const char * fn_model_out; + + int seed; + int n_ctx; + int n_embd; + int n_mult; + int n_head; + int n_layer; + int n_rotmax; + + int n_threads; + int n_batch; + int n_examples; + int n_predict; + + int print_info_interval; + int print_details_interval; + + bool samples_start_after_nl; + bool use_adam; + bool use_flash; + bool use_scratch; + + // only adam + int warmup; + int cos_decay_steps; + float cos_decay_restart; + float cos_decay_alpha; + + int lbfgs_n_iter; + int adam_n_iter; + float adam_alpha; + float adam_decay; + + int mem_model_gb; + int mem_compute_gb; + int mem_compute0_gb; + int mem_compute1_gb; +}; + +struct train_params get_default_train_params() { + struct train_params params; + params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin"; + params.fn_train_data = "shakespeare.txt"; + params.fn_checkpoint_in = "checkpoint.bin"; + params.fn_checkpoint_out = "checkpoint.bin"; + params.fn_model_out = "ggml-checkpoint-f32.bin"; + + params.seed = -1; + + params.n_ctx = 128; + params.n_embd = 256; + params.n_mult = 256; + params.n_head = 8; + params.n_layer = 16; + params.n_rotmax = 64; + + params.n_threads = 6; + params.n_batch = 8; + params.n_examples = 8; + params.n_predict = 1024; + + params.print_info_interval = 1; + params.print_details_interval = 2; + + params.samples_start_after_nl = false; + params.use_adam = true; + params.use_flash = true; + params.use_scratch = true; + + // only adam + params.warmup = 100; + params.cos_decay_steps = 1000; + params.cos_decay_restart = 1.1f; + params.cos_decay_alpha = 0.0f; + + params.lbfgs_n_iter = 16; + params.adam_n_iter = 16; + params.adam_alpha = 1e-3; + params.adam_decay = 1e-3; + + params.mem_model_gb = 2; + params.mem_compute_gb = 24; + params.mem_compute0_gb = 8; + params.mem_compute1_gb = 2; + + return params; +} + +void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model); + fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); + fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); + fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); + fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); + fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); + fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd); + fprintf(stderr, " --mult N Mult size used for new models, influences feedforward size. (default %d)\n", params->n_mult); + fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head); + fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer); + fprintf(stderr, " --rotmax N Maximal number Rope dimensions for new models (default %d)\n", params->n_rotmax); + fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); + fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); + fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); + fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict); + fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); + fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval); + fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); + fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); + fprintf(stderr, " --use-adam Use Adam optimizer (default)\n"); + fprintf(stderr, " --no-flash Don't use flash attention.\n"); + fprintf(stderr, " --use-flash Use flash attention (default)\n"); + fprintf(stderr, " --no-scratch Don't use scratch buffers\n"); + fprintf(stderr, " --use-scratch Use scratch buffers (default)\n"); + fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup); + fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps); + fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); + fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha); + fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter); + fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); + fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); + fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); + fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); + fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb); + fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb); + fprintf(stderr, "\n"); +} + +bool train_params_parse(int argc, char ** argv, struct train_params * params) { + bool invalid_param = false; + std::string arg; + struct train_params default_params = get_default_train_params(); + const std::string arg_prefix = "--"; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + if (arg == "--vocab-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_vocab_model = argv[i]; + } else if (arg == "--train-data") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_train_data = argv[i]; + } else if (arg == "--checkpoint-in") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_checkpoint_in = argv[i]; + } else if (arg == "--checkpoint-out") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_checkpoint_out = argv[i]; + } else if (arg == "--model-out") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_model_out = argv[i]; + } else if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->seed = std::stoi(argv[i]); + } else if (arg == "-c" || arg == "--ctx") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_ctx = std::stoi(argv[i]); + } else if (arg == "--embd") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_embd = std::stoi(argv[i]); + } else if (arg == "--mult") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_mult = std::stoi(argv[i]); + } else if (arg == "--head") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_head = std::stoi(argv[i]); + } else if (arg == "--layer") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_layer = std::stoi(argv[i]); + } else if (arg == "--rotmax") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_rotmax = std::stoi(argv[i]); + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_threads = std::stoi(argv[i]); + } else if (arg == "-b" || arg == "--batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_batch = std::stoi(argv[i]); + } else if (arg == "-n" || arg == "--examples") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_examples = std::stoi(argv[i]); + } else if (arg == "--predict") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_predict = std::stoi(argv[i]); + } else if (arg == "--print-info-interval") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->print_info_interval = std::stoi(argv[i]); + } else if (arg == "--print-details-interval") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->print_details_interval = std::stoi(argv[i]); + } else if (arg == "--samples-after-nl") { + params->samples_start_after_nl = true; + } else if (arg == "--use-lbfgs") { + params->use_adam = false; + } else if (arg == "--use-adam") { + params->use_adam = true; + } else if (arg == "--no-flash") { + params->use_flash = false; + } else if (arg == "--use-flash") { + params->use_flash = true; + } else if (arg == "--no-scratch") { + params->use_scratch = false; + } else if (arg == "--use-scratch") { + params->use_scratch = true; + } else if (arg == "--warmup") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->warmup = std::stoi(argv[i]); + } else if (arg == "--cos-decay-steps") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->cos_decay_steps = std::stof(argv[i]); + } else if (arg == "--cos-decay-restart") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->cos_decay_restart = std::stof(argv[i]); + } else if (arg == "--cos-decay-alpha") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->cos_decay_alpha = std::stof(argv[i]); + } else if (arg == "--lbfgs-iter") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->lbfgs_n_iter = std::stoi(argv[i]); + } else if (arg == "--adam-iter") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_n_iter = std::stoi(argv[i]); + } else if (arg == "--adam-alpha") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_alpha = std::stof(argv[i]); + } else if (arg == "--adam-decay") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_decay = std::stof(argv[i]); + } else if (arg == "--mem-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_model_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute0") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute0_gb = std::stoi(argv[i]); + } else if (arg == "--mem-compute1") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->mem_compute1_gb = std::stoi(argv[i]); + } else if (arg == "-h" || arg == "--help") { + train_print_usage(argc, argv, &default_params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + train_print_usage(argc, argv, &default_params); + exit(1); + } + } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + train_print_usage(argc, argv, &default_params); + exit(1); + } + + return true; +} + +int main(int argc, char ** argv) { + struct train_params params = get_default_train_params(); + + if (!train_params_parse(argc, argv, ¶ms)) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + printf("%s: seed: %d\n", __func__, params.seed); + srand(params.seed); + + struct llama_context_params llama_params = llama_context_default_params(); + llama_params.vocab_only = true; + + struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params); + + struct llama_vocab vocab; + { + std::vector strings; + std::vector scores; + int n_vocab = llama_n_vocab(lctx); + strings.resize(n_vocab, NULL); + scores.resize(n_vocab, 0); + n_vocab = llama_get_vocab(lctx, strings.data(), scores.data(), n_vocab); + GGML_ASSERT(n_vocab == llama_n_vocab(lctx)); + vocab.id_to_token.resize(n_vocab); + for (int i=0; i train_tokens; + if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) { + fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data); + } + printf("%s: number of training tokens: %d\n", __func__, (int) train_tokens.size()); + + struct my_llama_model model; + model.hparams.n_vocab = llama_n_vocab(lctx); + model.hparams.n_ctx = params.n_ctx; + model.hparams.n_embd = params.n_embd; + model.hparams.n_mult = params.n_mult; + model.hparams.n_head = params.n_head; + model.hparams.n_layer = params.n_layer; + model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); + + print_params(&model.hparams); + + std::vector token_noccurs; + std::vector token_notavail; + token_noccurs.resize(model.hparams.n_vocab, 0); + token_notavail.resize(model.hparams.n_vocab, true); + for (int i = 0; i < (int) train_tokens.size(); ++i) { + ++token_noccurs[train_tokens[i]]; + token_notavail[train_tokens[i]] = false; + } + + std::vector token_freq; + token_freq.resize(model.hparams.n_vocab, 0); + int n_unique_tokens = 0; + for (int i = 0; i < (int) token_noccurs.size(); ++i) { + token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size(); + n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0; + } + printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); + + struct my_llama_kv_cache kv_self; + + + struct ggml_init_params lcparams; + lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); + lcparams.mem_buffer = NULL; + lcparams.no_alloc = false; + + model.ctx = ggml_init(lcparams); + kv_self.ctx = model.ctx; + + my_llama_sampler sampler; + + + int n_tokens = model.hparams.n_ctx; + int n_vocab = model.hparams.n_vocab; + int n_batch = params.n_batch; + + struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); + memset(opt, 0, sizeof(struct ggml_opt_context)); + + struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); + struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); + opt_params_adam.print_forward_graph = false; + opt_params_adam.print_backward_graph = false; + opt_params_adam.n_threads = params.n_threads; + opt_params_adam.adam.n_iter = params.adam_n_iter; + opt_params_adam.adam.sched = 1.0f; + opt_params_adam.adam.alpha = params.adam_alpha; + opt_params_adam.adam.decay = params.adam_decay; + + opt_params_lbfgs.print_forward_graph = false; + opt_params_lbfgs.print_backward_graph = false; + opt_params_lbfgs.n_threads = params.n_threads; + opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter; + + opt->ctx = model.ctx; + opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; + + printf("%s: init model\n", __func__); + bool existed = load_checkpoint(&model, opt, params.fn_checkpoint_in, true); + set_param_model(&model); + + opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; + + opt->iter = model.train_its; + printf("%s: opt iter %d\n", __func__, opt->iter); + + bool from_scratch = !existed; + if (from_scratch) { + randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f); + } + + init_kv_cache(&kv_self, &model, 1); + // init_kv_cache(&kv_self, &model, n_batch); + init_sampler(&sampler, lctx); + + printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx)); + // ggml_print_tensor_objects(model.ctx); + + size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb); + uint8_t * compute_addr = new uint8_t[compute_size]; + + size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb); + size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb); + uint8_t * compute_buf_0 = new uint8_t[size_buf_0]; + uint8_t * compute_buf_1 = new uint8_t[size_buf_1]; + + GGML_ASSERT(n_tokens < (int) train_tokens.size()); + std::vector train_samples; + train_samples.push_back(0); + for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) { + if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) { + train_samples.push_back(i); + } + } + shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size()); + for (int i = 0; i < (int) train_samples.size(); ++i) { + GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size()); + } + + printf("%s: begin training\n", __func__); + + for (int ex = 0; ex < params.n_examples; ++ex) { + if (ex*n_batch >= (int) train_samples.size()) { + shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size()); + for (int i = 0; i < (int) train_samples.size(); ++i) { + GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size()); + } + } + + struct ggml_init_params cparams = { + /*.mem_size =*/ compute_size, + /*.mem_buffer =*/ compute_addr, + /*.no_alloc =*/ false, + }; + struct ggml_context * ctx0 = ggml_init(cparams); + + struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + //struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + + int n_past = 0; + + struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + + memset(gfbuf->data, 0, ggml_nbytes(gfbuf)); + memset(gbbuf->data, 0, ggml_nbytes(gbbuf)); + + struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; + struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; + + // ggml_cgraph gf = {}; + gf->n_threads = params.n_threads; + gb->n_threads = params.n_threads; + + get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); + + GGML_ASSERT(n_past == 0); + + struct ggml_tensor * loss = NULL; + struct ggml_tensor * logits = NULL; + + if (params.use_scratch) { + loss = forward_batch_wo_cache_flash_attn_train( + &model, ctx0, + gf, gb, + &logits, tokens_input, target_probs, + compute_buf_0, compute_buf_1, + size_buf_0, size_buf_1, + n_tokens, n_batch); + } else if (params.use_flash) { + logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch); + loss = cross_entropy_loss(ctx0, logits, target_probs); + ggml_build_forward_expand(gf, loss); + *gb = ggml_build_backward(ctx0, gf, true); + } else { + logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch); + loss = cross_entropy_loss(ctx0, logits, target_probs); + ggml_build_forward_expand(gf, loss); + *gb = ggml_build_backward(ctx0, gf, true); + } + + ggml_graph_compute(ctx0, gf); + + size_t used_mem_before_opt = ggml_used_mem(ctx0); + + float error_before_opt = ggml_get_f32_1d(loss, 0); + + opt->params.adam.sched = (opt->iter < params.warmup) + ? (float) opt->iter / (float) params.warmup + : cosine_decay_restart( + params.cos_decay_steps, + params.cos_decay_alpha, + opt->iter - params.warmup, + params.cos_decay_restart); + + printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); + + ggml_opt_resume_g(ctx0, opt, loss, gf, gb); + + size_t used_mem_after_opt = ggml_used_mem(ctx0); + + model.train_its = opt->iter; + model.train_samples += n_batch; + model.train_tokens += n_batch * n_tokens; + + ggml_graph_compute(ctx0, gf); + + float error_after_opt = ggml_get_f32_1d(loss, 0); + + if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { + printf("Example %d, opt iter %d\n", ex, opt->iter); + printf("error_before_opt: %.6f\n", error_before_opt); + printf("error_after_opt: %.6f\n", error_after_opt); + printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); + printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); + } + + if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) { + // set_logits_masked(logits, token_notavail, -1e9); + for (int i=0; idata + i*logits->nb[2] + k*logits->nb[1]), + (llama_token *) ((char *) tokens_input->data + i*tokens_input->nb[1]), + k); + * ((int32_t *) ((char *) after_opt_best_samples->data + i*after_opt_best_samples->nb[1] + k*after_opt_best_samples->nb[0])) = token; + } + } + + // printf("probabilities after optimization:\n"); + // print_matrix(after_opt_probs); + printf("Example:\n---\n"); + print_tokens_batch(lctx, tokens_input); + printf("\n---\n"); + + // printf("best samples after optimization:\n---\n"); + printf("samples after optimization:\n---\n"); + print_tokens_batch(lctx, after_opt_best_samples); + printf("\n---\n"); + } + + ggml_free(ctx0); + } + + if (params.n_examples > 0) { + save_checkpoint(&model, opt, params.fn_checkpoint_out); + } + + if (strlen(params.fn_model_out) > 0) { + save_as_llama_model(&vocab, &model, params.fn_model_out); + } + + { + int n_gen = params.n_predict; + int sample_ctx = n_tokens - n_tokens/8; + + sampler.params.temp = 0.2; + sampler.params.repeat_penalty = 1.1; + sampler.params.mirostat = 2; + init_sampler(&sampler, lctx); + + printf("Generating %d tokens.\n", n_gen); + + struct ggml_tensor * tokens_input = ggml_new_tensor_1d(model.ctx, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); + struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); + + get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); + for (int i=sample_ctx; idata + (sample_ctx-1)*logits->nb[1]), + (llama_token *) tokens_input->data, + sample_ctx-1); + //int token = ggml_get_i32_1d(best_samples, sample_ctx-1); + + // print_row(probs, sample_at); + print_token(lctx, token); + + lshift_examples(tokens_input, target_logits, target_probs, 1); + ggml_set_i32_1d(tokens_input, 0, 0); + ggml_set_i32_1d(tokens_input, sample_ctx-1, token); + + ggml_free(ctx0); + } + } + + delete[] compute_addr; + delete[] compute_buf_0; + delete[] compute_buf_1; + ggml_free(model.ctx); + + return 0; +} diff --git a/ggml.c b/ggml.c index 252edd582..32c191307 100644 --- a/ggml.c +++ b/ggml.c @@ -3603,6 +3603,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SUM_ROWS", "MEAN", "REPEAT", + "REPEAT_BACK", "ABS", "SGN", "NEG", @@ -3616,6 +3617,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM_BACK", "MUL_MAT", + "OUT_PROD", "SCALE", "SET", @@ -3631,6 +3633,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "DIAG_MASK_INF", "DIAG_MASK_ZERO", "SOFT_MAX", + "SOFT_MAX_BACK", "ROPE", "ROPE_BACK", "ALIBI", @@ -3640,13 +3643,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "FLASH_ATTN", "FLASH_FF", + "FLASH_ATTN_BACK", "MAP_UNARY", "MAP_BINARY", + + "CROSS_ENTROPY_LOSS", + "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); - +static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3665,6 +3671,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "Σx_k", "Σx/n", "repeat(x)", + "repeat_back(x)", "abs(x)", "sgn(x)", "-x", @@ -3677,6 +3684,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm(x)", "rms_norm_back(x)", + "X*Y", "X*Y", "x*v", @@ -3693,6 +3701,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "diag_mask_inf(x)", "diag_mask_zero(x)", "soft_max(x)", + "soft_max_back(x)", "rope(x)", "rope_back(x)", "alibi(x)", @@ -3702,12 +3711,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_attn(x)", "flash_ff(x)", + "flash_attn_back(x)", "f(x)", "f(x,y)", + + "cross_entropy_loss(x,y)", + "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); +static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -3870,6 +3883,15 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct (t0->ne[3] == t1->ne[3]); } +static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + bool ggml_is_quantized(enum ggml_type type) { return GGML_IS_QUANTIZED[type]; } @@ -4693,7 +4715,7 @@ struct ggml_tensor * ggml_add_impl( bool is_node = false; - if (!inplace && (a->grad || b->grad)) { + if (a->grad || b->grad) { is_node = true; } @@ -4733,7 +4755,7 @@ struct ggml_tensor * ggml_add1_impl( bool is_node = false; - if (!inplace && (a->grad || b->grad)) { + if (a->grad || b->grad) { is_node = true; } @@ -5159,6 +5181,34 @@ struct ggml_tensor * ggml_repeat( return result; } +// ggml_repeat_back + +struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = GGML_OP_REPEAT_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_abs struct ggml_tensor * ggml_abs_impl( @@ -5536,6 +5586,32 @@ struct ggml_tensor * ggml_mul_mat( return result; } +// ggml_out_prod + +struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_out_prod(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_OUT_PROD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_scale struct ggml_tensor * ggml_scale_impl( @@ -5548,7 +5624,7 @@ struct ggml_tensor * ggml_scale_impl( bool is_node = false; - if (!inplace && (a->grad || b->grad)) { + if (a->grad || b->grad) { is_node = true; } @@ -5591,7 +5667,7 @@ struct ggml_tensor * ggml_set_impl( bool is_node = false; - if (!inplace && (a->grad || b->grad)) { + if (a->grad || b->grad) { is_node = true; } @@ -5913,10 +5989,6 @@ struct ggml_tensor * ggml_view_1d( result->src1 = NULL; result->opt[0] = offs; - if (is_node) { - memcpy(result->padding, &offset, sizeof(offset)); - } - return result; } @@ -5957,10 +6029,6 @@ struct ggml_tensor * ggml_view_2d( result->src1 = NULL; result->opt[0] = offs; - if (is_node) { - memcpy(result->padding, &offset, sizeof(offset)); - } - return result; } @@ -6003,10 +6071,6 @@ struct ggml_tensor * ggml_view_3d( result->src1 = NULL; result->opt[0] = offs; - if (is_node) { - memcpy(result->padding, &offset, sizeof(offset)); - } - return result; } @@ -6051,10 +6115,6 @@ struct ggml_tensor * ggml_view_4d( result->src1 = NULL; result->opt[0] = offs; - if (is_node) { - memcpy(result->padding, &offset, sizeof(offset)); - } - return result; } @@ -6116,10 +6176,18 @@ struct ggml_tensor * ggml_permute( result->src1 = NULL; if (is_node) { - result->padding[0] = axis0; - result->padding[1] = axis1; - result->padding[2] = axis2; - result->padding[3] = axis3; + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); + + ((int32_t *) b->data)[0] = axis0; + ((int32_t *) b->data)[1] = axis1; + ((int32_t *) b->data)[2] = axis2; + ((int32_t *) b->data)[3] = axis3; + + ggml_scratch_load(ctx); + + result->opt[0] = b; } return result; @@ -6359,6 +6427,44 @@ struct ggml_tensor * ggml_soft_max_inplace( return ggml_soft_max_impl(ctx, a, true); } + +// ggml_soft_max_back + +struct ggml_tensor * ggml_soft_max_back_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; // TODO : implement backward pass + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, true); +} + // ggml_rope struct ggml_tensor * ggml_rope_impl( @@ -6371,7 +6477,7 @@ struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(n_past >= 0); bool is_node = false; - if (!inplace && a->grad) { + if (a->grad) { is_node = true; } @@ -6425,8 +6531,7 @@ struct ggml_tensor * ggml_rope_back( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; + is_node = false; // TODO: implement backward } struct ggml_tensor * result = ggml_dup_tensor(ctx, a); @@ -6591,7 +6696,6 @@ struct ggml_tensor * ggml_flash_attn( bool is_node = false; if (q->grad || k->grad || v->grad) { - GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -6623,7 +6727,6 @@ struct ggml_tensor * ggml_flash_ff( bool is_node = false; if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { - GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -6641,6 +6744,71 @@ struct ggml_tensor * ggml_flash_ff( return result; } +// ggml_flash_attn_back + +struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + // d shape [D,N,ne2,ne3] + // q shape [D,N,ne2,ne3] + // k shape [D,M,ne2,ne3] + // v shape [M,D,ne2,ne3] + + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + + GGML_ASSERT(k->ne[0] == D); + GGML_ASSERT(v->ne[0] == M); + GGML_ASSERT(v->ne[1] == D); + GGML_ASSERT(d->ne[0] == D); + GGML_ASSERT(d->ne[1] == N); + GGML_ASSERT(k->ne[2] == ne2); + GGML_ASSERT(k->ne[3] == ne3); + GGML_ASSERT(v->ne[2] == ne2); + GGML_ASSERT(v->ne[3] == ne3); + GGML_ASSERT(d->ne[2] == ne2); + GGML_ASSERT(d->ne[3] == ne3); + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + // when using this operation (in backwards pass) these grads are set. + // we don't want to create (big) grad of our result, so is_node is false. + is_node = false; + } + + // store gradients of q, k and v as continuous tensors concatenated in result. + // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3] + // gradq->data = result->data + // gradk->data = result->data + nb0*D*N*ne2*ne3 + // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3 + // note: v and gradv are actually transposed, i.e. v->ne[0] != D. + int64_t ne[4] = {D,M+N+M,ne2,ne3}; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_FLASH_ATTN_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = d; + result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + + // ggml_map_unary struct ggml_tensor * ggml_map_unary_impl_f32( @@ -6725,6 +6893,50 @@ struct ggml_tensor * ggml_map_binary_inplace_f32( return ggml_map_binary_impl_f32(ctx, a, b, fun, true); } +// ggml_cross_entropy_loss + +struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_cross_entropy_loss_back + +struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_is_scalar(c)); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; + result->grad = NULL; + result->src0 = a; + result->src1 = b; + result->opt[0] = c; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_set_param( @@ -8875,6 +9087,99 @@ static void ggml_compute_forward_repeat( } } +// ggml_compute_forward_repeat_back + +static void ggml_compute_forward_repeat_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (ggml_is_contiguous(dst)) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } + } + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_back_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_abs static void ggml_compute_forward_abs_f32( @@ -10249,6 +10554,176 @@ static void ggml_compute_forward_mul_mat( } } +// ggml_compute_forward_out_prod + + +static void ggml_compute_forward_out_prod_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + //const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int nb00 = src0->nb[0]; + const int nb01 = src0->nb[1]; + const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + + if (params->type == GGML_TASK_INIT) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + // for (int64_t i0 = 0; i0 < ne0; ++i0) { + // d[i0] += s0[i0] * s1[i1]; + // } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_out_prod_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_scale static void ggml_compute_forward_scale_f32( @@ -10671,7 +11146,11 @@ static void ggml_compute_forward_get_rows_back_f32( GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(dst)); - ggml_compute_forward_dup_same_cont(params, opt0, dst); + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + } if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -10815,8 +11294,8 @@ static void ggml_compute_forward_diag_mask_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst, const float value) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_nelements(src1) == 2); const int ith = params->ith; const int nth = params->nth; @@ -10824,7 +11303,7 @@ static void ggml_compute_forward_diag_mask_f32( const int n_past = ((int32_t *) src1->data)[0]; const bool inplace = (bool)((int32_t *) src1->data)[1]; - assert(n_past >= 0); + GGML_ASSERT(n_past >= 0); if (!inplace && (params->type == GGML_TASK_INIT)) { // memcpy needs to be synchronized across threads to avoid race conditions. @@ -10848,8 +11327,8 @@ static void ggml_compute_forward_diag_mask_f32( const int nr = src0->ne[1]; const int nz = n/nr; - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); for (int k = 0; k < nz; k++) { for (int j = ith; j < nr; j += nth) { @@ -10985,6 +11464,101 @@ static void ggml_compute_forward_soft_max( } } +// ggml_compute_forward_soft_max_back + +static void ggml_compute_forward_soft_max_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(dy[i])); + assert(!isnan(y[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + + // linear runtime, no additional memory + float dot_y_dy = 0; + ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32(nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_alibi static void ggml_compute_forward_alibi_f32( @@ -12938,6 +13512,414 @@ static void ggml_compute_forward_flash_ff( } } +// ggml_compute_forward_flash_attn_back + +static void ggml_compute_forward_flash_attn_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * d, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int64_t neq0 = q->ne[0]; + const int64_t neq1 = q->ne[1]; + const int64_t neq2 = q->ne[2]; + const int64_t neq3 = q->ne[3]; + + const int64_t nek0 = k->ne[0]; + const int64_t nek1 = k->ne[1]; + //const int64_t nek2 = k->ne[2]; + //const int64_t nek3 = k->ne[3]; + + const int64_t nev0 = v->ne[0]; + const int64_t nev1 = v->ne[1]; + //const int64_t nev2 = v->ne[2]; + //const int64_t nev3 = v->ne[3]; + + const int64_t ned0 = d->ne[0]; + const int64_t ned1 = d->ne[1]; + //const int64_t ned2 = d->ne[2]; + //const int64_t ned3 = d->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nbd0 = d->nb[0]; + const int nbd1 = d->nb[1]; + const int nbd2 = d->nb[2]; + const int nbd3 = d->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + const int mxDM = MAX(D, Mup); + + // GGML_ASSERT(ne0 == D); + // GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + } + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2); + const int iq2 = ir - iq3*neq2; + for ( int iq1 = 0; iq1 < neq1; ++iq1) { + + + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + ggml_vec_sum_f32(Mup, &sum, SM); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SW[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, SM, sum); + + } + + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for iq2,iq3: + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM + } + + // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur + // S = d[:D,iq1,iq2,iq3] @ vcur + // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3] + ggml_vec_set_f32(M, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_mad_f32(M, + S, + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + } + + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); + ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + ggml_vec_mul_f32 (M, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + if (masked) { + // for (int64_t i = P + iq1 + 1; i < M; i++) { + // S[i] = 0; + // } + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = 0; + } + } + } + ggml_vec_scale_f32(M, S, scale); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; + void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic] + // + //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T) + //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T) + for (int64_t ic = 0; ic < M; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)), + S[ic]); + } + + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + for (int64_t ic = 0; ic < M; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // ggml_vec_set_f32(D, + // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), + // 0); + ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), + (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)), + S[ic]); + } + + // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM + // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] + for (int64_t ic = 0; ic < D; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // ggml_vec_set_f32(M, + // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), + // 0); + ggml_vec_mad_f32(M, + (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + } + } + } +} + +static void ggml_compute_forward_flash_attn_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * d, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -13031,6 +14013,286 @@ static void ggml_compute_forward_map_binary( } } +// ggml_compute_forward_cross_entropy_loss + +static void ggml_compute_forward_cross_entropy_loss_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + + // TODO: handle transposed/permuted matrices + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + if (params->type == GGML_TASK_INIT) { + if (ith == 0) { + memset(sums, 0, sizeof(float) * (nth + nth * nc)); + } + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (ith == 0) { + float * dp = (float *) dst->data; + ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f; + } + return; + } + + const double eps = 1e-9; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * st = (float *) params->wdata + nth + ith*nc; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // soft_max + ggml_float sum = 0.0; + { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + st[i] = 0.0f; + } else { + // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += (ggml_float)val; + st[i] = val; + } + } + + assert(sum > 0.0); + // sum = 1.0/sum; + } + // avoid log(0) by rescaling from [0..1] to [eps..1] + sum = (1.0 - eps) / sum; + ggml_vec_scale_f32(nc, st, sum); + ggml_vec_add1_f32(nc, st, st, eps); + ggml_vec_log_f32(nc, st, st); + ggml_vec_mul_f32(nc, st, st, s1); + + ggml_vec_sum_f32(nc, sums + ith, st); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(st[i])); + assert(!isinf(st[i])); + } +#endif + } + +} + +static void ggml_compute_forward_cross_entropy_loss( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_cross_entropy_loss_back + +static void ggml_compute_forward_cross_entropy_loss_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const float eps = 1e-9f; + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + float * d = (float *) opt0->data; + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * sm = (float *) params->wdata + ith*nc; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // step by step explanation: + { + //float * sums = (float *) params->wdata; + + // forward pass with annotated gradients from backward pass + // (built by going in reverse operation order, adding to gradients of current operation args) + // st0 = exp(s0-max(s0)) grad[st0] = grad[st1]*(1.0 - eps)/sum + // from softmax_back: grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1])) + // ggml_vec_scale_f32(nc, st, sum); // st1 = st0*/sum = softmax(s0) grad[st1] = grad[st2]*(1.0 - eps) + // ggml_vec_scale_f32(nc, st, (1.0f - eps)); // st2 = st1*(1.0 - eps) grad[st2] = grad[st3] + // ggml_vec_add1_f32(nc, st, st, eps); // st3 = st2 + eps grad[st3] = grad[st4]/st3 + // ggml_vec_log_f32(nc, st, st); // st4 = log(st3) grad[st4] = grad[st5] * s1 + // ggml_vec_mul_f32(nc, st, st, s1); // st5 = st4 * s1 grad[st5] = grad[sums[ith]] + // ggml_vec_sum_f32(nc, sums + ith, st); // sums[ith] = st5 grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel] + + // substitute into grad[st1], because we can reuse softmax_back from this point on + // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps)) + // postorder: + // grad[st1] := softmax(s0) + // grad[st1] := grad[st1]*(1.0 - eps) + // grad[st1] := grad[st1] + eps + // grad[st1] := s1 / grad[st1] + // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel] + + // src0 gradients by going through softmax_back + // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1])) + // from softmax_back: + // dxk = yk * (dyk - dot(y, dy)) + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + // postorder: + // dot_st1_dst1 := dot(st1, grad[st1]) + // grad[s0] := grad[st1] + // grad[s0] := grad[s0] - dot_st1_dst1 + // grad[s0] := grad[s0] * st1 + + // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1] + // sm := softmax(s0) + // grad[s0] := sm*(1.0 - eps) + // grad[s0] := grad[s0] + eps + // grad[s0] := s1 / grad[s0] + // grad[s0] := grad[s0]*(1.0-eps)*-grad[cel] + // dot_st1_dst1 := dot(sm, grad[s0]) + // grad[s0] := grad[s0] - dot_st1_dst1 + // grad[s0] := grad[s0] * sm + } + + // soft_max + ggml_float sum = 0.0; + { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + sm[i] = 0.0f; + } else { + // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += (ggml_float)val; + sm[i] = val; + } + } + + assert(sum > 0.0); + sum = 1.0/sum; + } + + float dot_st1_dst1 = 0; + ggml_vec_scale_f32(nc, sm, sum); + ggml_vec_cpy_f32 (nc, ds0, sm); + ggml_vec_scale_f32(nc, ds0, (1.0f - eps)); + ggml_vec_add1_f32 (nc, ds0, ds0, eps); + ggml_vec_div_f32 (nc, ds0, s1, ds0); + ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]); + ggml_vec_dot_f32 (nc, &dot_st1_dst1, sm, ds0); + ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1); + ggml_vec_mul_f32 (nc, ds0, ds0, sm); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(sm[i])); + assert(!isinf(sm[i])); + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void ggml_compute_forward_cross_entropy_loss_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + + ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { @@ -13102,6 +14364,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_repeat(params, tensor->src0, tensor); } break; + case GGML_OP_REPEAT_BACK: + { + ggml_compute_forward_repeat_back(params, tensor->src0, tensor); + } break; case GGML_OP_ABS: { ggml_compute_forward_abs(params, tensor->src0, tensor); @@ -13150,6 +14416,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_OUT_PROD: + { + ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_SCALE: { ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); @@ -13206,6 +14476,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_soft_max(params, tensor->src0, tensor); } break; + case GGML_OP_SOFT_MAX_BACK: + { + ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_ROPE: { ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor); @@ -13241,6 +14515,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); } break; + case GGML_OP_FLASH_ATTN_BACK: + { + int32_t t = ggml_get_i32_1d(tensor->opt[2], 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor); + } break; case GGML_OP_MAP_UNARY: { const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data); @@ -13253,6 +14534,16 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun); } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor); + } + break; case GGML_OP_NONE: { // nop @@ -13391,11 +14682,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_impl(ctx, src0->grad, - ggml_mul(ctx, - tensor->grad, // this was not catched by test_grad because in test_grad tensor->grad is 1 + ggml_scale(ctx, ggml_div(ctx, - ggml_repeat(ctx, ggml_new_f32(ctx, 0.5f), tensor), - tensor)), + tensor->grad, + tensor), + ggml_new_f32(ctx, 0.5f)), inplace); } } break; @@ -13441,43 +14732,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - GGML_ASSERT(src0->n_dims == 1 || src0->n_dims == 2); - const int nc = tensor->ne[0]; - const int nr = tensor->ne[1]; - const int nc0 = src0->ne[0]; - const int nr0 = src0->ne[1]; - const int ncr = nc/nc0; // guaranteed to be an integer due to the check in ggml_can_repeat - const int nrr = nr/nr0; // guaranteed to be an integer due to the check in ggml_can_repeat - // tensor->grad [nc,nr,1,1] - // reshape [nc0,nc/nc0,nr0,nr/nr0] - // permute [nc0,nr0,nc/nc0,nr/nr0] - // substitute [nc0,nr0,ncr,nrr] - // reshape [nc0*nr0,ncr*nrr,1,1] - // transpose [ncr*nrr,nc0*nr0,1,1] - // sum rows [1,nc0*nr0,1,1] - // transpose [nc0*nr0,1,1] - // reshape [nc0,nr0,1,1] reshape_1d or reshape_2d - // add to src0->grad - - int64_t ne[4] = {nc0,ncr,nr0,nrr}; - - struct ggml_tensor* F00 = tensor->grad; - struct ggml_tensor* F01 = ggml_reshape (ctx, F00, ggml_new_tensor(ctx,tensor->grad->type,4,ne)); - struct ggml_tensor* F02 = ggml_permute (ctx, F01, 0,2,1,3); - struct ggml_tensor* F03 = ggml_cont (ctx, F02); - struct ggml_tensor* F04 = ggml_reshape_2d(ctx, F03, nc0*nr0, ncr*nrr); - struct ggml_tensor* F05 = ggml_transpose (ctx, F04); - struct ggml_tensor* F06 = ggml_cont (ctx, F05); - struct ggml_tensor* F07 = ggml_sum_rows (ctx, F06); - struct ggml_tensor* F08 = ggml_transpose (ctx, F07); - struct ggml_tensor* F09 = ggml_cont (ctx, F08); - struct ggml_tensor* F10 = ggml_reshape (ctx, F09, src0->grad); - - src0->grad = - ggml_add_impl(ctx, - src0->grad, - F10, - inplace); + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_repeat_back(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_REPEAT_BACK: + { + if (src0->grad) { + // TODO: test this + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, tensor->grad, src0->grad), + inplace); } } break; case GGML_OP_ABS: @@ -13584,38 +14852,37 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { - // TODO: this requires outer product - ggml_out_prod(ctx, src1, tensor->grad); src0->grad = ggml_add_impl(ctx, src0->grad, - // ds0 = dt.dot(s1.T) - // ggml_out_prod(ctx, // [n,m] - // src1, // [n,p] - // tensor->grad), // [m,p] - // for now just using A*B==(B.T*A.T).T - ggml_cont(ctx, // [n,m] - ggml_transpose(ctx, // [n,m] - ggml_mul_mat(ctx, // [m,n] - ggml_cont(ctx, // [p,m] - ggml_transpose(ctx, // [p,m] - tensor->grad)), // [m,p] - ggml_cont(ctx, // [p,n] - ggml_transpose(ctx, // [p,n] - src1))))), // [n,p] + ggml_out_prod(ctx, // [n,m] + src1, // [n,p] + tensor->grad), // [m,p] inplace); } if (src1->grad) { src1->grad = ggml_add_impl(ctx, src1->grad, - // ds1 = s0.T.dot(dt): - ggml_mul_mat(ctx, // [n,p] - ggml_cont(ctx, // [m,n] - ggml_transpose(ctx, src0)), // [m,n] - tensor->grad), // [m,p] + // ggml_mul_mat(ctx, // [n,p] + // ggml_cont(ctx, // [m,n] + // ggml_transpose(ctx, src0)), // [m,n] + // tensor->grad), // [m,p] + + // // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // // avoid transpose of src0, rather transpose smaller tensor->grad + // // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p] + src0, // [n,m] + ggml_transpose(ctx, // [p,m] + tensor->grad)), // [m,p] inplace); } } break; + case GGML_OP_OUT_PROD: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SCALE: { // necessary for llama @@ -13717,7 +14984,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { size_t offset; - memcpy(&offset, tensor->padding, sizeof(offset)); + + GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0])); + memcpy(&offset, tensor->opt[0]->data, sizeof(offset)); size_t nb1 = tensor->nb[1]; size_t nb2 = tensor->nb[2]; @@ -13744,10 +15013,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - int axis0 = tensor->padding[0] & 0x3; - int axis1 = tensor->padding[1] & 0x3; - int axis2 = tensor->padding[2] & 0x3; - int axis3 = tensor->padding[3] & 0x3; + int32_t * axes = (int32_t *) tensor->opt[0]->data; + int axis0 = axes[0] & 0x3; + int axis1 = axes[1] & 0x3; + int axis2 = axes[2] & 0x3; + int axis3 = axes[3] & 0x3; int axes_backward[4] = {0,0,0,0}; axes_backward[axis0] = 0; axes_backward[axis1] = 1; @@ -13831,50 +15101,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - // y = softmax(x) - // - // Jii = yi - yi*yi - // Jij = -yi*yj - // J = diag(y)-y.*y - // dx = J * dy - // dxk = sum(Jkj * dyk) - - int64_t ne2[4] = { - tensor->ne[0], - 1, - tensor->ne[1]*tensor->ne[2], - tensor->ne[3] - }; - struct ggml_tensor * tensor2 = ggml_cont(ctx, - ggml_reshape_4d(ctx, - ggml_cont(ctx, tensor), - ne2[0], ne2[1], ne2[2], ne2[3])); - - struct ggml_tensor * grad2 = ggml_cont(ctx, - ggml_reshape_4d(ctx, - ggml_cont(ctx, tensor->grad), - ne2[0], ne2[1], ne2[2], ne2[3])); - - struct ggml_tensor * tensor2_t = ggml_cont(ctx, // [1,ne0,ne1*ne2,ne3] - ggml_permute(ctx, // [1,ne0,ne1*ne2,ne3] - tensor2, // [ne0,1,ne1*ne2,ne3] - 1, 0, 2, 3)); - src0->grad = - ggml_add_impl(ctx, - src0->grad, // [ne0,ne1,ne2,ne3] - ggml_reshape(ctx, // [ne0,ne1,ne2,ne3] - ggml_mul_mat(ctx, // [ne0,1,ne1*ne2,ne3] - ggml_sub(ctx, // [ne0,ne0,ne1*ne2,ne3] - ggml_diag(ctx, // [ne0,ne0,ne1*ne2,ne3] - tensor2), // [ne0,1,ne1*ne2,ne3] - ggml_mul_mat(ctx, // [ne0,ne0,ne1*ne2,ne3] - tensor2_t, // [1,ne0,ne1*ne2,ne3] - tensor2_t)), // [1,ne0,ne1*ne2,ne3] - grad2), // [ne0,1,ne1*ne2,ne3] - src0->grad), - inplace); + ggml_add_impl(ctx, src0->grad, + ggml_soft_max_back(ctx, tensor->grad, tensor), + inplace); } + + } break; + case GGML_OP_SOFT_MAX_BACK: + { + GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_ROPE: { @@ -13929,17 +15165,190 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_FLASH_ATTN: { - GGML_ASSERT(false); // not supported + struct ggml_tensor * flash_grad = NULL; + if (src0->grad || src1->grad || tensor->opt[0]->grad) { + int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + flash_grad = + ggml_flash_attn_back(ctx, + src0, + src1, + tensor->opt[0], + tensor->grad, + masked); + } + + if (src0->grad) { + struct ggml_tensor * grad_q = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = 0; + switch(src0->n_dims) { + case 2: + { + grad_q = ggml_view_2d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + nb0*src0->ne[0], + offset); + } break; + case 3: + { + grad_q = ggml_view_3d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + src0->ne[2], + nb0*src0->ne[0], + nb0*src0->ne[0]*src0->ne[1], + offset); + } break; + case 4: + { + grad_q = ggml_view_4d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + src0->ne[2], + src0->ne[3], + nb0*src0->ne[0], + nb0*src0->ne[0]*src0->ne[1], + nb0*src0->ne[0]*src0->ne[1]*src0->ne[2], + offset); + } break; + } + + src0->grad = ggml_add_impl(ctx, + src0->grad, + grad_q, + inplace); + } + + if (src1->grad) { + struct ggml_tensor * grad_k = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]; + switch(src1->n_dims) { + case 2: + { + grad_k = ggml_view_2d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + nb0*src1->ne[0], + offset); + } break; + case 3: + { + grad_k = ggml_view_3d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + src1->ne[2], + nb0*src1->ne[0], + nb0*src1->ne[0]*src1->ne[1], + offset); + } break; + case 4: + { + grad_k = ggml_view_4d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + src1->ne[2], + src1->ne[3], + nb0*src1->ne[0], + nb0*src1->ne[0]*src1->ne[1], + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2], + offset); + } break; + } + + src1->grad = ggml_add_impl(ctx, + src1->grad, + grad_k, + inplace); + } + + struct ggml_tensor * opt0 = tensor->opt[0]; + + if (opt0->grad) { + struct ggml_tensor * grad_v = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3] + + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3]; + switch(opt0->n_dims) { + case 2: + { + grad_v = ggml_view_2d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + nb0*opt0->ne[0], + offset); + } break; + case 3: + { + grad_v = ggml_view_3d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + opt0->ne[2], + nb0*opt0->ne[0], + nb0*opt0->ne[0]*opt0->ne[1], + offset); + } break; + case 4: + { + grad_v = ggml_view_4d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + opt0->ne[2], + opt0->ne[3], + nb0*opt0->ne[0], + nb0*opt0->ne[0]*opt0->ne[1], + nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2], + offset); + } break; + } + + opt0->grad = ggml_add_impl(ctx, + opt0->grad, + grad_v, + inplace); + } } break; case GGML_OP_FLASH_FF: { GGML_ASSERT(false); // not supported } break; + case GGML_OP_FLASH_ATTN_BACK: + { + GGML_ASSERT(false); // not supported + } break; case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: { GGML_ASSERT(false); // not supported } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_cross_entropy_loss_back(ctx, + src0, + src1, + tensor->grad), + inplace); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + GGML_ASSERT(false); // not supported + } break; case GGML_OP_NONE: { // nop @@ -14316,6 +15725,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: case GGML_OP_ABS: case GGML_OP_SGN: case GGML_OP_NEG: @@ -14335,6 +15745,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->n_tasks = n_threads; } break; case GGML_OP_MUL_MAT: + case GGML_OP_OUT_PROD: { node->n_tasks = n_threads; @@ -14417,6 +15828,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { @@ -14496,6 +15908,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + const int64_t D = node->src0->ne[0]; + const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2 + } + work_size = MAX(work_size, cur); } break; case GGML_OP_MAP_UNARY: @@ -14503,6 +15936,22 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + node->n_tasks = n_threads; + + size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks); + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + node->n_tasks = n_threads; + + size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks; + + work_size = MAX(work_size, cur); + } break; case GGML_OP_NONE: { node->n_tasks = 1; @@ -15478,6 +16927,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g static enum ggml_opt_result ggml_opt_adam( struct ggml_context * ctx, + struct ggml_opt_context * opt, struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, @@ -15503,25 +16953,29 @@ static enum ggml_opt_result ggml_opt_adam( } } + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { + int iter = opt->iter; + ggml_opt_init(opt->ctx, opt, params, nx); + opt->iter = iter; + } + // constants - const float alpha = params.adam.alpha; + const float sched = params.adam.sched; + const float decay = params.adam.decay * sched; + const float alpha = params.adam.alpha * sched; const float beta1 = params.adam.beta1; const float beta2 = params.adam.beta2; const float eps = params.adam.eps; - float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // view of the parameters - float * g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient - float * g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // gradient squared - float * m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment - float * v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment - float * mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // first moment hat - float * vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // second moment hat + float * x = opt->adam.x->data; // view of the parameters + float * g1 = opt->adam.g1->data; // gradient + float * g2 = opt->adam.g2->data; // gradient squared + float * m = opt->adam.m->data; // first moment + float * v = opt->adam.v->data; // second moment + float * mh = opt->adam.mh->data; // first moment hat + float * vh = opt->adam.vh->data; // second moment hat - float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values - - // initialize - ggml_vec_set_f32(nx, m, 0.0f); - ggml_vec_set_f32(nx, v, 0.0f); + float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values // update view ggml_opt_get_params(np, ps, x); @@ -15531,16 +16985,27 @@ static enum ggml_opt_result ggml_opt_adam( ggml_set_f32 (f->grad, 1.0f); ggml_graph_compute(ctx, gb); - float fx_prev = ggml_get_f32_1d(f, 0); + opt->adam.fx_prev = ggml_get_f32_1d(f, 0); + opt->adam.fx_best = opt->adam.fx_prev; if (pf) { - pf[0] = fx_prev; + pf[opt->iter % params.past] = opt->adam.fx_prev; } - int n_no_improvement = 0; - float fx_best = fx_prev; + // initialize + if (opt->just_initialized) { + opt->adam.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->adam.fx_best; + float * fx_prev = &opt->adam.fx_prev; + int * n_no_improvement = &opt->adam.n_no_improvement; + + int iter0 = opt->iter; // run the optimizer for (int t = 0; t < params.adam.n_iter; ++t) { + opt->iter = iter0 + t + 1; GGML_PRINT_DEBUG ("=== iter %d ===\n", t); GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); @@ -15574,17 +17039,22 @@ static enum ggml_opt_result ggml_opt_adam( // m^hat = m_t / (1 - beta1^t) // v^hat = v_t / (1 - beta2^t) - // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) + // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1) + // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1 + // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps) + // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps) + // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay) ggml_vec_cpy_f32 (nx, mh, m); ggml_vec_cpy_f32 (nx, vh, v); - ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, t + 1))); - ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, t + 1))); + ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter))); + ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter))); ggml_vec_sqrt_f32 (nx, vh, vh); ggml_vec_acc1_f32 (nx, vh, eps); ggml_vec_div_f32 (nx, mh, mh, vh); + ggml_vec_scale_f32(nx, x, 1.0f - decay); ggml_vec_sub_f32 (nx, x, x, mh); // update the parameters @@ -15598,7 +17068,7 @@ static enum ggml_opt_result ggml_opt_adam( const float fx = ggml_get_f32_1d(f, 0); // check convergence - if (fabsf(fx - fx_prev)/fx < params.adam.eps_f) { + if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { GGML_PRINT_DEBUG("converged\n"); return GGML_OPT_OK; @@ -15607,32 +17077,32 @@ static enum ggml_opt_result ggml_opt_adam( // delta-based convergence test if (pf != NULL) { // need at least params.past iterations to start checking for convergence - if (params.past <= t) { - const float rate = (pf[t%params.past] - fx)/fx; + if (params.past <= iter0 + t) { + const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; if (fabsf(rate) < params.delta) { return GGML_OPT_OK; } } - pf[t%params.past] = fx; + pf[(iter0 + t)%params.past] = fx; } // check for improvement if (params.max_no_improvement > 0) { - if (fx_best > fx) { - fx_best = fx; - n_no_improvement = 0; + if (fx_best[0] > fx) { + fx_best[0] = fx; + n_no_improvement[0] = 0; } else { - ++n_no_improvement; + ++n_no_improvement[0]; - if (n_no_improvement >= params.max_no_improvement) { + if (n_no_improvement[0] >= params.max_no_improvement) { return GGML_OPT_OK; } } } - fx_prev = fx; + fx_prev[0] = fx; { const int64_t t_end_cpu = ggml_cycles(); @@ -15771,6 +17241,7 @@ static enum ggml_opt_result linesearch_backtracking( static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_context * ctx, + struct ggml_opt_context * opt, struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, @@ -15803,31 +17274,32 @@ static enum ggml_opt_result ggml_opt_lbfgs( } } - float * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current parameters - float * xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous parameters - float * g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // current gradient - float * gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // previous gradient - float * d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; // search direction + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { + int iter = opt->iter; + ggml_opt_init(ctx, opt, params, nx); + opt->iter = iter; + } - float * pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)->data : NULL; // past function values + float * x = opt->lbfgs.x->data; // current parameters + float * xp = opt->lbfgs.xp->data; // previous parameters + float * g = opt->lbfgs.g->data; // current gradient + float * gp = opt->lbfgs.gp->data; // previous gradient + float * d = opt->lbfgs.d->data; // search direction + + float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values float fx = 0.0f; // cost function value float xnorm = 0.0f; // ||x|| float gnorm = 0.0f; // ||g|| - float step = 0.0f; // initialize x from the graph nodes ggml_opt_get_params(np, ps, x); // the L-BFGS memory - struct ggml_lbfgs_iteration_data * lm = alloca(sizeof(struct ggml_lbfgs_iteration_data)*m); - - for (int i = 0; i < m; ++i) { - lm[i].alpha = 0.0f; - lm[i].ys = 0.0f; - lm[i].s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; - lm[i].y = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx)->data; - } + float * lm_alpha = opt->lbfgs.lmal->data; + float * lm_ys = opt->lbfgs.lmys->data; + float * lm_s = opt->lbfgs.lms->data; + float * lm_y = opt->lbfgs.lmy->data; // evaluate the function value and its gradient { @@ -15842,12 +17314,6 @@ static enum ggml_opt_result ggml_opt_lbfgs( fx = ggml_get_f32_1d(f, 0); } - if (pf) { - pf[0] = fx; - } - - float fx_best = fx; - // search direction = -gradient ggml_vec_neg_f32(nx, d, g); @@ -15864,26 +17330,43 @@ static enum ggml_opt_result ggml_opt_lbfgs( return GGML_OPT_OK; } - // initial step - ggml_vec_norm_inv_f32(nx, &step, d); + if (opt->just_initialized) { + if (pf) { + pf[0] = fx; + } + opt->lbfgs.fx_best = fx; - int j = 0; - int k = 1; - int ls = 0; - int end = 0; - int bound = 0; - int n_no_improvement = 0; + // initial step + ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); + opt->lbfgs.j = 0; + opt->lbfgs.k = 1; + opt->lbfgs.end = 0; + opt->lbfgs.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->lbfgs.fx_best; + float * step = &opt->lbfgs.step; + int * j = &opt->lbfgs.j; + int * k = &opt->lbfgs.k; + int * end = &opt->lbfgs.end; + int * n_no_improvement = &opt->lbfgs.n_no_improvement; + + int ls = 0; + int bound = 0; float ys = 0.0f; float yy = 0.0f; float beta = 0.0f; + int it = 0; + while (true) { // store the current position and gradient vectors ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, &step, xp, f, gf, gb, np, ps); + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -15909,32 +17392,32 @@ static enum ggml_opt_result ggml_opt_lbfgs( // delta-based convergence test if (pf != NULL) { // need at least params.past iterations to start checking for convergence - if (params.past <= k) { - const float rate = (pf[k%params.past] - fx)/fx; + if (params.past <= k[0]) { + const float rate = (pf[k[0]%params.past] - fx)/fx; if (fabsf(rate) < params.delta) { return GGML_OPT_OK; } } - pf[k%params.past] = fx; + pf[k[0]%params.past] = fx; } // check for improvement if (params.max_no_improvement > 0) { - if (fx < fx_best) { - fx_best = fx; - n_no_improvement = 0; + if (fx < fx_best[0]) { + fx_best[0] = fx; + n_no_improvement[0] = 0; } else { - n_no_improvement++; + n_no_improvement[0]++; - if (n_no_improvement >= params.max_no_improvement) { + if (n_no_improvement[0] >= params.max_no_improvement) { return GGML_OPT_OK; } } } - if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < k + 1) { + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { // reached the maximum number of iterations return GGML_OPT_DID_NOT_CONVERGE; } @@ -15943,50 +17426,51 @@ static enum ggml_opt_result ggml_opt_lbfgs( // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. // y_{k+1} = g_{k+1} - g_{k}. // - ggml_vec_sub_f32(nx, lm[end].s, x, xp); - ggml_vec_sub_f32(nx, lm[end].y, g, gp); + ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); + ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); // compute scalars ys and yy: // ys = y^t \cdot s -> 1 / \rho. // yy = y^t \cdot y. // - ggml_vec_dot_f32(nx, &ys, lm[end].y, lm[end].s); - ggml_vec_dot_f32(nx, &yy, lm[end].y, lm[end].y); + ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]); + ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); - lm[end].ys = ys; + lm_ys[end[0]] = ys; // find new search direction // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS - bound = (m <= k) ? m : k; - k++; - end = (end + 1)%m; + bound = (m <= k[0]) ? m : k[0]; + k[0]++; + it++; + end[0] = (end[0] + 1)%m; // initialize search direction with -g ggml_vec_neg_f32(nx, d, g); - j = end; + j[0] = end[0]; for (int i = 0; i < bound; ++i) { - j = (j + m - 1) % m; + j[0] = (j[0] + m - 1) % m; // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} - ggml_vec_dot_f32(nx, &lm[j].alpha, lm[j].s, d); - lm[j].alpha /= lm[j].ys; + ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d); + lm_alpha[j[0]] /= lm_ys[j[0]]; // q_{i} = q_{i+1} - \alpha_{i} y_{i} - ggml_vec_mad_f32(nx, d, lm[j].y, -lm[j].alpha); + ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); } ggml_vec_scale_f32(nx, d, ys/yy); for (int i = 0; i < bound; ++i) { // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} - ggml_vec_dot_f32(nx, &beta, lm[j].y, d); - beta /= lm[j].ys; + ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d); + beta /= lm_ys[j[0]]; // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} - ggml_vec_mad_f32(nx, d, lm[j].s, lm[j].alpha - beta); - j = (j + 1)%m; + ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); + j[0] = (j[0] + 1)%m; } - step = 1.0; + step[0] = 1.0; } return GGML_OPT_DID_NOT_CONVERGE; @@ -16011,6 +17495,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .adam = { .n_iter = 10000, + .sched = 1.000f, + .decay = 0.001f, .alpha = 0.001f, .beta1 = 0.9f, .beta2 = 0.999f, @@ -16053,6 +17539,71 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { return result; } +GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx) { + opt->ctx = ctx; + opt->params = params; + opt->iter = 0; + opt->nx = nx; + opt->just_initialized = true; + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.pf = params.past > 0 + ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + : NULL; + ggml_set_zero(opt->adam.x); + ggml_set_zero(opt->adam.g1); + ggml_set_zero(opt->adam.g2); + ggml_set_zero(opt->adam.m); + ggml_set_zero(opt->adam.v); + ggml_set_zero(opt->adam.mh); + ggml_set_zero(opt->adam.vh); + if (opt->adam.pf) { + ggml_set_zero(opt->adam.pf); + } + } break; + case GGML_OPT_LBFGS: + { + opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.pf = params.past > 0 + ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + : NULL; + opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + ggml_set_zero(opt->lbfgs.x); + ggml_set_zero(opt->lbfgs.xp); + ggml_set_zero(opt->lbfgs.g); + ggml_set_zero(opt->lbfgs.gp); + ggml_set_zero(opt->lbfgs.d); + ggml_set_zero(opt->lbfgs.pf); + if (opt->lbfgs.pf) { + ggml_set_zero(opt->lbfgs.pf); + } + ggml_set_zero(opt->lbfgs.lmal); + ggml_set_zero(opt->lbfgs.lmys); + ggml_set_zero(opt->lbfgs.lms); + ggml_set_zero(opt->lbfgs.lmy); + } break; + } +} + enum ggml_opt_result ggml_opt( struct ggml_context * ctx, struct ggml_opt_params params, @@ -16075,30 +17626,10 @@ enum ggml_opt_result ggml_opt( enum ggml_opt_result result = GGML_OPT_OK; - // build forward + backward compute graphs - struct ggml_cgraph gf = ggml_build_forward (f); - struct ggml_cgraph gb = ggml_build_backward(ctx, &gf, true); + struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); - switch (params.type) { - case GGML_OPT_ADAM: - { - result = ggml_opt_adam(ctx, params, f, &gf, &gb); - } break; - case GGML_OPT_LBFGS: - { - result = ggml_opt_lbfgs(ctx, params, f, &gf, &gb); - } break; - } - - if (params.print_forward_graph) { - ggml_graph_print (&gf); - ggml_graph_dump_dot(&gf, NULL, "opt-forward.dot"); - } - - if (params.print_backward_graph) { - ggml_graph_print (&gb); - ggml_graph_dump_dot(&gb, &gf, "opt-backward.dot"); - } + ggml_opt_init(ctx, opt, params, 0); + result = ggml_opt_resume(ctx, opt, f); if (free_ctx) { ggml_free(ctx); @@ -16107,6 +17638,58 @@ enum ggml_opt_result ggml_opt( return result; } +enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f) { + + // build forward + backward compute graphs + struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0)); + struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0)); + + struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; + struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; + + *gf = ggml_build_forward (f); + *gb = ggml_build_backward(ctx, gf, true); + + return ggml_opt_resume_g(ctx, opt, f, gf, gb); +} + +enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb) { + + // build forward + backward compute graphs + enum ggml_opt_result result = GGML_OPT_OK; + + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb); + } break; + case GGML_OPT_LBFGS: + { + result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb); + } break; + } + + if (opt->params.print_forward_graph) { + ggml_graph_print (gf); + ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); + } + + if (opt->params.print_backward_graph) { + ggml_graph_print (gb); + ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); + } + + return result; +} + //////////////////////////////////////////////////////////////////////////////// size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { diff --git a/ggml.h b/ggml.h index 1b26da3ad..f2a91761b 100644 --- a/ggml.h +++ b/ggml.h @@ -296,6 +296,7 @@ extern "C" { GGML_OP_SUM_ROWS, GGML_OP_MEAN, GGML_OP_REPEAT, + GGML_OP_REPEAT_BACK, GGML_OP_ABS, GGML_OP_SGN, GGML_OP_NEG, @@ -309,6 +310,7 @@ extern "C" { GGML_OP_RMS_NORM_BACK, GGML_OP_MUL_MAT, + GGML_OP_OUT_PROD, GGML_OP_SCALE, GGML_OP_SET, @@ -324,6 +326,7 @@ extern "C" { GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, GGML_OP_SOFT_MAX, + GGML_OP_SOFT_MAX_BACK, GGML_OP_ROPE, GGML_OP_ROPE_BACK, GGML_OP_ALIBI, @@ -333,10 +336,14 @@ extern "C" { GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, + GGML_OP_FLASH_ATTN_BACK, GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, + GGML_OP_CROSS_ENTROPY_LOSS, + GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_COUNT, }; @@ -574,6 +581,11 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_acc( struct ggml_context * ctx, struct ggml_tensor * a, @@ -645,6 +657,11 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_abs( struct ggml_context * ctx, struct ggml_tensor * a); @@ -698,14 +715,22 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // A: m rows, n columns - // B: p rows, n columns (i.e. we transpose it internally) + // A: n columns, m rows + // B: n columns, p rows (i.e. we transpose it internally) // result is m columns, p rows GGML_API struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // // operations on tensors without backpropagation // @@ -916,6 +941,17 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // rotary position embedding // if mode & 1 == 1, skip n_past elements // if mode & 2 == 1, GPT-NeoX style @@ -982,6 +1018,14 @@ extern "C" { struct ggml_tensor * v, bool masked); + GGML_API struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked); + GGML_API struct ggml_tensor * ggml_flash_ff( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1005,6 +1049,19 @@ extern "C" { struct ggml_tensor * b, ggml_binary_op_f32_t fun); + // loss function + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + // // automatic differentiation // @@ -1099,6 +1156,8 @@ extern "C" { struct { int n_iter; + float sched; // schedule multiplier (fixed, decay or warmup) + float decay; // weight decay for AdamW, use 0.0f to disable float alpha; // learning rate float beta1; float beta2; @@ -1123,6 +1182,49 @@ extern "C" { } lbfgs; }; + struct ggml_opt_context { + struct ggml_context * ctx; + struct ggml_opt_params params; + + int iter; + int64_t nx; // number of parameter elements + + bool just_initialized; + + struct { + struct ggml_tensor * x; // view of the parameters + struct ggml_tensor * g1; // gradient + struct ggml_tensor * g2; // gradient squared + struct ggml_tensor * m; // first moment + struct ggml_tensor * v; // second moment + struct ggml_tensor * mh; // first moment hat + struct ggml_tensor * vh; // second moment hat + struct ggml_tensor * pf; // past function values + float fx_best; + float fx_prev; + int n_no_improvement; + } adam; + + struct { + struct ggml_tensor * x; // current parameters + struct ggml_tensor * xp; // previous parameters + struct ggml_tensor * g; // current gradient + struct ggml_tensor * gp; // previous gradient + struct ggml_tensor * d; // search direction + struct ggml_tensor * pf; // past function values + struct ggml_tensor * lmal; // the L-BFGS memory alpha + struct ggml_tensor * lmys; // the L-BFGS memory ys + struct ggml_tensor * lms; // the L-BFGS memory s + struct ggml_tensor * lmy; // the L-BFGS memory y + float fx_best; + float step; + int j; + int k; + int end; + int n_no_improvement; + } lbfgs; + }; + GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); // optimize the function defined by the tensor f @@ -1131,6 +1233,27 @@ extern "C" { struct ggml_opt_params params, struct ggml_tensor * f); + // initialize optimizer context + GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb); + // // quantization // diff --git a/llama.cpp b/llama.cpp index c7a333642..d2a52bb0c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1036,6 +1036,12 @@ static void llama_model_load_internal( case 40: model.type = e_model::MODEL_13B; break; case 60: model.type = e_model::MODEL_30B; break; case 80: model.type = e_model::MODEL_65B; break; + default: + { + if (hparams.n_layer < 32) { + model.type = e_model::MODEL_7B; + } + } break; } hparams.n_ctx = n_ctx; @@ -1200,6 +1206,7 @@ static void llama_model_load_internal( mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); (void) vram_scratch; + (void) n_batch; #ifdef GGML_USE_CUBLAS vram_scratch = n_batch * MB; ggml_cuda_set_scratch_size(vram_scratch); @@ -1227,6 +1234,7 @@ static void llama_model_load_internal( model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); } + (void) tensor_split; #if defined(GGML_USE_CUBLAS) { ggml_cuda_set_tensor_split(tensor_split); @@ -2161,6 +2169,10 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok return -log2f(candidate.p) > *mu; })); + if (candidates->size == 0) { + candidates->size = 1; + } + // Normalize the probabilities of the remaining words llama_sample_softmax(ctx, candidates); @@ -3287,6 +3299,19 @@ int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } +int llama_get_vocab( + const struct llama_context * ctx, + const char * * strings, + float * scores, + int capacity) { + int n = std::min(capacity, (int) ctx->vocab.id_to_token.size()); + for (int i = 0; ivocab.id_to_token[i].tok.c_str(); + scores[i] = ctx->vocab.id_to_token[i].score; + } + return n; +} + float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } diff --git a/llama.h b/llama.h index 7c7fd481c..61f6c867d 100644 --- a/llama.h +++ b/llama.h @@ -220,6 +220,14 @@ extern "C" { LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx); + // Get the vocabulary as output parameters. + // Returns number of results. + LLAMA_API int llama_get_vocab( + const struct llama_context * ctx, + const char * * strings, + float * scores, + int capacity); + // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row // Can be mutated in order to change the probabilities of the next token diff --git a/tests/test-grad0.c b/tests/test-grad0.c index ec5059220..c8c2c0f71 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -5,7 +5,7 @@ #include #include -#define MAX_NARGS 2 +#define MAX_NARGS 3 #undef MIN #undef MAX @@ -1090,6 +1090,25 @@ int main(int argc, const char ** argv) { } } + // cross_entropy_loss + { + const int nargs = 1; + + int64_t ne2[4]; + get_random_dims(ne2, 4); + + for (int ndims = 1; ndims <= 3; ++ndims) { + x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); + x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f); + ggml_set_param(ctx0, x[0]); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1])); + + check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-1f, 1e-2f, INFINITY); + // finite differences regularly fails! + } + } + // rope { const int nargs = 1; @@ -1124,6 +1143,45 @@ int main(int argc, const char ** argv) { } } + // flash_attn + { + const int nargs = 3; + + int64_t ne2[4]; + + get_random_dims(ne2, 4); + int64_t D = ne2[0]; + int64_t N = ne2[1]; + int64_t M = ne2[2] + N; + int64_t B = ne2[3]; + + for (int masked = 0; masked <= 1; ++masked) { + for (int ndims = 2; ndims <= 4; ++ndims) { + int64_t neq[4] = { D, N, B, ne[3] }; + int64_t nek[4] = { D, M, B, ne[3] }; + int64_t nev[4] = { M, D, B, ne[3] }; + if (ndims == 2) { + neq[2] = 1; neq[3] = 1; + nek[2] = 1; nek[3] = 1; + nev[2] = 1; nev[3] = 1; + } else if (ndims == 3) { + neq[3] = 1; + nek[3] = 1; + nev[3] = 1; + } + x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f); + x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f); + x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f); + ggml_set_param(ctx0, x[0]); + ggml_set_param(ctx0, x[1]); + ggml_set_param(ctx0, x[2]); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); + + check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); + } + } + } ggml_free(ctx0); }