mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 20:12:55 +01:00
fix: new llama3 prompt
This commit is contained in:
parent
a19e991fa1
commit
1bb7624be4
2 changed files with 96 additions and 46 deletions
|
|
@ -139,25 +139,27 @@ class Llama2PromptStyle(AbstractPromptStyle):
|
||||||
|
|
||||||
|
|
||||||
class Llama3PromptStyle(AbstractPromptStyle):
|
class Llama3PromptStyle(AbstractPromptStyle):
|
||||||
r"""Template for metas lama3.
|
r"""Template for Meta's Llama 3.1.
|
||||||
|
|
||||||
{% set loop_messages = messages %}
|
The format follows this structure:
|
||||||
{% for message in loop_messages %}
|
<|begin_of_text|>
|
||||||
{% set content = '<|start_header_id|>' + message['role']
|
<|start_header_id|>system<|end_header_id|>
|
||||||
+ '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
|
|
||||||
{% if loop.index0 == 0 %}
|
[System message content]<|eot_id|>
|
||||||
{% set content = bos_token + content %}
|
<|start_header_id|>user<|end_header_id|>
|
||||||
{% endif %}
|
|
||||||
{{ content }}
|
[User message content]<|eot_id|>
|
||||||
{% endfor %}
|
<|start_header_id|>assistant<|end_header_id|>
|
||||||
{% if add_generation_prompt %}
|
|
||||||
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
[Assistant message content]<|eot_id|>
|
||||||
{% endif %}
|
...
|
||||||
|
(Repeat for each message, including possible 'ipython' role)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
||||||
B_INST, E_INST = "<|start_header_id|>user<|end_header_id|>", "<|eot_id|>"
|
B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>"
|
||||||
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|> ", "<|eot_id|>"
|
EOT = "<|eot_id|>"
|
||||||
|
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>"
|
||||||
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
|
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
|
||||||
DEFAULT_SYSTEM_PROMPT = """\
|
DEFAULT_SYSTEM_PROMPT = """\
|
||||||
You are a helpful, respectful and honest assistant. \
|
You are a helpful, respectful and honest assistant. \
|
||||||
|
|
@ -167,46 +169,39 @@ class Llama3PromptStyle(AbstractPromptStyle):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||||
string_messages: list[str] = []
|
prompt = self.BOS
|
||||||
if messages[0].role == MessageRole.SYSTEM:
|
has_system_message = False
|
||||||
system_message_str = messages[0].content or ""
|
|
||||||
messages = messages[1:]
|
for i, message in enumerate(messages):
|
||||||
|
if not message or message.content is None:
|
||||||
|
continue
|
||||||
|
if message.role == MessageRole.SYSTEM:
|
||||||
|
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
|
||||||
|
has_system_message = True
|
||||||
else:
|
else:
|
||||||
system_message_str = self.DEFAULT_SYSTEM_PROMPT
|
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
|
||||||
|
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"
|
||||||
|
|
||||||
system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}"
|
# Add assistant header if the last message is not from the assistant
|
||||||
|
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
|
||||||
|
prompt += f"{self.ASSISTANT_INST}\n\n"
|
||||||
|
|
||||||
for i in range(0, len(messages), 2):
|
# Add default system prompt if no system message was provided
|
||||||
user_message = messages[i]
|
if not has_system_message:
|
||||||
assert user_message.role == MessageRole.USER
|
prompt = (
|
||||||
|
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||||
if i == 0:
|
+ prompt[len(self.BOS) :]
|
||||||
str_message = f"{system_message_str} {self.BOS} {self.B_INST} "
|
|
||||||
else:
|
|
||||||
# end previous user-assistant interaction
|
|
||||||
string_messages[-1] += f" {self.EOS}"
|
|
||||||
# no need to include system prompt
|
|
||||||
str_message = f"{self.BOS} {self.B_INST} "
|
|
||||||
|
|
||||||
str_message += f"{user_message.content} {self.E_INST} {self.ASSISTANT_INST}"
|
|
||||||
|
|
||||||
if len(messages) > (i + 1):
|
|
||||||
assistant_message = messages[i + 1]
|
|
||||||
assert assistant_message.role == MessageRole.ASSISTANT
|
|
||||||
str_message += (
|
|
||||||
f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
string_messages.append(str_message)
|
# TODO: Implement tool handling logic
|
||||||
|
|
||||||
return "".join(string_messages)
|
return prompt
|
||||||
|
|
||||||
def _completion_to_prompt(self, completion: str) -> str:
|
def _completion_to_prompt(self, completion: str) -> str:
|
||||||
system_prompt_str = self.DEFAULT_SYSTEM_PROMPT
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"{self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} "
|
f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||||
f"{completion.strip()} {self.E_SYS} "
|
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
|
||||||
|
f"{self.ASSISTANT_INST}\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from private_gpt.components.llm.prompt_helper import (
|
||||||
ChatMLPromptStyle,
|
ChatMLPromptStyle,
|
||||||
DefaultPromptStyle,
|
DefaultPromptStyle,
|
||||||
Llama2PromptStyle,
|
Llama2PromptStyle,
|
||||||
|
Llama3PromptStyle,
|
||||||
MistralPromptStyle,
|
MistralPromptStyle,
|
||||||
TagPromptStyle,
|
TagPromptStyle,
|
||||||
get_prompt_style,
|
get_prompt_style,
|
||||||
|
|
@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_prompt_style_format():
|
||||||
|
prompt_style = Llama3PromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||||
|
"You are a helpful assistant<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||||
|
"Hello, how are you doing?<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_prompt_style_with_default_system():
|
||||||
|
prompt_style = Llama3PromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="Hello!", role=MessageRole.USER),
|
||||||
|
]
|
||||||
|
expected = (
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||||
|
f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert prompt_style._messages_to_prompt(messages) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_prompt_style_with_assistant_response():
|
||||||
|
prompt_style = Llama3PromptStyle()
|
||||||
|
messages = [
|
||||||
|
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||||
|
ChatMessage(content="What is the capital of France?", role=MessageRole.USER),
|
||||||
|
ChatMessage(
|
||||||
|
content="The capital of France is Paris.", role=MessageRole.ASSISTANT
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_prompt = (
|
||||||
|
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
|
||||||
|
"You are a helpful assistant<|eot_id|>"
|
||||||
|
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||||
|
"What is the capital of France?<|eot_id|>"
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
"The capital of France is Paris.<|eot_id|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue