mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 23:22:57 +01:00
more fixes to pass tests; changed type VectorStore -> BasePydanticVectorStore, see https://github.com/run-llama/llama_index/blob/main/CHANGELOG.md#2024-05-14
This commit is contained in:
parent
94712824d6
commit
3f6396cca8
4 changed files with 33 additions and 28 deletions
|
|
@ -139,20 +139,20 @@ class Llama2PromptStyle(AbstractPromptStyle):
|
|||
|
||||
|
||||
class Llama3PromptStyle(AbstractPromptStyle):
|
||||
r"""Template for metas lama3.
|
||||
|
||||
"""
|
||||
Template:
|
||||
{% set loop_messages = messages %}
|
||||
{% for message in loop_messages %}
|
||||
{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
|
||||
{% if loop.index0 == 0 %}
|
||||
{% set content = bos_token + content %}
|
||||
{% endif %}
|
||||
{{ content }}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{% endif %}
|
||||
{% for message in loop_messages %}
|
||||
{% set content = '<|start_header_id|>' + message['role']
|
||||
+ '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
|
||||
{% if loop.index0 == 0 %}
|
||||
{% set content = bos_token + content %}
|
||||
{% endif %}
|
||||
{{ content }}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{% endif %}
|
||||
"""
|
||||
|
||||
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
||||
|
|
@ -183,7 +183,7 @@ class Llama3PromptStyle(AbstractPromptStyle):
|
|||
if i == 0:
|
||||
str_message = f"{system_message_str} {self.BOS} {self.B_INST} "
|
||||
else:
|
||||
# end previous user-assistant interaction
|
||||
# end previous user-assistant interaction
|
||||
string_messages[-1] += f" {self.EOS}"
|
||||
# no need to include system prompt
|
||||
str_message = f"{self.BOS} {self.B_INST} "
|
||||
|
|
@ -193,7 +193,9 @@ class Llama3PromptStyle(AbstractPromptStyle):
|
|||
if len(messages) > (i + 1):
|
||||
assistant_message = messages[i + 1]
|
||||
assert assistant_message.role == MessageRole.ASSISTANT
|
||||
str_message += f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
|
||||
str_message += (
|
||||
f" {assistant_message.content} {self.E_SYS} {self.B_INST}"
|
||||
)
|
||||
|
||||
string_messages.append(str_message)
|
||||
|
||||
|
|
@ -289,7 +291,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
|||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None
|
||||
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
|
||||
| None
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from llama_index.core.vector_stores.types import (
|
|||
FilterCondition,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
VectorStore,
|
||||
BasePydanticVectorStore,
|
||||
)
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
|
|
@ -32,7 +32,7 @@ def _doc_id_metadata_filter(
|
|||
@singleton
|
||||
class VectorStoreComponent:
|
||||
settings: Settings
|
||||
vector_store: VectorStore
|
||||
vector_store: BasePydanticVectorStore
|
||||
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
|
|
@ -54,7 +54,7 @@ class VectorStoreComponent:
|
|||
)
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
VectorStore,
|
||||
BasePydanticVectorStore,
|
||||
PGVectorStore.from_params(
|
||||
**settings.postgres.model_dump(exclude_none=True),
|
||||
table_name="embeddings",
|
||||
|
|
@ -87,7 +87,7 @@ class VectorStoreComponent:
|
|||
) # TODO
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
VectorStore,
|
||||
BasePydanticVectorStore,
|
||||
BatchedChromaVectorStore(
|
||||
chroma_client=chroma_client, chroma_collection=chroma_collection
|
||||
),
|
||||
|
|
@ -115,7 +115,7 @@ class VectorStoreComponent:
|
|||
**settings.qdrant.model_dump(exclude_none=True)
|
||||
)
|
||||
self.vector_store = typing.cast(
|
||||
VectorStore,
|
||||
BasePydanticVectorStore,
|
||||
QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name="make_this_parameterizable_per_api_call",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue