feat(ui): Allows User to Set System Prompt via "Additional Options" in Chat Interface (#1353)

This commit is contained in:
3ly-13 2023-12-10 12:45:14 -06:00 committed by GitHub
parent a072a40a7c
commit 145f3ec9f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 110 additions and 18 deletions

View file

@ -30,6 +30,8 @@ UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "\n\n Sources: \n"
MODES = ["Query Docs", "Search in Docs", "LLM Chat"]
class Source(BaseModel):
file: str
@ -71,6 +73,10 @@ class PrivateGptUi:
# Cache the UI blocks
self._ui_block = None
# Initialize system prompt based on default mode
self.mode = MODES[0]
self._system_prompt = self._get_default_system_prompt(self.mode)
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
@ -114,25 +120,22 @@ class PrivateGptUi:
new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
# If a system prompt is set, add it as a system message
if self._system_prompt:
all_messages.insert(
0,
ChatMessage(
content=self._system_prompt,
role=MessageRole.SYSTEM,
),
)
match mode:
case "Query Docs":
# Add a system message to force the behaviour of the LLM
# to answer only questions about the provided context.
all_messages.insert(
0,
ChatMessage(
content="You can only answer questions about the provided context. If you know the answer "
"but it is not based in the provided context, don't provide the answer, just state "
"the answer is not in the context provided.",
role=MessageRole.SYSTEM,
),
)
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
)
yield from yield_deltas(query_stream)
case "LLM Chat":
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
@ -154,6 +157,37 @@ class PrivateGptUi:
for index, source in enumerate(sources, start=1)
)
# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
@staticmethod
def _get_default_system_prompt(mode: str) -> str:
p = ""
match mode:
# For query chat mode, obtain default system prompt from settings
case "Query Docs":
p = settings().ui.default_query_system_prompt
# For chat mode, obtain default system prompt from settings
case "LLM Chat":
p = settings().ui.default_chat_system_prompt
# For any other mode, clear the system prompt
case _:
p = ""
return p
def _set_system_prompt(self, system_prompt_input: str) -> None:
logger.info(f"Setting system prompt to: {system_prompt_input}")
self._system_prompt = system_prompt_input
def _set_current_mode(self, mode: str) -> Any:
self.mode = mode
self._set_system_prompt(self._get_default_system_prompt(mode))
# Update placeholder and allow interaction if default system prompt is set
if self._system_prompt:
return gr.update(placeholder=self._system_prompt, interactive=True)
# Update placeholder and disable interaction if no default system prompt is set
else:
return gr.update(placeholder=self._system_prompt, interactive=False)
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
@ -193,7 +227,7 @@ class PrivateGptUi:
with gr.Row():
with gr.Column(scale=3, variant="compact"):
mode = gr.Radio(
["Query Docs", "Search in Docs", "LLM Chat"],
MODES,
label="Mode",
value="Query Docs",
)
@ -220,6 +254,23 @@ class PrivateGptUi:
outputs=ingested_dataset,
)
ingested_dataset.render()
system_prompt_input = gr.Textbox(
placeholder=self._system_prompt,
label="System Prompt",
lines=2,
interactive=True,
render=False,
)
# When mode changes, set default system prompt
mode.change(
self._set_current_mode, inputs=mode, outputs=system_prompt_input
)
# On blur, set system prompt to use in queries
system_prompt_input.blur(
self._set_system_prompt,
inputs=system_prompt_input,
)
with gr.Column(scale=7):
_ = gr.ChatInterface(
self._chat,
@ -232,7 +283,7 @@ class PrivateGptUi:
AVATAR_BOT,
),
),
additional_inputs=[mode, upload_button],
additional_inputs=[mode, upload_button, system_prompt_input],
)
return blocks