From b41b4cad6f956b5f501db0711dd7007c32b5eee5 Mon Sep 17 00:00:00 2001 From: SuperUserNameMan Date: Fri, 16 Jun 2023 20:58:09 +0200 Subject: [PATCH] examples : add "simple" (#1840) * Create `simple.cpp` * minimalist example `CMakeLists.txt` * Update Makefile for minimalist example * remove 273: Trailing whitespace * removed trailing white spaces simple.cpp * typo and comments simple.cpp --------- Co-authored-by: Georgi Gerganov --- Makefile | 8 +- examples/simple/CMakeLists.txt | 7 ++ examples/simple/simple.cpp | 177 +++++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 examples/simple/CMakeLists.txt create mode 100644 examples/simple/simple.cpp diff --git a/Makefile b/Makefile index b24caf8dd..5306a114f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple ifdef LLAMA_BUILD_SERVER BUILD_TARGETS += server @@ -276,6 +276,12 @@ main: examples/main/main.cpp build-info.h ggml. @echo '==== Run ./main -h for help. ====' @echo +simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + @echo + @echo '==== Run ./simple -h for help. ====' + @echo + quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/examples/simple/CMakeLists.txt b/examples/simple/CMakeLists.txt new file mode 100644 index 000000000..1568f7364 --- /dev/null +++ b/examples/simple/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET simple) +add_executable(${TARGET} simple.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp new file mode 100644 index 000000000..76f991cdc --- /dev/null +++ b/examples/simple/simple.cpp @@ -0,0 +1,177 @@ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "common.h" +#include "llama.h" +#include "build-info.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include +#include +#endif + + + +int main(int argc, char ** argv) +{ + gpt_params params; + + //--------------------------------- + // Print help : + //--------------------------------- + + if ( argc == 1 || argv[1][0] == '-' ) + { + printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] ); + return 1 ; + } + + //--------------------------------- + // Load parameters : + //--------------------------------- + + if ( argc >= 2 ) + { + params.model = argv[1]; + } + + if ( argc >= 3 ) + { + params.prompt = argv[2]; + } + + if ( params.prompt.empty() ) + { + params.prompt = "Hello my name is"; + } + + //--------------------------------- + // Init LLM : + //--------------------------------- + + llama_init_backend(); + + llama_context * ctx ; + + ctx = llama_init_from_gpt_params( params ); + + if ( ctx == NULL ) + { + fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); + return 1; + } + + //--------------------------------- + // Tokenize the prompt : + //--------------------------------- + + std::vector tokens_list; + tokens_list = ::llama_tokenize( ctx , params.prompt , true ); + + const int max_context_size = llama_n_ctx( ctx ); + const int max_tokens_list_size = max_context_size - 4 ; + + if ( (int)tokens_list.size() > max_tokens_list_size ) + { + fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" , + __func__ , (int)tokens_list.size() , max_tokens_list_size ); + return 1; + } + + fprintf( stderr, "\n\n" ); + + // Print the tokens from the prompt : + + for( auto id : tokens_list ) + { + printf( "%s" , llama_token_to_str( ctx , id ) ); + } + + fflush(stdout); + + + //--------------------------------- + // Main prediction loop : + //--------------------------------- + + // The LLM keeps a contextual cache memory of previous token evaluation. + // Usually, once this cache is full, it is required to recompute a compressed context based on previous + // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist + // example, we will just stop the loop once this cache is full or once an end of stream is detected. + + while ( llama_get_kv_cache_token_count( ctx ) < max_context_size ) + { + //--------------------------------- + // Evaluate the tokens : + //--------------------------------- + + if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) + { + fprintf( stderr, "%s : failed to eval\n" , __func__ ); + return 1; + } + + tokens_list.clear(); + + //--------------------------------- + // Select the best prediction : + //--------------------------------- + + llama_token new_token_id = 0; + + auto logits = llama_get_logits( ctx ); + auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens) + + std::vector candidates; + candidates.reserve( n_vocab ); + + for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ ) + { + candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } ); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // Select it using the "Greedy sampling" method : + new_token_id = llama_sample_token_greedy( ctx , &candidates_p ); + + + // is it an end of stream ? + if ( new_token_id == llama_token_eos() ) + { + fprintf(stderr, " [end of text]\n"); + break; + } + + // Print the new token : + printf( "%s" , llama_token_to_str( ctx , new_token_id ) ); + fflush( stdout ); + + // Push this new token for next evaluation : + tokens_list.push_back( new_token_id ); + + } // wend of main loop + + llama_free( ctx ); + + return 0; +} + +// EOF