from dataclasses import dataclass from injector import inject, singleton from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine, CondensePlusContextChatEngine from llama_index.core.chat_engine.types import ( BaseChatEngine, ) from llama_index.core.indices import VectorStoreIndex from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.postprocessor import ( SentenceTransformerRerank, SimilarityPostprocessor, ) from llama_index.core.storage import StorageContext from llama_index.core.types import TokenGen from pydantic import BaseModel from private_gpt.components.embedding.embedding_component import EmbeddingComponent from private_gpt.components.llm.llm_component import LLMComponent from private_gpt.components.node_store.node_store_component import NodeStoreComponent from private_gpt.components.vector_store.vector_store_component import ( VectorStoreComponent, ) from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chunks.chunks_service import Chunk from private_gpt.settings.settings import Settings class Completion(BaseModel): response: str sources: list[Chunk] | None = None class CompletionGen(BaseModel): response: TokenGen sources: list[Chunk] | None = None @dataclass class ChatEngineInput: system_message: ChatMessage | None = None last_message: ChatMessage | None = None chat_history: list[ChatMessage] | None = None @classmethod def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput": # Detect if there is a system message, extract the last message and chat history system_message = ( messages[0] if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM else None ) last_message = ( messages[-1] if len(messages) > 0 and messages[-1].role == MessageRole.USER else None ) # Remove from messages list the system message and last message, # if they exist. The rest is the chat history. if system_message: messages.pop(0) if last_message: messages.pop(-1) chat_history = messages if len(messages) > 0 else None return cls( system_message=system_message, last_message=last_message, chat_history=chat_history, ) @singleton class ChatService: settings: Settings @inject def __init__( self, settings: Settings, llm_component: LLMComponent, vector_store_component: VectorStoreComponent, embedding_component: EmbeddingComponent, node_store_component: NodeStoreComponent, ) -> None: self.settings = settings self.llm_component = llm_component self.embedding_component = embedding_component self.vector_store_component = vector_store_component self.storage_context = StorageContext.from_defaults( vector_store=vector_store_component.vector_store, docstore=node_store_component.doc_store, index_store=node_store_component.index_store, ) self.index = VectorStoreIndex.from_vector_store( vector_store_component.vector_store, storage_context=self.storage_context, llm=llm_component.llm, embed_model=embedding_component.embedding_model, show_progress=True, ) def _chat_engine( self, system_prompt: str | None = None, use_context: bool = False, context_filter: ContextFilter | None = None, ) -> BaseChatEngine: settings = self.settings if use_context: vector_index_retriever = self.vector_store_component.get_retriever( index=self.index, context_filter=context_filter, similarity_top_k=self.settings.rag.similarity_top_k, ) node_postprocessors = [ MetadataReplacementPostProcessor(target_metadata_key="window"), SimilarityPostprocessor( similarity_cutoff=settings.rag.similarity_value ), ] if settings.rag.rerank.enabled: rerank_postprocessor = SentenceTransformerRerank( model=settings.rag.rerank.model, top_n=settings.rag.rerank.top_n ) node_postprocessors.append(rerank_postprocessor) return CondensePlusContextChatEngine.from_defaults( system_prompt=system_prompt, retriever=vector_index_retriever, llm=self.llm_component.llm, # Takes no effect at the moment node_postprocessors=node_postprocessors, ) else: return SimpleChatEngine.from_defaults( system_prompt=system_prompt, llm=self.llm_component.llm, ) def stream_chat( self, messages: list[ChatMessage], use_context: bool = False, context_filter: ContextFilter | None = None, ) -> CompletionGen: chat_engine_input = ChatEngineInput.from_messages(messages) last_message = ( chat_engine_input.last_message.content if chat_engine_input.last_message else None ) system_prompt = ( chat_engine_input.system_message.content if chat_engine_input.system_message else None ) chat_history = ( chat_engine_input.chat_history if chat_engine_input.chat_history else None ) chat_engine = self._chat_engine( system_prompt=system_prompt, use_context=use_context, context_filter=context_filter, ) streaming_response = chat_engine.stream_chat( message=last_message if last_message is not None else "", chat_history=chat_history, ) sources = [Chunk.from_node(node) for node in streaming_response.source_nodes] completion_gen = CompletionGen( response=streaming_response.response_gen, sources=sources ) return completion_gen def chat( self, messages: list[ChatMessage], use_context: bool = False, context_filter: ContextFilter | None = None, ) -> Completion: chat_engine_input = ChatEngineInput.from_messages(messages) last_message = ( chat_engine_input.last_message.content if chat_engine_input.last_message else None ) system_prompt = ( """ You are a helpful assistant named QuickGPT by Quickfox Consulting. Your responses must be strictly and exclusively based on the context documents provided. You are not allowed to use any information, knowledge, or external sources outside of the given context documents. If the answer to a query is not present in the context documents, you should respond with "I do not have enough information in the provided context to answer this question." Your responses should be relevant, informative, and easy to understand. Aim to deliver high-quality answers that are respectful and helpful, using clear and concise language. Focus on providing accurate and reliable answers based solely on the given context. Do not make assumptions, inferences, or draw upon any prior knowledge beyond what is explicitly stated in the context documents. """ ) chat_history = ( chat_engine_input.chat_history if chat_engine_input.chat_history else None ) chat_engine = self._chat_engine( system_prompt=system_prompt, use_context=use_context, context_filter=context_filter, ) # chat_engine = chat_engine.as_chat_engine(chat_mode="react", llm=self.llm_component.llm, verbose=True) # configuring ReAct Chat engine wrapped_response = chat_engine.chat( message=last_message if last_message is not None else "", chat_history=chat_history, ) sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes] completion = Completion(response=wrapped_response.response, sources=sources) return completion