From c8d6a1f34ab6f1b6bd468d256e535a61f98f114c Mon Sep 17 00:00:00 2001 From: Thibault Terrasson Date: Fri, 27 Oct 2023 16:37:41 +0200 Subject: [PATCH] simple : fix batch handling (#3803) --- examples/simple/simple.cpp | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index f376c0509..374aef6f1 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -95,13 +95,8 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(512, 0, 1); // evaluate the initial prompt - batch.n_tokens = tokens_list.size(); - - for (int32_t i = 0; i < batch.n_tokens; i++) { - batch.token[i] = tokens_list[i]; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + for (size_t i = 0; i < tokens_list.size(); i++) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); } // llama_decode will output logits only for the last token of the prompt @@ -148,15 +143,10 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch - batch.n_tokens = 0; + llama_batch_clear(batch); // push this new token for next evaluation - batch.token [batch.n_tokens] = new_token_id; - batch.pos [batch.n_tokens] = n_cur; - batch.seq_id[batch.n_tokens] = 0; - batch.logits[batch.n_tokens] = true; - - batch.n_tokens += 1; + llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); n_decode += 1; }