refactor: move modes to enum and update mode explanations

This commit is contained in:
Javier Martinez 2024-07-31 12:44:40 +02:00
parent 0f10107783
commit 2077ff66bb
No known key found for this signature in database

View file

@ -3,6 +3,7 @@ import base64
import logging
import time
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import Any
@ -34,11 +35,19 @@ UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "<hr>Sources: \n"
MODES = [
"Query Files",
"Search Files",
"LLM Chat (no context from files)",
"Summarization",
class Modes(str, Enum):
RAG_MODE = "RAG"
SEARCH_MODE = "Search"
BASIC_CHAT_MODE = "Basic"
SUMMARIZE_MODE = "Summarize"
MODES: list[Modes] = [
Modes.RAG_MODE,
Modes.SEARCH_MODE,
Modes.BASIC_CHAT_MODE,
Modes.SUMMARIZE_MODE,
]
@ -93,7 +102,9 @@ class PrivateGptUi:
self.mode = MODES[0]
self._system_prompt = self._get_default_system_prompt(self.mode)
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
@ -158,8 +169,7 @@ class PrivateGptUi:
),
)
match mode:
case "Query Files":
case Modes.RAG_MODE:
# Use only the selected file for the query
context_filter = None
if self._selected_filename is not None:
@ -178,14 +188,14 @@ class PrivateGptUi:
context_filter=context_filter,
)
yield from yield_deltas(query_stream)
case "LLM Chat (no context from files)":
case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
case "Search Files":
case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)
@ -198,7 +208,7 @@ class PrivateGptUi:
f"{source.text}"
for index, source in enumerate(sources, start=1)
)
case "Summarization":
case Modes.SUMMARIZE_MODE:
# Summarize the given message, optionally using selected files
context_filter = None
if self._selected_filename:
@ -221,17 +231,17 @@ class PrivateGptUi:
# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
@staticmethod
def _get_default_system_prompt(mode: str) -> str:
def _get_default_system_prompt(mode: Modes) -> str:
p = ""
match mode:
# For query chat mode, obtain default system prompt from settings
case "Query Files":
case Modes.RAG_MODE:
p = settings().ui.default_query_system_prompt
# For chat mode, obtain default system prompt from settings
case "LLM Chat (no context from files)":
case Modes.BASIC_CHAT_MODE:
p = settings().ui.default_chat_system_prompt
# For summarization mode, obtain default system prompt from settings
case "Summarization":
case Modes.SUMMARIZE_MODE:
p = settings().ui.default_summarization_system_prompt
# For any other mode, clear the system prompt
case _:
@ -239,26 +249,17 @@ class PrivateGptUi:
return p
@staticmethod
def _get_default_mode_explanation(mode: str) -> str:
def _get_default_mode_explanation(mode: Modes) -> str:
match mode:
case "Query Files":
case Modes.RAG_MODE:
return "Get contextualized answers from selected files."
case Modes.SEARCH_MODE:
return "Find relevant chunks of text in selected files."
case Modes.BASIC_CHAT_MODE:
return "Chat with the LLM using its training data. Files are ignored."
case Modes.SUMMARIZE_MODE:
return (
"Query specific files you've ingested. "
"Ideal for retrieving targeted information from particular documents."
)
case "Search Files":
return (
"Search for relevant information across all ingested files. "
"Useful for broad information retrieval."
)
case "LLM Chat (no context from files)":
return (
"Generate responses without using context from ingested files. "
"Suitable for general inquiries."
)
case "Summarization":
return (
"Generate summaries from provided ingested files. "
"Generate a summary of the selected files. Prompt to customize the result. "
"This may take significant time depending on the length and complexity of the input."
)
case _:
@ -271,7 +272,7 @@ class PrivateGptUi:
def _set_explanatation_mode(self, explanation_mode: str) -> None:
self._explanation_mode = explanation_mode
def _set_current_mode(self, mode: str) -> Any:
def _set_current_mode(self, mode: Modes) -> Any:
self.mode = mode
self._set_system_prompt(self._get_default_system_prompt(mode))
self._set_explanatation_mode(self._get_default_mode_explanation(mode))
@ -394,7 +395,7 @@ class PrivateGptUi:
with gr.Column(scale=3):
default_mode = MODES[0]
mode = gr.Radio(
MODES,
[mode.value for mode in MODES],
label="Mode",
value=default_mode,
)