mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 10:45:42 +01:00
* Update ui.py Changed 'curated_sources' from a list, in order to maintain score order when returning the curated sources. * Maintain score order after curating sources
453 lines
18 KiB
Python
453 lines
18 KiB
Python
"""This file should be imported only and only if you want to run the UI locally."""
|
|
import itertools
|
|
import logging
|
|
import time
|
|
from collections.abc import Iterable
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import gradio as gr # type: ignore
|
|
from fastapi import FastAPI
|
|
from gradio.themes.utils.colors import slate # type: ignore
|
|
from injector import inject, singleton
|
|
from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole
|
|
from pydantic import BaseModel
|
|
|
|
from private_gpt.constants import PROJECT_ROOT_PATH
|
|
from private_gpt.di import global_injector
|
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
|
|
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
|
from private_gpt.server.ingest.ingest_service import IngestService
|
|
from private_gpt.settings.settings import settings
|
|
from private_gpt.ui.images import logo_svg
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
|
|
# Should be "private_gpt/ui/avatar-bot.ico"
|
|
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico"
|
|
|
|
UI_TAB_TITLE = "My Private GPT"
|
|
|
|
SOURCES_SEPARATOR = "\n\n Sources: \n"
|
|
|
|
MODES = ["Query Files", "Search Files", "LLM Chat (no context from files)"]
|
|
|
|
|
|
class Source(BaseModel):
|
|
file: str
|
|
page: str
|
|
text: str
|
|
|
|
class Config:
|
|
frozen = True
|
|
|
|
@staticmethod
|
|
def curate_sources(sources: list[Chunk]) -> list["Source"]:
|
|
curated_sources = []
|
|
|
|
for chunk in sources:
|
|
doc_metadata = chunk.document.doc_metadata
|
|
|
|
file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
|
|
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
|
|
|
|
source = Source(file=file_name, page=page_label, text=chunk.text)
|
|
curated_sources.append(source)
|
|
curated_sources = list(
|
|
dict.fromkeys(curated_sources).keys()
|
|
) # Unique sources only
|
|
|
|
return curated_sources
|
|
|
|
|
|
@singleton
|
|
class PrivateGptUi:
|
|
@inject
|
|
def __init__(
|
|
self,
|
|
ingest_service: IngestService,
|
|
chat_service: ChatService,
|
|
chunks_service: ChunksService,
|
|
) -> None:
|
|
self._ingest_service = ingest_service
|
|
self._chat_service = chat_service
|
|
self._chunks_service = chunks_service
|
|
|
|
# Cache the UI blocks
|
|
self._ui_block = None
|
|
|
|
self._selected_filename = None
|
|
|
|
# Initialize system prompt based on default mode
|
|
self.mode = MODES[0]
|
|
self._system_prompt = self._get_default_system_prompt(self.mode)
|
|
|
|
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
|
|
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
|
|
full_response: str = ""
|
|
stream = completion_gen.response
|
|
for delta in stream:
|
|
if isinstance(delta, str):
|
|
full_response += str(delta)
|
|
elif isinstance(delta, ChatResponse):
|
|
full_response += delta.delta or ""
|
|
yield full_response
|
|
time.sleep(0.02)
|
|
|
|
if completion_gen.sources:
|
|
full_response += SOURCES_SEPARATOR
|
|
cur_sources = Source.curate_sources(completion_gen.sources)
|
|
sources_text = "\n\n\n"
|
|
used_files = set()
|
|
for index, source in enumerate(cur_sources, start=1):
|
|
if (source.file + "-" + source.page) not in used_files:
|
|
sources_text = (
|
|
sources_text
|
|
+ f"{index}. {source.file} (page {source.page}) \n\n"
|
|
)
|
|
used_files.add(source.file + "-" + source.page)
|
|
full_response += sources_text
|
|
yield full_response
|
|
|
|
def build_history() -> list[ChatMessage]:
|
|
history_messages: list[ChatMessage] = list(
|
|
itertools.chain(
|
|
*[
|
|
[
|
|
ChatMessage(content=interaction[0], role=MessageRole.USER),
|
|
ChatMessage(
|
|
# Remove from history content the Sources information
|
|
content=interaction[1].split(SOURCES_SEPARATOR)[0],
|
|
role=MessageRole.ASSISTANT,
|
|
),
|
|
]
|
|
for interaction in history
|
|
]
|
|
)
|
|
)
|
|
|
|
# max 20 messages to try to avoid context overflow
|
|
return history_messages[:20]
|
|
|
|
new_message = ChatMessage(content=message, role=MessageRole.USER)
|
|
all_messages = [*build_history(), new_message]
|
|
# If a system prompt is set, add it as a system message
|
|
if self._system_prompt:
|
|
all_messages.insert(
|
|
0,
|
|
ChatMessage(
|
|
content=self._system_prompt,
|
|
role=MessageRole.SYSTEM,
|
|
),
|
|
)
|
|
match mode:
|
|
case "Query Files":
|
|
|
|
# Use only the selected file for the query
|
|
context_filter = None
|
|
if self._selected_filename is not None:
|
|
docs_ids = []
|
|
for ingested_document in self._ingest_service.list_ingested():
|
|
if (
|
|
ingested_document.doc_metadata["file_name"]
|
|
== self._selected_filename
|
|
):
|
|
docs_ids.append(ingested_document.doc_id)
|
|
context_filter = ContextFilter(docs_ids=docs_ids)
|
|
|
|
query_stream = self._chat_service.stream_chat(
|
|
messages=all_messages,
|
|
use_context=True,
|
|
context_filter=context_filter,
|
|
)
|
|
yield from yield_deltas(query_stream)
|
|
case "LLM Chat (no context from files)":
|
|
llm_stream = self._chat_service.stream_chat(
|
|
messages=all_messages,
|
|
use_context=False,
|
|
)
|
|
yield from yield_deltas(llm_stream)
|
|
|
|
case "Search Files":
|
|
response = self._chunks_service.retrieve_relevant(
|
|
text=message, limit=4, prev_next_chunks=0
|
|
)
|
|
|
|
sources = Source.curate_sources(response)
|
|
|
|
yield "\n\n\n".join(
|
|
f"{index}. **{source.file} "
|
|
f"(page {source.page})**\n "
|
|
f"{source.text}"
|
|
for index, source in enumerate(sources, start=1)
|
|
)
|
|
|
|
# On initialization and on mode change, this function set the system prompt
|
|
# to the default prompt based on the mode (and user settings).
|
|
@staticmethod
|
|
def _get_default_system_prompt(mode: str) -> str:
|
|
p = ""
|
|
match mode:
|
|
# For query chat mode, obtain default system prompt from settings
|
|
case "Query Files":
|
|
p = settings().ui.default_query_system_prompt
|
|
# For chat mode, obtain default system prompt from settings
|
|
case "LLM Chat (no context from files)":
|
|
p = settings().ui.default_chat_system_prompt
|
|
# For any other mode, clear the system prompt
|
|
case _:
|
|
p = ""
|
|
return p
|
|
|
|
def _set_system_prompt(self, system_prompt_input: str) -> None:
|
|
logger.info(f"Setting system prompt to: {system_prompt_input}")
|
|
self._system_prompt = system_prompt_input
|
|
|
|
def _set_current_mode(self, mode: str) -> Any:
|
|
self.mode = mode
|
|
self._set_system_prompt(self._get_default_system_prompt(mode))
|
|
# Update placeholder and allow interaction if default system prompt is set
|
|
if self._system_prompt:
|
|
return gr.update(placeholder=self._system_prompt, interactive=True)
|
|
# Update placeholder and disable interaction if no default system prompt is set
|
|
else:
|
|
return gr.update(placeholder=self._system_prompt, interactive=False)
|
|
|
|
def _list_ingested_files(self) -> list[list[str]]:
|
|
files = set()
|
|
for ingested_document in self._ingest_service.list_ingested():
|
|
if ingested_document.doc_metadata is None:
|
|
# Skipping documents without metadata
|
|
continue
|
|
file_name = ingested_document.doc_metadata.get(
|
|
"file_name", "[FILE NAME MISSING]"
|
|
)
|
|
files.add(file_name)
|
|
return [[row] for row in files]
|
|
|
|
def _upload_file(self, files: list[str]) -> None:
|
|
logger.debug("Loading count=%s files", len(files))
|
|
paths = [Path(file) for file in files]
|
|
|
|
# remove all existing Documents with name identical to a new file upload:
|
|
file_names = [path.name for path in paths]
|
|
doc_ids_to_delete = []
|
|
for ingested_document in self._ingest_service.list_ingested():
|
|
if (
|
|
ingested_document.doc_metadata
|
|
and ingested_document.doc_metadata["file_name"] in file_names
|
|
):
|
|
doc_ids_to_delete.append(ingested_document.doc_id)
|
|
if len(doc_ids_to_delete) > 0:
|
|
logger.info(
|
|
"Uploading file(s) which were already ingested: %s document(s) will be replaced.",
|
|
len(doc_ids_to_delete),
|
|
)
|
|
for doc_id in doc_ids_to_delete:
|
|
self._ingest_service.delete(doc_id)
|
|
|
|
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
|
|
|
|
def _delete_all_files(self) -> Any:
|
|
ingested_files = self._ingest_service.list_ingested()
|
|
logger.debug("Deleting count=%s files", len(ingested_files))
|
|
for ingested_document in ingested_files:
|
|
self._ingest_service.delete(ingested_document.doc_id)
|
|
return [
|
|
gr.List(self._list_ingested_files()),
|
|
gr.components.Button(interactive=False),
|
|
gr.components.Button(interactive=False),
|
|
gr.components.Textbox("All files"),
|
|
]
|
|
|
|
def _delete_selected_file(self) -> Any:
|
|
logger.debug("Deleting selected %s", self._selected_filename)
|
|
# Note: keep looping for pdf's (each page became a Document)
|
|
for ingested_document in self._ingest_service.list_ingested():
|
|
if (
|
|
ingested_document.doc_metadata
|
|
and ingested_document.doc_metadata["file_name"]
|
|
== self._selected_filename
|
|
):
|
|
self._ingest_service.delete(ingested_document.doc_id)
|
|
return [
|
|
gr.List(self._list_ingested_files()),
|
|
gr.components.Button(interactive=False),
|
|
gr.components.Button(interactive=False),
|
|
gr.components.Textbox("All files"),
|
|
]
|
|
|
|
def _deselect_selected_file(self) -> Any:
|
|
self._selected_filename = None
|
|
return [
|
|
gr.components.Button(interactive=False),
|
|
gr.components.Button(interactive=False),
|
|
gr.components.Textbox("All files"),
|
|
]
|
|
|
|
def _selected_a_file(self, select_data: gr.SelectData) -> Any:
|
|
self._selected_filename = select_data.value
|
|
return [
|
|
gr.components.Button(interactive=True),
|
|
gr.components.Button(interactive=True),
|
|
gr.components.Textbox(self._selected_filename),
|
|
]
|
|
|
|
def _build_ui_blocks(self) -> gr.Blocks:
|
|
logger.debug("Creating the UI blocks")
|
|
with gr.Blocks(
|
|
title=UI_TAB_TITLE,
|
|
theme=gr.themes.Soft(primary_hue=slate),
|
|
css=".logo { "
|
|
"display:flex;"
|
|
"background-color: #C7BAFF;"
|
|
"height: 80px;"
|
|
"border-radius: 8px;"
|
|
"align-content: center;"
|
|
"justify-content: center;"
|
|
"align-items: center;"
|
|
"}"
|
|
".logo img { height: 25% }"
|
|
".contain { display: flex !important; flex-direction: column !important; }"
|
|
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
|
|
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
|
|
"#col { height: calc(100vh - 112px - 16px) !important; }",
|
|
) as blocks:
|
|
with gr.Row():
|
|
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
|
|
|
|
with gr.Row(equal_height=False):
|
|
with gr.Column(scale=3):
|
|
mode = gr.Radio(
|
|
MODES,
|
|
label="Mode",
|
|
value="Query Files",
|
|
)
|
|
upload_button = gr.components.UploadButton(
|
|
"Upload File(s)",
|
|
type="filepath",
|
|
file_count="multiple",
|
|
size="sm",
|
|
)
|
|
ingested_dataset = gr.List(
|
|
self._list_ingested_files,
|
|
headers=["File name"],
|
|
label="Ingested Files",
|
|
height=235,
|
|
interactive=False,
|
|
render=False, # Rendered under the button
|
|
)
|
|
upload_button.upload(
|
|
self._upload_file,
|
|
inputs=upload_button,
|
|
outputs=ingested_dataset,
|
|
)
|
|
ingested_dataset.change(
|
|
self._list_ingested_files,
|
|
outputs=ingested_dataset,
|
|
)
|
|
ingested_dataset.render()
|
|
deselect_file_button = gr.components.Button(
|
|
"De-select selected file", size="sm", interactive=False
|
|
)
|
|
selected_text = gr.components.Textbox(
|
|
"All files", label="Selected for Query or Deletion", max_lines=1
|
|
)
|
|
delete_file_button = gr.components.Button(
|
|
"🗑️ Delete selected file",
|
|
size="sm",
|
|
visible=settings().ui.delete_file_button_enabled,
|
|
interactive=False,
|
|
)
|
|
delete_files_button = gr.components.Button(
|
|
"⚠️ Delete ALL files",
|
|
size="sm",
|
|
visible=settings().ui.delete_all_files_button_enabled,
|
|
)
|
|
deselect_file_button.click(
|
|
self._deselect_selected_file,
|
|
outputs=[
|
|
delete_file_button,
|
|
deselect_file_button,
|
|
selected_text,
|
|
],
|
|
)
|
|
ingested_dataset.select(
|
|
fn=self._selected_a_file,
|
|
outputs=[
|
|
delete_file_button,
|
|
deselect_file_button,
|
|
selected_text,
|
|
],
|
|
)
|
|
delete_file_button.click(
|
|
self._delete_selected_file,
|
|
outputs=[
|
|
ingested_dataset,
|
|
delete_file_button,
|
|
deselect_file_button,
|
|
selected_text,
|
|
],
|
|
)
|
|
delete_files_button.click(
|
|
self._delete_all_files,
|
|
outputs=[
|
|
ingested_dataset,
|
|
delete_file_button,
|
|
deselect_file_button,
|
|
selected_text,
|
|
],
|
|
)
|
|
system_prompt_input = gr.Textbox(
|
|
placeholder=self._system_prompt,
|
|
label="System Prompt",
|
|
lines=2,
|
|
interactive=True,
|
|
render=False,
|
|
)
|
|
# When mode changes, set default system prompt
|
|
mode.change(
|
|
self._set_current_mode, inputs=mode, outputs=system_prompt_input
|
|
)
|
|
# On blur, set system prompt to use in queries
|
|
system_prompt_input.blur(
|
|
self._set_system_prompt,
|
|
inputs=system_prompt_input,
|
|
)
|
|
|
|
with gr.Column(scale=7, elem_id="col"):
|
|
_ = gr.ChatInterface(
|
|
self._chat,
|
|
chatbot=gr.Chatbot(
|
|
label=f"LLM: {settings().llm.mode}",
|
|
show_copy_button=True,
|
|
elem_id="chatbot",
|
|
render=False,
|
|
avatar_images=(
|
|
None,
|
|
AVATAR_BOT,
|
|
),
|
|
),
|
|
additional_inputs=[mode, upload_button, system_prompt_input],
|
|
)
|
|
return blocks
|
|
|
|
def get_ui_blocks(self) -> gr.Blocks:
|
|
if self._ui_block is None:
|
|
self._ui_block = self._build_ui_blocks()
|
|
return self._ui_block
|
|
|
|
def mount_in_app(self, app: FastAPI, path: str) -> None:
|
|
blocks = self.get_ui_blocks()
|
|
blocks.queue()
|
|
logger.info("Mounting the gradio UI, at path=%s", path)
|
|
gr.mount_gradio_app(app, blocks, path=path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
ui = global_injector.get(PrivateGptUi)
|
|
_blocks = ui.get_ui_blocks()
|
|
_blocks.queue()
|
|
_blocks.launch(debug=False, show_api=False)
|