server tests : more pythonic process management; fix bare except: (#6146)

* server tests : remove seemingly redundant newlines in print()

* server tests : use built-in subprocess features, not os.kill and psutil

* server tests : do not catch e.g. SystemExit; use print_exc

* server tests: handle TimeoutExpired exception

* server tests: fix connect on dual-stack systems

* server: tests: add new tokens regex on windows generated following new repeat penalties default changed in (#6127)

* server: tests: remove the hack on windows since now we get the good socket family

* server: tests: add new tokens regex following new repeat penalties default changed in (#6127)

* server: tests: add new tokens regex following new repeat penalties default changed in (#6127)

---------

Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
This commit is contained in:
Jared Van Bortel 2024-03-20 01:33:49 -04:00 committed by GitHub
parent 6c0b287748
commit bd60d82d0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 46 additions and 76 deletions

View file

@ -5,15 +5,14 @@ import sys
import time import time
import traceback import traceback
from contextlib import closing from contextlib import closing
from subprocess import TimeoutExpired
import psutil
def before_scenario(context, scenario): def before_scenario(context, scenario):
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
if context.debug: if context.debug:
print("DEBUG=ON\n") print("DEBUG=ON")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n") print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
port = 8080 port = 8080
if 'PORT' in os.environ: if 'PORT' in os.environ:
port = int(os.environ['PORT']) port = int(os.environ['PORT'])
@ -27,60 +26,40 @@ def after_scenario(context, scenario):
return return
if scenario.status == "failed": if scenario.status == "failed":
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n") print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n")
if os.path.isfile('llama.log'): if os.path.isfile('llama.log'):
with closing(open('llama.log', 'r')) as f: with closing(open('llama.log', 'r')) as f:
for line in f: for line in f:
print(line) print(line)
if not is_server_listening(context.server_fqdn, context.server_port): if not is_server_listening(context.server_fqdn, context.server_port):
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n") print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
if not pid_exists(context.server_process.pid): if context.server_process.poll() is not None:
assert False, f"Server not running pid={context.server_process.pid} ..." assert False, f"Server not running pid={context.server_process.pid} ..."
server_graceful_shutdown(context) server_graceful_shutdown(context) # SIGINT
# Wait few for socket to free up try:
time.sleep(0.05) context.server_process.wait(0.5)
except TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...")
context.server_process.kill() # SIGKILL
context.server_process.wait()
attempts = 0 while is_server_listening(context.server_fqdn, context.server_port):
while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
server_kill(context)
time.sleep(0.1) time.sleep(0.1)
attempts += 1 except Exception:
if attempts > 5: print("ignoring error in after_scenario:")
server_kill_hard(context) traceback.print_exc(file=sys.stdout)
except:
exc = sys.exception()
print("error in after scenario: \n")
print(exc)
print("*** print_tb: \n")
traceback.print_tb(exc.__traceback__, file=sys.stdout)
def server_graceful_shutdown(context): def server_graceful_shutdown(context):
print(f"shutting down server pid={context.server_process.pid} ...\n") print(f"shutting down server pid={context.server_process.pid} ...")
if os.name == 'nt': if os.name == 'nt':
os.kill(context.server_process.pid, signal.CTRL_C_EVENT) interrupt = signal.CTRL_C_EVENT
else: else:
os.kill(context.server_process.pid, signal.SIGINT) interrupt = signal.SIGINT
context.server_process.send_signal(interrupt)
def server_kill(context):
print(f"killing server pid={context.server_process.pid} ...\n")
context.server_process.kill()
def server_kill_hard(context):
pid = context.server_process.pid
path = context.server_path
print(f"Server dangling exits, hard killing force {pid}={path}...\n")
try:
psutil.Process(pid).kill()
except psutil.NoSuchProcess:
return False
return True
def is_server_listening(server_fqdn, server_port): def is_server_listening(server_fqdn, server_port):
@ -88,14 +67,5 @@ def is_server_listening(server_fqdn, server_port):
result = sock.connect_ex((server_fqdn, server_port)) result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0 _is_server_listening = result == 0
if _is_server_listening: if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...\n") print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening return _is_server_listening
def pid_exists(pid):
try:
psutil.Process(pid)
except psutil.NoSuchProcess:
return False
return True

View file

@ -35,9 +35,9 @@ Feature: llama.cpp server
And metric llamacpp:tokens_predicted is <n_predicted> And metric llamacpp:tokens_predicted is <n_predicted>
Examples: Prompts Examples: Prompts
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated | | prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not | | I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not | | Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids\|Anna\|forest)+ | 46 | 64 | not |
Scenario: Completion prompt truncated Scenario: Completion prompt truncated
Given a prompt: Given a prompt:
@ -48,7 +48,7 @@ Feature: llama.cpp server
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""" """
And a completion request with no api error And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated And the completion is truncated
And 109 prompt tokens are processed And 109 prompt tokens are processed
@ -65,9 +65,9 @@ Feature: llama.cpp server
And the completion is <truncated> truncated And the completion is <truncated> truncated
Examples: Prompts Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated | | model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not | | llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize

View file

@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port):
def step_download_hf_model(context, hf_file, hf_repo): def step_download_hf_model(context, hf_file, hf_repo):
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file) context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
if context.debug: if context.debug:
print(f"model file: {context.model_file}\n") print(f"model file: {context.model_file}")
@step('a model file {model_file}') @step('a model file {model_file}')
@ -137,9 +137,12 @@ def step_start_server(context):
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2 max_attempts *= 2
addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
family, typ, proto, _, sockaddr = addrs[0]
while True: while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(family, typ, proto)) as sock:
result = sock.connect_ex((context.server_fqdn, context.server_port)) result = sock.connect_ex(sockaddr)
if result == 0: if result == 0:
print("\x1b[33;46mserver started!\x1b[0m") print("\x1b[33;46mserver started!\x1b[0m")
return return
@ -209,7 +212,7 @@ async def step_request_completion(context, api_error):
user_api_key=context.user_api_key) user_api_key=context.user_api_key)
context.tasks_result.append(completion) context.tasks_result.append(completion)
if context.debug: if context.debug:
print(f"Completion response: {completion}\n") print(f"Completion response: {completion}")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos):
prompt += context.prompt_junk_suffix prompt += context.prompt_junk_suffix
if context.debug: if context.debug:
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n") print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
context.n_prompts = len(context.prompts) context.n_prompts = len(context.prompts)
@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos):
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context, api_error): async def step_oai_chat_completions(context, api_error):
if context.debug: if context.debug:
print(f"Submitting OAI compatible completions request...\n") print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(), completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt, context.system_prompt,
@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context):
embedding1 = np.array(embeddings[i]) embedding1 = np.array(embeddings[i])
embedding2 = np.array(embeddings[j]) embedding2 = np.array(embeddings[j])
if context.debug: if context.debug:
print(f"embedding1: {embedding1[-8:]}\n") print(f"embedding1: {embedding1[-8:]}")
print(f"embedding2: {embedding2[-8:]}\n") print(f"embedding2: {embedding2[-8:]}")
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
msg = f"Similarity between {i} and {j}: {similarity:.10f}" msg = f"Similarity between {i} and {j}: {similarity:.10f}"
if context.debug: if context.debug:
print(f"{msg}\n") print(f"{msg}")
assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg
@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context):
metrics_raw = await metrics_response.text() metrics_raw = await metrics_response.text()
metric_exported = False metric_exported = False
if context.debug: if context.debug:
print(f"/metrics answer:\n{metrics_raw}\n") print(f"/metrics answer:\n{metrics_raw}")
context.metrics = {} context.metrics = {}
for metric in parser.text_string_to_metric_families(metrics_raw): for metric in parser.text_string_to_metric_families(metrics_raw):
match metric.name: match metric.name:
@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
last_match = end last_match = end
highlighted += content[last_match:] highlighted += content[last_match:]
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"Checking completion response: {highlighted}\n") print(f"Checking completion response: {highlighted}")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if expected_predicted_n and expected_predicted_n > 0: if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
async def gather_tasks_results(context): async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks) n_tasks = len(context.concurrent_tasks)
if context.debug: if context.debug:
print(f"Waiting for all {n_tasks} tasks results...\n") print(f"Waiting for all {n_tasks} tasks results...")
for task_no in range(n_tasks): for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop()) context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result) n_completions = len(context.tasks_result)
@ -959,7 +962,7 @@ async def wait_for_health_status(context,
slots_processing=None, slots_processing=None,
expected_slots=None): expected_slots=None):
if context.debug: if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}\n") print(f"Starting checking for health for expected_health_status={expected_health_status}")
interval = 0.5 interval = 0.5
counter = 0 counter = 0
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
@ -1048,8 +1051,6 @@ def start_server_background(context):
if 'LLAMA_SERVER_BIN_PATH' in os.environ: if 'LLAMA_SERVER_BIN_PATH' in os.environ:
context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_listen_addr = context.server_fqdn server_listen_addr = context.server_fqdn
if os.name == 'nt':
server_listen_addr = '0.0.0.0'
server_args = [ server_args = [
'--host', server_listen_addr, '--host', server_listen_addr,
'--port', context.server_port, '--port', context.server_port,
@ -1088,7 +1089,7 @@ def start_server_background(context):
server_args.append('--verbose') server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ: if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"]) server_args.extend(['--log-format', "text"])
print(f"starting server with: {context.server_path} {server_args}\n") print(f"starting server with: {context.server_path} {server_args}")
flags = 0 flags = 0
if 'nt' == os.name: if 'nt' == os.name:
flags |= subprocess.DETACHED_PROCESS flags |= subprocess.DETACHED_PROCESS

View file

@ -3,5 +3,4 @@ behave~=1.2.6
huggingface_hub~=0.20.3 huggingface_hub~=0.20.3
numpy~=1.24.4 numpy~=1.24.4
openai~=0.25.0 openai~=0.25.0
psutil~=5.9.8
prometheus-client~=0.20.0 prometheus-client~=0.20.0