From 1bb7624be428dd1bb3232571a6fe40c1ab2bd0e6 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Mon, 29 Jul 2024 15:57:56 +0200 Subject: [PATCH] fix: new llama3 prompt --- private_gpt/components/llm/prompt_helper.py | 87 ++++++++++----------- tests/test_prompt_helper.py | 55 +++++++++++++ 2 files changed, 96 insertions(+), 46 deletions(-) diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index b0cda46..b550020 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -139,25 +139,27 @@ class Llama2PromptStyle(AbstractPromptStyle): class Llama3PromptStyle(AbstractPromptStyle): - r"""Template for metas lama3. + r"""Template for Meta's Llama 3.1. - {% set loop_messages = messages %} - {% for message in loop_messages %} - {% set content = '<|start_header_id|>' + message['role'] - + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %} - {% if loop.index0 == 0 %} - {% set content = bos_token + content %} - {% endif %} - {{ content }} - {% endfor %} - {% if add_generation_prompt %} - {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} - {% endif %} + The format follows this structure: + <|begin_of_text|> + <|start_header_id|>system<|end_header_id|> + + [System message content]<|eot_id|> + <|start_header_id|>user<|end_header_id|> + + [User message content]<|eot_id|> + <|start_header_id|>assistant<|end_header_id|> + + [Assistant message content]<|eot_id|> + ... + (Repeat for each message, including possible 'ipython' role) """ BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>" - B_INST, E_INST = "<|start_header_id|>user<|end_header_id|>", "<|eot_id|>" - B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|> ", "<|eot_id|>" + B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>" + EOT = "<|eot_id|>" + B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>" ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>" DEFAULT_SYSTEM_PROMPT = """\ You are a helpful, respectful and honest assistant. \ @@ -167,46 +169,39 @@ class Llama3PromptStyle(AbstractPromptStyle): """ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - string_messages: list[str] = [] - if messages[0].role == MessageRole.SYSTEM: - system_message_str = messages[0].content or "" - messages = messages[1:] - else: - system_message_str = self.DEFAULT_SYSTEM_PROMPT + prompt = self.BOS + has_system_message = False - system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}" - - for i in range(0, len(messages), 2): - user_message = messages[i] - assert user_message.role == MessageRole.USER - - if i == 0: - str_message = f"{system_message_str} {self.BOS} {self.B_INST} " + for i, message in enumerate(messages): + if not message or message.content is None: + continue + if message.role == MessageRole.SYSTEM: + prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}" + has_system_message = True else: - # end previous user-assistant interaction - string_messages[-1] += f" {self.EOS}" - # no need to include system prompt - str_message = f"{self.BOS} {self.B_INST} " + role_header = f"{self.B_INST}{message.role.value}{self.E_INST}" + prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}" - str_message += f"{user_message.content} {self.E_INST} {self.ASSISTANT_INST}" + # Add assistant header if the last message is not from the assistant + if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT: + prompt += f"{self.ASSISTANT_INST}\n\n" - if len(messages) > (i + 1): - assistant_message = messages[i + 1] - assert assistant_message.role == MessageRole.ASSISTANT - str_message += ( - f" {assistant_message.content} {self.E_SYS} {self.B_INST}" - ) + # Add default system prompt if no system message was provided + if not has_system_message: + prompt = ( + f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + + prompt[len(self.BOS) :] + ) - string_messages.append(str_message) + # TODO: Implement tool handling logic - return "".join(string_messages) + return prompt def _completion_to_prompt(self, completion: str) -> str: - system_prompt_str = self.DEFAULT_SYSTEM_PROMPT - return ( - f"{self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} " - f"{completion.strip()} {self.E_SYS} " + f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}" + f"{self.ASSISTANT_INST}\n\n" ) diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index ef76437..ad9349c 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -5,6 +5,7 @@ from private_gpt.components.llm.prompt_helper import ( ChatMLPromptStyle, DefaultPromptStyle, Llama2PromptStyle, + Llama3PromptStyle, MistralPromptStyle, TagPromptStyle, get_prompt_style, @@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt(): ) assert prompt_style.messages_to_prompt(messages) == expected_prompt + + +def test_llama3_prompt_style_format(): + prompt_style = Llama3PromptStyle() + messages = [ + ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM), + ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), + ] + + expected_prompt = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n" + "Hello, how are you doing?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + assert prompt_style.messages_to_prompt(messages) == expected_prompt + + +def test_llama3_prompt_style_with_default_system(): + prompt_style = Llama3PromptStyle() + messages = [ + ChatMessage(content="Hello!", role=MessageRole.USER), + ] + expected = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + assert prompt_style._messages_to_prompt(messages) == expected + + +def test_llama3_prompt_style_with_assistant_response(): + prompt_style = Llama3PromptStyle() + messages = [ + ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM), + ChatMessage(content="What is the capital of France?", role=MessageRole.USER), + ChatMessage( + content="The capital of France is Paris.", role=MessageRole.ASSISTANT + ), + ] + + expected_prompt = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n" + "What is the capital of France?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + "The capital of France is Paris.<|eot_id|>" + ) + + assert prompt_style.messages_to_prompt(messages) == expected_prompt