WIP more prompt format, and more maintainable

This commit is contained in:
Louis 2023-12-03 00:48:43 +01:00
parent 3d301d0c6f
commit 76faffb269
11 changed files with 476 additions and 217 deletions

View file

@ -1,10 +1,10 @@
import pytest
from llama_index.llms import ChatMessage, MessageRole
from private_gpt.components.llm.prompt_helper import (
from private_gpt.components.llm.prompt.prompt_helper import (
DefaultPromptStyle,
Llama2PromptStyle,
TagPromptStyle,
LlamaIndexPromptStyle,
VigognePromptStyle,
get_prompt_style,
)
@ -13,8 +13,8 @@ from private_gpt.components.llm.prompt_helper import (
("prompt_style", "expected_prompt_style"),
[
("default", DefaultPromptStyle),
("llama2", Llama2PromptStyle),
("tag", TagPromptStyle),
("llama2", LlamaIndexPromptStyle),
("tag", VigognePromptStyle),
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
@ -29,7 +29,7 @@ def test_get_prompt_style_failure():
def test_tag_prompt_style_format():
prompt_style = TagPromptStyle()
prompt_style = VigognePromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
@ -46,7 +46,7 @@ def test_tag_prompt_style_format():
def test_tag_prompt_style_format_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
prompt_style = TagPromptStyle(default_system_prompt=system_prompt)
prompt_style = VigognePromptStyle(default_system_prompt=system_prompt)
messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]
@ -76,7 +76,7 @@ def test_tag_prompt_style_format_with_system_prompt():
def test_llama2_prompt_style_format():
prompt_style = Llama2PromptStyle()
prompt_style = LlamaIndexPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
@ -95,7 +95,7 @@ def test_llama2_prompt_style_format():
def test_llama2_prompt_style_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt)
prompt_style = LlamaIndexPromptStyle(default_system_prompt=system_prompt)
messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]