From f85020b19ac853a6bbad6092e0cc344e27553aea Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 May 2023 20:49:24 +0300 Subject: [PATCH] mtl : export the LLaMA computation graph --- examples/CMakeLists.txt | 1 + examples/mtl/CMakeLists.txt | 7 ++++++ examples/mtl/mtl-export.cpp | 25 +++++++++++++++++++++ llama.cpp | 44 ++++++++++++++++++++++++++++--------- llama.h | 4 ++++ 5 files changed, 71 insertions(+), 10 deletions(-) create mode 100644 examples/mtl/CMakeLists.txt create mode 100644 examples/mtl/mtl-export.cpp diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e4ce5aca7..97a3ffd1b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -37,6 +37,7 @@ else() add_subdirectory(save-load-state) add_subdirectory(benchmark) add_subdirectory(baby-llama) + add_subdirectory(mtl) if(LLAMA_BUILD_SERVER) add_subdirectory(server) endif() diff --git a/examples/mtl/CMakeLists.txt b/examples/mtl/CMakeLists.txt new file mode 100644 index 000000000..4dc0bc596 --- /dev/null +++ b/examples/mtl/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET mtl-export) +add_executable(${TARGET} mtl-export.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/mtl/mtl-export.cpp b/examples/mtl/mtl-export.cpp new file mode 100644 index 000000000..7872182a1 --- /dev/null +++ b/examples/mtl/mtl-export.cpp @@ -0,0 +1,25 @@ +#include "common.h" +#include "llama.h" + +int main(int argc, char ** argv) { + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + return 1; + } + + llama_init_backend(); + + llama_context * ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; + } + + llama_eval_export(ctx, "llama.ggml"); + + llama_print_timings(ctx); + llama_free(ctx); + + return 0; +} diff --git a/llama.cpp b/llama.cpp index 5a19316b3..9dccf0ed1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1189,17 +1189,19 @@ static bool llama_model_load( // evaluate the transformer // -// - lctx: llama context -// - tokens: new batch of tokens to process -// - n_past: the context size so far -// - n_threads: number of threads to use +// - lctx: llama context +// - tokens: new batch of tokens to process +// - n_past: the context size so far +// - n_threads: number of threads to use +// - cgraph_fname: filename of the exported computation graph (TODO: TMP!!!) // static bool llama_eval_internal( - llama_context & lctx, - const llama_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads) { + llama_context & lctx, + const llama_token * tokens, + const int n_tokens, + const int n_past, + const int n_threads, + const char * cgraph_fname) { // enforce that the first token is BOS if (n_past == 0 && tokens[0] != llama_token_bos()) { @@ -1422,6 +1424,10 @@ static bool llama_eval_internal( ggml_build_forward_expand(&gf, inpL); ggml_graph_compute (ctx0, &gf); + if (cgraph_fname) { + ggml_graph_export(&gf, cgraph_fname); + } + #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) // requires GGML_PERF to be defined @@ -2899,7 +2905,7 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2914,6 +2920,24 @@ int llama_eval( return 0; } +int llama_eval_export(struct llama_context * ctx, const char * fname) { + // these values determine the maximum inference sizes of the exported computation graph + // TODO: TMP !!! + //const int n_ctx = ctx->model.hparams.n_ctx; + //const int n_batch = 512; + const int n_ctx = 128; + const int n_batch = 32; + + const std::vector tmp(n_batch, llama_token_bos()); + + if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + int llama_tokenize( struct llama_context * ctx, const char * text, diff --git a/llama.h b/llama.h index c6b0a2889..3ba0775bd 100644 --- a/llama.h +++ b/llama.h @@ -173,6 +173,10 @@ extern "C" { int n_past, int n_threads); + // Export a computation graph for model inference + // TODO: very likely to change + LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens