more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14

This commit is contained in:
Robert Hirsch 2024-06-15 12:08:31 +02:00
parent 94712824d6
commit 3f6396cca8
No known key found for this signature in database
GPG key ID: A9D9D1205DBED12C
4 changed files with 33 additions and 28 deletions

View file

@ -139,20 +139,20 @@ class Llama2PromptStyle(AbstractPromptStyle):
class Llama3PromptStyle(AbstractPromptStyle):
r"""Template for metas lama3.
"""
Template:
{% set loop_messages = messages %}
{% for message in loop_messages %}
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
{% if loop.index0 == 0 %}
{% set content = bos_token + content %}
{% endif %}
{{ content }}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{% endif %}
{% for message in loop_messages %}
{% set content = '<|start_header_id|>' + message['role']
+ '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
{% if loop.index0 == 0 %}
{% set content = bos_token + content %}
{% endif %}
{{ content }}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{% endif %}
"""
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
@ -183,7 +183,7 @@ class Llama3PromptStyle(AbstractPromptStyle):
if i == 0:
str_message = f"{system_message_str} {self.BOS} {self.B_INST} "
else:
# end previous user-assistant interaction
# end previous user-assistant interaction
string_messages[-1] += f" {self.EOS}"
# no need to include system prompt
str_message = f"{self.BOS} {self.B_INST} "
@ -193,7 +193,9 @@ class Llama3PromptStyle(AbstractPromptStyle):
if len(messages) > (i + 1):
assistant_message = messages[i + 1]
assert assistant_message.role == MessageRole.ASSISTANT
str_message += f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
str_message += (
f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
)
string_messages.append(str_message)
@ -289,7 +291,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
def get_prompt_style(
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
| None
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.

View file

@ -7,7 +7,7 @@ from llama_index.core.vector_stores.types import (
FilterCondition,
MetadataFilter,
MetadataFilters,
VectorStore,
BasePydanticVectorStore,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
@ -32,7 +32,7 @@ def _doc_id_metadata_filter(
@singleton
class VectorStoreComponent:
settings: Settings
vector_store: VectorStore
vector_store: BasePydanticVectorStore
@inject
def __init__(self, settings: Settings) -> None:
@ -54,7 +54,7 @@ class VectorStoreComponent:
)
self.vector_store = typing.cast(
VectorStore,
BasePydanticVectorStore,
PGVectorStore.from_params(
**settings.postgres.model_dump(exclude_none=True),
table_name="embeddings",
@ -87,7 +87,7 @@ class VectorStoreComponent:
) # TODO
self.vector_store = typing.cast(
VectorStore,
BasePydanticVectorStore,
BatchedChromaVectorStore(
chroma_client=chroma_client, chroma_collection=chroma_collection
),
@ -115,7 +115,7 @@ class VectorStoreComponent:
**settings.qdrant.model_dump(exclude_none=True)
)
self.vector_store = typing.cast(
VectorStore,
BasePydanticVectorStore,
QdrantVectorStore(
client=client,
collection_name="make_this_parameterizable_per_api_call",