diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 13bc480..364feb5 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -3,6 +3,7 @@ import base64 import logging import time from collections.abc import Iterable +from enum import Enum from pathlib import Path from typing import Any @@ -34,11 +35,19 @@ UI_TAB_TITLE = "My Private GPT" SOURCES_SEPARATOR = "
Sources: \n" -MODES = [ - "Query Files", - "Search Files", - "LLM Chat (no context from files)", - "Summarization", + +class Modes(str, Enum): + RAG_MODE = "RAG" + SEARCH_MODE = "Search" + BASIC_CHAT_MODE = "Basic" + SUMMARIZE_MODE = "Summarize" + + +MODES: list[Modes] = [ + Modes.RAG_MODE, + Modes.SEARCH_MODE, + Modes.BASIC_CHAT_MODE, + Modes.SUMMARIZE_MODE, ] @@ -93,7 +102,9 @@ class PrivateGptUi: self.mode = MODES[0] self._system_prompt = self._get_default_system_prompt(self.mode) - def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: + def _chat( + self, message: str, history: list[list[str]], mode: Modes, *_: Any + ) -> Any: def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: full_response: str = "" stream = completion_gen.response @@ -158,8 +169,7 @@ class PrivateGptUi: ), ) match mode: - case "Query Files": - + case Modes.RAG_MODE: # Use only the selected file for the query context_filter = None if self._selected_filename is not None: @@ -178,14 +188,14 @@ class PrivateGptUi: context_filter=context_filter, ) yield from yield_deltas(query_stream) - case "LLM Chat (no context from files)": + case Modes.BASIC_CHAT_MODE: llm_stream = self._chat_service.stream_chat( messages=all_messages, use_context=False, ) yield from yield_deltas(llm_stream) - case "Search Files": + case Modes.SEARCH_MODE: response = self._chunks_service.retrieve_relevant( text=message, limit=4, prev_next_chunks=0 ) @@ -198,7 +208,7 @@ class PrivateGptUi: f"{source.text}" for index, source in enumerate(sources, start=1) ) - case "Summarization": + case Modes.SUMMARIZE_MODE: # Summarize the given message, optionally using selected files context_filter = None if self._selected_filename: @@ -221,17 +231,17 @@ class PrivateGptUi: # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @staticmethod - def _get_default_system_prompt(mode: str) -> str: + def _get_default_system_prompt(mode: Modes) -> str: p = "" match mode: # For query chat mode, obtain default system prompt from settings - case "Query Files": + case Modes.RAG_MODE: p = settings().ui.default_query_system_prompt # For chat mode, obtain default system prompt from settings - case "LLM Chat (no context from files)": + case Modes.BASIC_CHAT_MODE: p = settings().ui.default_chat_system_prompt # For summarization mode, obtain default system prompt from settings - case "Summarization": + case Modes.SUMMARIZE_MODE: p = settings().ui.default_summarization_system_prompt # For any other mode, clear the system prompt case _: @@ -239,26 +249,17 @@ class PrivateGptUi: return p @staticmethod - def _get_default_mode_explanation(mode: str) -> str: + def _get_default_mode_explanation(mode: Modes) -> str: match mode: - case "Query Files": + case Modes.RAG_MODE: + return "Get contextualized answers from selected files." + case Modes.SEARCH_MODE: + return "Find relevant chunks of text in selected files." + case Modes.BASIC_CHAT_MODE: + return "Chat with the LLM using its training data. Files are ignored." + case Modes.SUMMARIZE_MODE: return ( - "Query specific files you've ingested. " - "Ideal for retrieving targeted information from particular documents." - ) - case "Search Files": - return ( - "Search for relevant information across all ingested files. " - "Useful for broad information retrieval." - ) - case "LLM Chat (no context from files)": - return ( - "Generate responses without using context from ingested files. " - "Suitable for general inquiries." - ) - case "Summarization": - return ( - "Generate summaries from provided ingested files. " + "Generate a summary of the selected files. Prompt to customize the result. " "This may take significant time depending on the length and complexity of the input." ) case _: @@ -271,7 +272,7 @@ class PrivateGptUi: def _set_explanatation_mode(self, explanation_mode: str) -> None: self._explanation_mode = explanation_mode - def _set_current_mode(self, mode: str) -> Any: + def _set_current_mode(self, mode: Modes) -> Any: self.mode = mode self._set_system_prompt(self._get_default_system_prompt(mode)) self._set_explanatation_mode(self._get_default_mode_explanation(mode)) @@ -394,7 +395,7 @@ class PrivateGptUi: with gr.Column(scale=3): default_mode = MODES[0] mode = gr.Radio( - MODES, + [mode.value for mode in MODES], label="Mode", value=default_mode, )