Fix typing, linting and add tests

This commit is contained in:
Louis 2023-12-03 16:31:20 +01:00
parent 76faffb269
commit 5bc5054000
3 changed files with 111 additions and 47 deletions

View file

@ -1,9 +1,14 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle,
LlamaCppPromptStyle,
LlamaIndexPromptStyle,
TemplatePromptStyle,
VigognePromptStyle,
get_prompt_style,
)
@ -12,13 +17,44 @@ from private_gpt.components.llm.prompt.prompt_helper import (
@pytest.mark.parametrize(
("prompt_style", "expected_prompt_style"),
[
("default", DefaultPromptStyle),
(None, DefaultPromptStyle),
("llama2", LlamaIndexPromptStyle),
("tag", VigognePromptStyle),
("vigogne", VigognePromptStyle),
("llama_cpp.alpaca", LlamaCppPromptStyle),
("llama_cpp.zephyr", LlamaCppPromptStyle),
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
assert get_prompt_style(prompt_style) == expected_prompt_style
assert type(get_prompt_style(prompt_style)) == expected_prompt_style
def test_get_prompt_style_template_success():
jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: "
with NamedTemporaryFile("w") as tmp_file:
path = Path(tmp_file.name)
tmp_file.write(jinja_template)
tmp_file.flush()
tmp_file.seek(0)
prompt_style = get_prompt_style(
"template", template_name=path.name, template_dir=path.parent
)
assert type(prompt_style) == TemplatePromptStyle
prompt = prompt_style.messages_to_prompt(
[
ChatMessage(
content="You are an AI assistant.", role=MessageRole.SYSTEM
),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]
)
expected_prompt = (
"<|system|>: You are an AI assistant.\n"
"<|user|>: Hello, how are you doing?\n"
"<|assistant|>: "
)
assert prompt == expected_prompt
def test_get_prompt_style_failure():