mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 20:12:55 +01:00
Fix typing, linting and add tests
This commit is contained in:
parent
76faffb269
commit
5bc5054000
3 changed files with 111 additions and 47 deletions
|
|
@ -1,9 +1,14 @@
|
|||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
|
||||
from private_gpt.components.llm.prompt.prompt_helper import (
|
||||
DefaultPromptStyle,
|
||||
LlamaCppPromptStyle,
|
||||
LlamaIndexPromptStyle,
|
||||
TemplatePromptStyle,
|
||||
VigognePromptStyle,
|
||||
get_prompt_style,
|
||||
)
|
||||
|
|
@ -12,13 +17,44 @@ from private_gpt.components.llm.prompt.prompt_helper import (
|
|||
@pytest.mark.parametrize(
|
||||
("prompt_style", "expected_prompt_style"),
|
||||
[
|
||||
("default", DefaultPromptStyle),
|
||||
(None, DefaultPromptStyle),
|
||||
("llama2", LlamaIndexPromptStyle),
|
||||
("tag", VigognePromptStyle),
|
||||
("vigogne", VigognePromptStyle),
|
||||
("llama_cpp.alpaca", LlamaCppPromptStyle),
|
||||
("llama_cpp.zephyr", LlamaCppPromptStyle),
|
||||
],
|
||||
)
|
||||
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
||||
assert get_prompt_style(prompt_style) == expected_prompt_style
|
||||
assert type(get_prompt_style(prompt_style)) == expected_prompt_style
|
||||
|
||||
|
||||
def test_get_prompt_style_template_success():
|
||||
jinja_template = "{% for message in messages %}<|{{message['role']}}|>: {{message['content'].strip() + '\\n'}}{% endfor %}<|assistant|>: "
|
||||
with NamedTemporaryFile("w") as tmp_file:
|
||||
path = Path(tmp_file.name)
|
||||
tmp_file.write(jinja_template)
|
||||
tmp_file.flush()
|
||||
tmp_file.seek(0)
|
||||
prompt_style = get_prompt_style(
|
||||
"template", template_name=path.name, template_dir=path.parent
|
||||
)
|
||||
assert type(prompt_style) == TemplatePromptStyle
|
||||
prompt = prompt_style.messages_to_prompt(
|
||||
[
|
||||
ChatMessage(
|
||||
content="You are an AI assistant.", role=MessageRole.SYSTEM
|
||||
),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
)
|
||||
|
||||
expected_prompt = (
|
||||
"<|system|>: You are an AI assistant.\n"
|
||||
"<|user|>: Hello, how are you doing?\n"
|
||||
"<|assistant|>: "
|
||||
)
|
||||
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_get_prompt_style_failure():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue