mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 20:12:55 +01:00
182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
from dataclasses import dataclass
|
|
|
|
from injector import inject, singleton
|
|
from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
|
from llama_index.core.chat_engine.types import (
|
|
BaseChatEngine,
|
|
)
|
|
from llama_index.core.indices import VectorStoreIndex
|
|
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
|
|
from llama_index.core.llms import ChatMessage, MessageRole
|
|
from llama_index.core.storage import StorageContext
|
|
from llama_index.core.types import TokenGen
|
|
from pydantic import BaseModel
|
|
|
|
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
|
from private_gpt.components.llm.llm_component import LLMComponent
|
|
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
|
from private_gpt.components.vector_store.vector_store_component import (
|
|
VectorStoreComponent,
|
|
)
|
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|
from private_gpt.server.chunks.chunks_service import Chunk
|
|
|
|
|
|
class Completion(BaseModel):
|
|
response: str
|
|
sources: list[Chunk] | None = None
|
|
|
|
|
|
class CompletionGen(BaseModel):
|
|
response: TokenGen
|
|
sources: list[Chunk] | None = None
|
|
|
|
|
|
@dataclass
|
|
class ChatEngineInput:
|
|
system_message: ChatMessage | None = None
|
|
last_message: ChatMessage | None = None
|
|
chat_history: list[ChatMessage] | None = None
|
|
|
|
@classmethod
|
|
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
|
|
# Detect if there is a system message, extract the last message and chat history
|
|
system_message = (
|
|
messages[0]
|
|
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
|
|
else None
|
|
)
|
|
last_message = (
|
|
messages[-1]
|
|
if len(messages) > 0 and messages[-1].role == MessageRole.USER
|
|
else None
|
|
)
|
|
# Remove from messages list the system message and last message,
|
|
# if they exist. The rest is the chat history.
|
|
if system_message:
|
|
messages.pop(0)
|
|
if last_message:
|
|
messages.pop(-1)
|
|
chat_history = messages if len(messages) > 0 else None
|
|
|
|
return cls(
|
|
system_message=system_message,
|
|
last_message=last_message,
|
|
chat_history=chat_history,
|
|
)
|
|
|
|
|
|
@singleton
|
|
class ChatService:
|
|
@inject
|
|
def __init__(
|
|
self,
|
|
llm_component: LLMComponent,
|
|
vector_store_component: VectorStoreComponent,
|
|
embedding_component: EmbeddingComponent,
|
|
node_store_component: NodeStoreComponent,
|
|
) -> None:
|
|
self.llm_component = llm_component
|
|
self.embedding_component = embedding_component
|
|
self.vector_store_component = vector_store_component
|
|
self.storage_context = StorageContext.from_defaults(
|
|
vector_store=vector_store_component.vector_store,
|
|
docstore=node_store_component.doc_store,
|
|
index_store=node_store_component.index_store,
|
|
)
|
|
self.index = VectorStoreIndex.from_vector_store(
|
|
vector_store_component.vector_store,
|
|
storage_context=self.storage_context,
|
|
llm=llm_component.llm,
|
|
embed_model=embedding_component.embedding_model,
|
|
show_progress=True,
|
|
)
|
|
|
|
def _chat_engine(
|
|
self,
|
|
system_prompt: str | None = None,
|
|
use_context: bool = False,
|
|
context_filter: ContextFilter | None = None,
|
|
) -> BaseChatEngine:
|
|
if use_context:
|
|
vector_index_retriever = self.vector_store_component.get_retriever(
|
|
index=self.index, context_filter=context_filter
|
|
)
|
|
return ContextChatEngine.from_defaults(
|
|
system_prompt=system_prompt,
|
|
retriever=vector_index_retriever,
|
|
llm=self.llm_component.llm, # Takes no effect at the moment
|
|
node_postprocessors=[
|
|
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
|
],
|
|
)
|
|
else:
|
|
return SimpleChatEngine.from_defaults(
|
|
system_prompt=system_prompt,
|
|
llm=self.llm_component.llm,
|
|
)
|
|
|
|
def stream_chat(
|
|
self,
|
|
messages: list[ChatMessage],
|
|
use_context: bool = False,
|
|
context_filter: ContextFilter | None = None,
|
|
) -> CompletionGen:
|
|
chat_engine_input = ChatEngineInput.from_messages(messages)
|
|
last_message = (
|
|
chat_engine_input.last_message.content
|
|
if chat_engine_input.last_message
|
|
else None
|
|
)
|
|
system_prompt = (
|
|
"You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
|
|
)
|
|
chat_history = (
|
|
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
|
)
|
|
|
|
chat_engine = self._chat_engine(
|
|
system_prompt=system_prompt,
|
|
use_context=use_context,
|
|
context_filter=context_filter,
|
|
)
|
|
streaming_response = chat_engine.stream_chat(
|
|
message=last_message if last_message is not None else "",
|
|
chat_history=chat_history,
|
|
)
|
|
sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
|
|
completion_gen = CompletionGen(
|
|
response=streaming_response.response_gen, sources=sources
|
|
)
|
|
return completion_gen
|
|
|
|
def chat(
|
|
self,
|
|
messages: list[ChatMessage],
|
|
use_context: bool = False,
|
|
context_filter: ContextFilter | None = None,
|
|
) -> Completion:
|
|
chat_engine_input = ChatEngineInput.from_messages(messages)
|
|
last_message = (
|
|
chat_engine_input.last_message.content
|
|
if chat_engine_input.last_message
|
|
else None
|
|
)
|
|
system_prompt = (
|
|
"You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
|
|
)
|
|
chat_history = (
|
|
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
|
)
|
|
chat_engine = self._chat_engine(
|
|
system_prompt=system_prompt,
|
|
use_context=use_context,
|
|
context_filter=context_filter,
|
|
)
|
|
wrapped_response = chat_engine.chat(
|
|
message=last_message if last_message is not None else "",
|
|
chat_history=chat_history,
|
|
)
|
|
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
|
|
completion = Completion(response=wrapped_response.response, sources=sources)
|
|
return completion
|