refactor: move modes to enum and update mode explanations

This commit is contained in:
Javier Martinez 2024-07-31 12:44:40 +02:00
parent 0f10107783
commit 2077ff66bb
No known key found for this signature in database

View file

@ -3,6 +3,7 @@ import base64
import logging import logging
import time import time
from collections.abc import Iterable from collections.abc import Iterable
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -34,11 +35,19 @@ UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "<hr>Sources: \n" SOURCES_SEPARATOR = "<hr>Sources: \n"
MODES = [
"Query Files", class Modes(str, Enum):
"Search Files", RAG_MODE = "RAG"
"LLM Chat (no context from files)", SEARCH_MODE = "Search"
"Summarization", BASIC_CHAT_MODE = "Basic"
SUMMARIZE_MODE = "Summarize"
MODES: list[Modes] = [
Modes.RAG_MODE,
Modes.SEARCH_MODE,
Modes.BASIC_CHAT_MODE,
Modes.SUMMARIZE_MODE,
] ]
@ -93,7 +102,9 @@ class PrivateGptUi:
self.mode = MODES[0] self.mode = MODES[0]
self._system_prompt = self._get_default_system_prompt(self.mode) self._system_prompt = self._get_default_system_prompt(self.mode)
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = "" full_response: str = ""
stream = completion_gen.response stream = completion_gen.response
@ -158,8 +169,7 @@ class PrivateGptUi:
), ),
) )
match mode: match mode:
case "Query Files": case Modes.RAG_MODE:
# Use only the selected file for the query # Use only the selected file for the query
context_filter = None context_filter = None
if self._selected_filename is not None: if self._selected_filename is not None:
@ -178,14 +188,14 @@ class PrivateGptUi:
context_filter=context_filter, context_filter=context_filter,
) )
yield from yield_deltas(query_stream) yield from yield_deltas(query_stream)
case "LLM Chat (no context from files)": case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat( llm_stream = self._chat_service.stream_chat(
messages=all_messages, messages=all_messages,
use_context=False, use_context=False,
) )
yield from yield_deltas(llm_stream) yield from yield_deltas(llm_stream)
case "Search Files": case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant( response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0 text=message, limit=4, prev_next_chunks=0
) )
@ -198,7 +208,7 @@ class PrivateGptUi:
f"{source.text}" f"{source.text}"
for index, source in enumerate(sources, start=1) for index, source in enumerate(sources, start=1)
) )
case "Summarization": case Modes.SUMMARIZE_MODE:
# Summarize the given message, optionally using selected files # Summarize the given message, optionally using selected files
context_filter = None context_filter = None
if self._selected_filename: if self._selected_filename:
@ -221,17 +231,17 @@ class PrivateGptUi:
# On initialization and on mode change, this function set the system prompt # On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings). # to the default prompt based on the mode (and user settings).
@staticmethod @staticmethod
def _get_default_system_prompt(mode: str) -> str: def _get_default_system_prompt(mode: Modes) -> str:
p = "" p = ""
match mode: match mode:
# For query chat mode, obtain default system prompt from settings # For query chat mode, obtain default system prompt from settings
case "Query Files": case Modes.RAG_MODE:
p = settings().ui.default_query_system_prompt p = settings().ui.default_query_system_prompt
# For chat mode, obtain default system prompt from settings # For chat mode, obtain default system prompt from settings
case "LLM Chat (no context from files)": case Modes.BASIC_CHAT_MODE:
p = settings().ui.default_chat_system_prompt p = settings().ui.default_chat_system_prompt
# For summarization mode, obtain default system prompt from settings # For summarization mode, obtain default system prompt from settings
case "Summarization": case Modes.SUMMARIZE_MODE:
p = settings().ui.default_summarization_system_prompt p = settings().ui.default_summarization_system_prompt
# For any other mode, clear the system prompt # For any other mode, clear the system prompt
case _: case _:
@ -239,26 +249,17 @@ class PrivateGptUi:
return p return p
@staticmethod @staticmethod
def _get_default_mode_explanation(mode: str) -> str: def _get_default_mode_explanation(mode: Modes) -> str:
match mode: match mode:
case "Query Files": case Modes.RAG_MODE:
return "Get contextualized answers from selected files."
case Modes.SEARCH_MODE:
return "Find relevant chunks of text in selected files."
case Modes.BASIC_CHAT_MODE:
return "Chat with the LLM using its training data. Files are ignored."
case Modes.SUMMARIZE_MODE:
return ( return (
"Query specific files you've ingested. " "Generate a summary of the selected files. Prompt to customize the result. "
"Ideal for retrieving targeted information from particular documents."
)
case "Search Files":
return (
"Search for relevant information across all ingested files. "
"Useful for broad information retrieval."
)
case "LLM Chat (no context from files)":
return (
"Generate responses without using context from ingested files. "
"Suitable for general inquiries."
)
case "Summarization":
return (
"Generate summaries from provided ingested files. "
"This may take significant time depending on the length and complexity of the input." "This may take significant time depending on the length and complexity of the input."
) )
case _: case _:
@ -271,7 +272,7 @@ class PrivateGptUi:
def _set_explanatation_mode(self, explanation_mode: str) -> None: def _set_explanatation_mode(self, explanation_mode: str) -> None:
self._explanation_mode = explanation_mode self._explanation_mode = explanation_mode
def _set_current_mode(self, mode: str) -> Any: def _set_current_mode(self, mode: Modes) -> Any:
self.mode = mode self.mode = mode
self._set_system_prompt(self._get_default_system_prompt(mode)) self._set_system_prompt(self._get_default_system_prompt(mode))
self._set_explanatation_mode(self._get_default_mode_explanation(mode)) self._set_explanatation_mode(self._get_default_mode_explanation(mode))
@ -394,7 +395,7 @@ class PrivateGptUi:
with gr.Column(scale=3): with gr.Column(scale=3):
default_mode = MODES[0] default_mode = MODES[0]
mode = gr.Radio( mode = gr.Radio(
MODES, [mode.value for mode in MODES],
label="Mode", label="Mode",
value=default_mode, value=default_mode,
) )