examples : add "simple" (#1840)

* Create `simple.cpp`

* minimalist example `CMakeLists.txt`

* Update Makefile for minimalist example

* remove 273: Trailing whitespace

* removed trailing white spaces simple.cpp

* typo and comments simple.cpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
SuperUserNameMan 2023-06-16 20:58:09 +02:00 committed by GitHub
parent 13fe9d2d84
commit b41b4cad6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 191 additions and 1 deletions

View file

@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple
ifdef LLAMA_BUILD_SERVER ifdef LLAMA_BUILD_SERVER
BUILD_TARGETS += server BUILD_TARGETS += server
@ -276,6 +276,12 @@ main: examples/main/main.cpp build-info.h ggml.
@echo '==== Run ./main -h for help. ====' @echo '==== Run ./main -h for help. ===='
@echo @echo
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./simple -h for help. ===='
@echo
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View file

@ -0,0 +1,7 @@
set(TARGET simple)
add_executable(${TARGET} simple.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

177
examples/simple/simple.cpp Normal file
View file

@ -0,0 +1,177 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <signal.h>
#endif
int main(int argc, char ** argv)
{
gpt_params params;
//---------------------------------
// Print help :
//---------------------------------
if ( argc == 1 || argv[1][0] == '-' )
{
printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] );
return 1 ;
}
//---------------------------------
// Load parameters :
//---------------------------------
if ( argc >= 2 )
{
params.model = argv[1];
}
if ( argc >= 3 )
{
params.prompt = argv[2];
}
if ( params.prompt.empty() )
{
params.prompt = "Hello my name is";
}
//---------------------------------
// Init LLM :
//---------------------------------
llama_init_backend();
llama_context * ctx ;
ctx = llama_init_from_gpt_params( params );
if ( ctx == NULL )
{
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
return 1;
}
//---------------------------------
// Tokenize the prompt :
//---------------------------------
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize( ctx , params.prompt , true );
const int max_context_size = llama_n_ctx( ctx );
const int max_tokens_list_size = max_context_size - 4 ;
if ( (int)tokens_list.size() > max_tokens_list_size )
{
fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" ,
__func__ , (int)tokens_list.size() , max_tokens_list_size );
return 1;
}
fprintf( stderr, "\n\n" );
// Print the tokens from the prompt :
for( auto id : tokens_list )
{
printf( "%s" , llama_token_to_str( ctx , id ) );
}
fflush(stdout);
//---------------------------------
// Main prediction loop :
//---------------------------------
// The LLM keeps a contextual cache memory of previous token evaluation.
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
while ( llama_get_kv_cache_token_count( ctx ) < max_context_size )
{
//---------------------------------
// Evaluate the tokens :
//---------------------------------
if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
{
fprintf( stderr, "%s : failed to eval\n" , __func__ );
return 1;
}
tokens_list.clear();
//---------------------------------
// Select the best prediction :
//---------------------------------
llama_token new_token_id = 0;
auto logits = llama_get_logits( ctx );
auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens)
std::vector<llama_token_data> candidates;
candidates.reserve( n_vocab );
for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ )
{
candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } );
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// Select it using the "Greedy sampling" method :
new_token_id = llama_sample_token_greedy( ctx , &candidates_p );
// is it an end of stream ?
if ( new_token_id == llama_token_eos() )
{
fprintf(stderr, " [end of text]\n");
break;
}
// Print the new token :
printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
fflush( stdout );
// Push this new token for next evaluation :
tokens_list.push_back( new_token_id );
} // wend of main loop
llama_free( ctx );
return 0;
}
// EOF