mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 07:40:12 +01:00
Multi language support - fern debug (#1307)
--------- Co-authored-by: Louis <lpglm@orange.fr> Co-authored-by: LeMoussel <cnhx27@gmail.com>
This commit is contained in:
parent
e8d88f8952
commit
944c43bfa8
10 changed files with 402 additions and 8 deletions
128
tests/test_prompt_helper.py
Normal file
128
tests/test_prompt_helper.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
import pytest
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
|
||||
from private_gpt.components.llm.prompt_helper import (
|
||||
DefaultPromptStyle,
|
||||
Llama2PromptStyle,
|
||||
TagPromptStyle,
|
||||
get_prompt_style,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("prompt_style", "expected_prompt_style"),
|
||||
[
|
||||
("default", DefaultPromptStyle),
|
||||
("llama2", Llama2PromptStyle),
|
||||
("tag", TagPromptStyle),
|
||||
],
|
||||
)
|
||||
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
|
||||
assert get_prompt_style(prompt_style) == expected_prompt_style
|
||||
|
||||
|
||||
def test_get_prompt_style_failure():
|
||||
prompt_style = "unknown"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_prompt_style(prompt_style)
|
||||
assert str(exc_info.value) == f"Unknown prompt_style='{prompt_style}'"
|
||||
|
||||
|
||||
def test_tag_prompt_style_format():
|
||||
prompt_style = TagPromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|system|>: You are an AI assistant.\n"
|
||||
"<|user|>: Hello, how are you doing?\n"
|
||||
"<|assistant|>: "
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_tag_prompt_style_format_with_system_prompt():
|
||||
system_prompt = "This is a system prompt from configuration."
|
||||
prompt_style = TagPromptStyle(default_system_prompt=system_prompt)
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
f"<|system|>: {system_prompt}\n"
|
||||
"<|user|>: Hello, how are you doing?\n"
|
||||
"<|assistant|>: "
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
|
||||
),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|system|>: FOO BAR Custom sys prompt from messages.\n"
|
||||
"<|user|>: Hello, how are you doing?\n"
|
||||
"<|assistant|>: "
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama2_prompt_style_format():
|
||||
prompt_style = Llama2PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<s> [INST] <<SYS>>\n"
|
||||
" You are an AI assistant. \n"
|
||||
"<</SYS>>\n"
|
||||
"\n"
|
||||
" Hello, how are you doing? [/INST]"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama2_prompt_style_with_system_prompt():
|
||||
system_prompt = "This is a system prompt from configuration."
|
||||
prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt)
|
||||
messages = [
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<s> [INST] <<SYS>>\n"
|
||||
f" {system_prompt} \n"
|
||||
"<</SYS>>\n"
|
||||
"\n"
|
||||
" Hello, how are you doing? [/INST]"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
|
||||
),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<s> [INST] <<SYS>>\n"
|
||||
" FOO BAR Custom sys prompt from messages. \n"
|
||||
"<</SYS>>\n"
|
||||
"\n"
|
||||
" Hello, how are you doing? [/INST]"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
Loading…
Add table
Add a link
Reference in a new issue