From 71a09da301705b9c5ad4ca3cf3fbd966dd3f1ec5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Oct 2023 18:32:51 +0200 Subject: [PATCH] llama : fix kv shift bug (#3835) ggml-ci --- llama.cpp | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1d1db8fc9..d8510a5cf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift( for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].pos += delta; + cache.has_shift = true; + cache.cells[i].pos += delta; + cache.cells[i].delta += delta; + if (cache.cells[i].pos < 0) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); if (new_head == cache.size) new_head = i; - } else { - cache.has_shift = true; - cache.cells[i].delta = delta; } } } @@ -6073,11 +6073,20 @@ static int llama_decode_internal( #endif // update the kv ring buffer - lctx.kv_self.has_shift = false; - lctx.kv_self.head += n_tokens; - // Ensure kv cache head points to a valid index. - if (lctx.kv_self.head >= lctx.kv_self.size) { - lctx.kv_self.head = 0; + { + if (kv_self.has_shift) { + kv_self.has_shift = false; + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].delta = 0; + } + } + + kv_self.head += n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } } #ifdef GGML_PERF