server: tests: add truncated prompt tests, better kv cache size (#5933)

* server: tests: add truncated prompt tests, better size

* server, tests : update regex

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Pierrick Hymbert 2024-03-09 10:30:04 +01:00 committed by GitHub
parent c2101a2e90
commit fd72d2d2a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 23 deletions

View file

@ -1128,6 +1128,7 @@ struct server_context {
LOG_VERBOSE("stopped by limit", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_decoded", slot.n_decoded},
{"n_predict", slot.params.n_predict},
});
@ -1141,6 +1142,8 @@ struct server_context {
}
LOG_VERBOSE("next token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"token", result.tok},
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
{"has_next_token", slot.has_next_token},
@ -1750,6 +1753,15 @@ struct server_context {
slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("prompt tokenized", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
if (slot.embedding) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_batch) {
@ -1788,10 +1800,13 @@ struct server_context {
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
{"id_slot", slot.id},
{"id_task", slot.id_task},
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"n_prompt_tokens", slot.n_prompt_tokens},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);

View file

@ -6,8 +6,8 @@ Feature: Parallel
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And 42 as server seed
And 512 as batch size
And 64 KV cache size
And 128 as batch size
And 256 KV cache size
And 2 slots
And continuous batching
Then the server is starting
@ -76,6 +76,7 @@ Feature: Parallel
| disabled | 128 |
| enabled | 64 |
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
Given a prompt:
"""

View file

@ -10,11 +10,10 @@ Feature: llama.cpp server
# KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130
# see --ctx-size and #5568
And 32 KV cache size
And 512 as batch size
And 1 slots
And embeddings extraction
And 32 server max tokens to predict
And 256 KV cache size
And 32 as batch size
And 2 slots
And 64 server max tokens to predict
And prometheus compatible metrics exposed
Then the server is starting
Then the server is healthy
@ -23,18 +22,35 @@ Feature: llama.cpp server
Then the server is ready
And all slots are idle
Scenario Outline: Completion
Given a prompt <prompt>
And <n_predict> max tokens to predict
And a completion request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
And the completion is <truncated> truncated
And <n_prompt> prompt tokens are processed
And prometheus metrics are exposed
And metric llamacpp:tokens_predicted is <n_predicted>
Examples: Prompts
| prompt | n_predict | re_content | n_predicted |
| I believe the meaning of life is | 8 | (read\|going)+ | 8 |
| Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 |
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not |
Scenario: Completion prompt truncated
Given a prompt:
"""
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns
And the completion is truncated
And 109 prompt tokens are processed
Scenario Outline: OAI Compatibility
Given a model <model>
@ -44,11 +60,14 @@ Feature: llama.cpp server
And streaming is <enable_streaming>
Given an OAI compatible chat completions request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
And <n_prompt> prompt tokens are processed
And the completion is <truncated> truncated
Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
| llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | |
Scenario: Tokenize / Detokenize
When tokenizing:

View file

@ -196,12 +196,30 @@ async def step_request_completion(context, api_error):
@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
context.completion = context.tasks_result.pop()
assert_n_tokens_predicted(context.completion, predicted_n, re_content)
@step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
context.completion = context.tasks_result.pop()
assert_n_tokens_predicted(context.completion, predicted_n)
@step(u'the completion is truncated')
def step_assert_completion_truncated(context):
step_assert_completion_truncated(context, '')
@step(u'the completion is {truncated} truncated')
def step_assert_completion_truncated(context, truncated):
truncated = truncated != "not"
assert context.completion['truncated'] == truncated, f'{context.completion}'
@step(u'{n_prompt:d} prompt tokens are processed')
def step_impl(context, n_prompt):
assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}"
@step(u'a user prompt {user_prompt}')
@ -722,7 +740,8 @@ async def oai_chat_completions(user_prompt,
completion_response = {
'content': '',
'timings': {
'predicted_n': 0
'predicted_n': 0,
'prompt_n': 0
}
}
if async_client:
@ -763,7 +782,8 @@ async def oai_chat_completions(user_prompt,
completion_response = {
'content': chat_completion_raw['choices'][0]['message'],
'timings': {
'predicted_n': chat_completion_raw['usage']['completion_tokens']
'predicted_n': chat_completion_raw['usage']['completion_tokens'],
'prompt_n': chat_completion_raw['usage']['prompt_tokens']
}
}
else:
@ -792,13 +812,16 @@ async def oai_chat_completions(user_prompt,
if 'content' in delta:
completion_response['content'] += delta['content']
completion_response['timings']['predicted_n'] += 1
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
else:
assert len(chat_completion.choices) == 1
completion_response = {
'content': chat_completion.choices[0].message.content,
'timings': {
'predicted_n': chat_completion.usage.completion_tokens
}
'predicted_n': chat_completion.usage.completion_tokens,
'prompt_n': chat_completion.usage.prompt_tokens
},
'truncated': chat_completion.choices[0].finish_reason != 'stop'
}
if debug:
print("OAI response formatted to llama.cpp:", completion_response)