more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14

This commit is contained in:
Robert Hirsch 2024-06-15 12:08:31 +02:00
parent 94712824d6
commit 3f6396cca8
No known key found for this signature in database
GPG key ID: A9D9D1205DBED12C
4 changed files with 33 additions and 28 deletions

View file

@ -139,20 +139,20 @@ class Llama2PromptStyle(AbstractPromptStyle):
class Llama3PromptStyle(AbstractPromptStyle):
r"""Template for metas lama3.
"""
Template:
{% set loop_messages = messages %}
{% for message in loop_messages %}
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
{% if loop.index0 == 0 %}
{% set content = bos_token + content %}
{% endif %}
{{ content }}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{% endif %}
{% for message in loop_messages %}
{% set content = '<|start_header_id|>' + message['role']
+ '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
{% if loop.index0 == 0 %}
{% set content = bos_token + content %}
{% endif %}
{{ content }}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{% endif %}
"""
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
@ -183,7 +183,7 @@ class Llama3PromptStyle(AbstractPromptStyle):
if i == 0:
str_message = f"{system_message_str} {self.BOS} {self.B_INST} "
else:
# end previous user-assistant interaction
# end previous user-assistant interaction
string_messages[-1] += f" {self.EOS}"
# no need to include system prompt
str_message = f"{self.BOS} {self.B_INST} "
@ -193,7 +193,9 @@ class Llama3PromptStyle(AbstractPromptStyle):
if len(messages) > (i + 1):
assistant_message = messages[i + 1]
assert assistant_message.role == MessageRole.ASSISTANT
str_message += f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
str_message += (
f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
)
string_messages.append(str_message)
@ -289,7 +291,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
def get_prompt_style(
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
| None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.