mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 23:22:57 +01:00
more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14
This commit is contained in:
parent
94712824d6
commit
3f6396cca8
4 changed files with 33 additions and 28 deletions
|
|
@ -139,20 +139,20 @@ class Llama2PromptStyle(AbstractPromptStyle):
|
|||
|
||||
|
||||
class Llama3PromptStyle(AbstractPromptStyle):
|
||||
r"""Template for metas lama3.
|
||||
|
||||
"""
|
||||
Template:
|
||||
{% set loop_messages = messages %}
|
||||
{% for message in loop_messages %}
|
||||
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
|
||||
{% if loop.index0 == 0 %}
|
||||
{% set content = bos_token + content %}
|
||||
{% endif %}
|
||||
{{ content }}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{% endif %}
|
||||
{% for message in loop_messages %}
|
||||
{% set content = '<|start_header_id|>' + message['role']
|
||||
+ '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
|
||||
{% if loop.index0 == 0 %}
|
||||
{% set content = bos_token + content %}
|
||||
{% endif %}
|
||||
{{ content }}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{% endif %}
|
||||
"""
|
||||
|
||||
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
||||
|
|
@ -183,7 +183,7 @@ class Llama3PromptStyle(AbstractPromptStyle):
|
|||
if i == 0:
|
||||
str_message = f"{system_message_str} {self.BOS} {self.B_INST} "
|
||||
else:
|
||||
# end previous user-assistant interaction
|
||||
# end previous user-assistant interaction
|
||||
string_messages[-1] += f" {self.EOS}"
|
||||
# no need to include system prompt
|
||||
str_message = f"{self.BOS} {self.B_INST} "
|
||||
|
|
@ -193,7 +193,9 @@ class Llama3PromptStyle(AbstractPromptStyle):
|
|||
if len(messages) > (i + 1):
|
||||
assistant_message = messages[i + 1]
|
||||
assert assistant_message.role == MessageRole.ASSISTANT
|
||||
str_message += f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
|
||||
str_message += (
|
||||
f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
|
||||
)
|
||||
|
||||
string_messages.append(str_message)
|
||||
|
||||
|
|
@ -289,7 +291,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
|||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None
|
||||
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
|
||||
| None
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue