feat(recipe): add our first recipe Summarize (#2028)
Some checks are pending
publish docs / publish-docs (push) Waiting to run
release-please / release-please (push) Waiting to run
tests / setup (push) Waiting to run
tests / ${{ matrix.quality-command }} (black) (push) Blocked by required conditions
tests / ${{ matrix.quality-command }} (mypy) (push) Blocked by required conditions
tests / ${{ matrix.quality-command }} (ruff) (push) Blocked by required conditions
tests / test (push) Blocked by required conditions
tests / all_checks_passed (push) Blocked by required conditions

* feat: add summary recipe

* test: add summary tests

* docs: move all recipes docs

* docs: add recipes and summarize doc

* docs: update openapi reference

* refactor: split method in two method (summary)

* feat: add initial summarize ui

* feat: add mode explanation

* fix: mypy

* feat: allow to configure async property in summarize

* refactor: move modes to enum and update mode explanations

* docs: fix url

* docs: remove list-llm pages

* docs: remove double header

* fix: summary description
This commit is contained in:
Javier Martinez 2024-07-31 16:53:27 +02:00 committed by GitHub
parent 40638a18a5
commit 8119842ae6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 743 additions and 148 deletions

View file

@ -15,6 +15,7 @@ from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.server.recipes.summarize.summarize_router import summarize_router
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
@ -32,12 +33,13 @@ def create_app(root_injector: Injector) -> FastAPI:
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(summarize_router)
app.include_router(embeddings_router)
app.include_router(health_router)
# Add LlamaIndex simple observability
global_handler = create_global_handler("simple")
if global_handler is not None:
if global_handler:
LlamaIndexSettings.callback_manager = CallbackManager([global_handler])
settings = root_injector.get(Settings)

View file

@ -0,0 +1,86 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
to_openai_sse_stream,
)
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
from private_gpt.server.utils.auth import authenticated
summarize_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class SummarizeBody(BaseModel):
text: str | None = None
use_context: bool = False
context_filter: ContextFilter | None = None
prompt: str | None = None
instructions: str | None = None
stream: bool = False
class SummarizeResponse(BaseModel):
summary: str
@summarize_router.post(
"/summarize",
response_model=None,
summary="Summarize",
responses={200: {"model": SummarizeResponse}},
tags=["Recipes"],
)
def summarize(
request: Request, body: SummarizeBody
) -> SummarizeResponse | StreamingResponse:
"""Given a text, the model will return a summary.
Optionally include `instructions` to influence the way the summary is generated.
If `use_context`
is set to `true`, the model will also use the content coming from the ingested
documents in the summary. The documents being used can
be filtered by their metadata using the `context_filter`.
Ingested documents metadata can be found using `/ingest/list` endpoint.
If you want all ingested documents to be used, remove `context_filter` altogether.
If `prompt` is set, it will be used as the prompt for the summarization,
otherwise the default prompt will be used.
When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
```
{"id":"12345","object":"completion.chunk","created":1694268190,
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
"finish_reason":null}]}
```
"""
service: SummarizeService = request.state.injector.get(SummarizeService)
if body.stream:
completion_gen = service.stream_summarize(
text=body.text,
instructions=body.instructions,
use_context=body.use_context,
context_filter=body.context_filter,
prompt=body.prompt,
)
return StreamingResponse(
to_openai_sse_stream(
response_generator=completion_gen,
),
media_type="text/event-stream",
)
else:
completion = service.summarize(
text=body.text,
instructions=body.instructions,
use_context=body.use_context,
context_filter=body.context_filter,
prompt=body.prompt,
)
return SummarizeResponse(
summary=completion,
)

View file

@ -0,0 +1,172 @@
from itertools import chain
from injector import inject, singleton
from llama_index.core import (
Document,
StorageContext,
SummaryIndex,
)
from llama_index.core.base.response.schema import Response, StreamingResponse
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.storage.docstore.types import RefDocInfo
from llama_index.core.types import TokenGen
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.settings.settings import Settings
DEFAULT_SUMMARIZE_PROMPT = (
"Provide a comprehensive summary of the provided context information. "
"The summary should cover all the key points and main ideas presented in "
"the original text, while also condensing the information into a concise "
"and easy-to-understand format. Please ensure that the summary includes "
"relevant details and examples that support the main ideas, while avoiding "
"any unnecessary information or repetition."
)
@singleton
class SummarizeService:
@inject
def __init__(
self,
settings: Settings,
llm_component: LLMComponent,
node_store_component: NodeStoreComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
) -> None:
self.settings = settings
self.llm_component = llm_component
self.node_store_component = node_store_component
self.vector_store_component = vector_store_component
self.embedding_component = embedding_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
index_store=node_store_component.index_store,
)
@staticmethod
def _filter_ref_docs(
ref_docs: dict[str, RefDocInfo], context_filter: ContextFilter | None
) -> list[RefDocInfo]:
if context_filter is None or not context_filter.docs_ids:
return list(ref_docs.values())
return [
ref_doc
for doc_id, ref_doc in ref_docs.items()
if doc_id in context_filter.docs_ids
]
def _summarize(
self,
use_context: bool = False,
stream: bool = False,
text: str | None = None,
instructions: str | None = None,
context_filter: ContextFilter | None = None,
prompt: str | None = None,
) -> str | TokenGen:
nodes_to_summarize = []
# Add text to summarize
if text:
text_documents = [Document(text=text)]
nodes_to_summarize += (
SentenceSplitter.from_defaults().get_nodes_from_documents(
text_documents
)
)
# Add context documents to summarize
if use_context:
# 1. Recover all ref docs
ref_docs: dict[
str, RefDocInfo
] | None = self.storage_context.docstore.get_all_ref_doc_info()
if ref_docs is None:
raise ValueError("No documents have been ingested yet.")
# 2. Filter documents based on context_filter (if provided)
filtered_ref_docs = self._filter_ref_docs(ref_docs, context_filter)
# 3. Get all nodes from the filtered documents
filtered_node_ids = chain.from_iterable(
[ref_doc.node_ids for ref_doc in filtered_ref_docs]
)
filtered_nodes = self.storage_context.docstore.get_nodes(
node_ids=list(filtered_node_ids),
)
nodes_to_summarize += filtered_nodes
# Create a SummaryIndex to summarize the nodes
summary_index = SummaryIndex(
nodes=nodes_to_summarize,
storage_context=StorageContext.from_defaults(), # In memory SummaryIndex
show_progress=True,
)
# Make a tree summarization query
# above the set of all candidate nodes
query_engine = summary_index.as_query_engine(
llm=self.llm_component.llm,
response_mode=ResponseMode.TREE_SUMMARIZE,
streaming=stream,
use_async=self.settings.summarize.use_async,
)
prompt = prompt or DEFAULT_SUMMARIZE_PROMPT
summarize_query = prompt + "\n" + (instructions or "")
response = query_engine.query(summarize_query)
if isinstance(response, Response):
return response.response or ""
elif isinstance(response, StreamingResponse):
return response.response_gen
else:
raise TypeError(f"The result is not of a supported type: {type(response)}")
def summarize(
self,
use_context: bool = False,
text: str | None = None,
instructions: str | None = None,
context_filter: ContextFilter | None = None,
prompt: str | None = None,
) -> str:
return self._summarize(
use_context=use_context,
stream=False,
text=text,
instructions=instructions,
context_filter=context_filter,
prompt=prompt,
) # type: ignore
def stream_summarize(
self,
use_context: bool = False,
text: str | None = None,
instructions: str | None = None,
context_filter: ContextFilter | None = None,
prompt: str | None = None,
) -> TokenGen:
return self._summarize(
use_context=use_context,
stream=True,
text=text,
instructions=instructions,
context_filter=context_filter,
prompt=prompt,
) # type: ignore

View file

@ -353,6 +353,10 @@ class UISettings(BaseModel):
default_query_system_prompt: str = Field(
None, description="The default system prompt to use for the query mode."
)
default_summarization_system_prompt: str = Field(
None,
description="The default system prompt to use for the summarization mode.",
)
delete_file_button_enabled: bool = Field(
True, description="If the button to delete a file is enabled or not."
)
@ -388,6 +392,13 @@ class RagSettings(BaseModel):
rerank: RerankSettings
class SummarizeSettings(BaseModel):
use_async: bool = Field(
True,
description="If set to True, the summarization will be done asynchronously.",
)
class ClickHouseSettings(BaseModel):
host: str = Field(
"localhost",
@ -577,6 +588,7 @@ class Settings(BaseModel):
vectorstore: VectorstoreSettings
nodestore: NodeStoreSettings
rag: RagSettings
summarize: SummarizeSettings
qdrant: QdrantSettings | None = None
postgres: PostgresSettings | None = None
clickhouse: ClickHouseSettings | None = None

View file

@ -3,6 +3,7 @@ import base64
import logging
import time
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import Any
@ -11,6 +12,7 @@ from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole
from llama_index.core.types import TokenGen
from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH
@ -19,6 +21,7 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg
@ -32,7 +35,20 @@ UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "<hr>Sources: \n"
MODES = ["Query Files", "Search Files", "LLM Chat (no context from files)"]
class Modes(str, Enum):
RAG_MODE = "RAG"
SEARCH_MODE = "Search"
BASIC_CHAT_MODE = "Basic"
SUMMARIZE_MODE = "Summarize"
MODES: list[Modes] = [
Modes.RAG_MODE,
Modes.SEARCH_MODE,
Modes.BASIC_CHAT_MODE,
Modes.SUMMARIZE_MODE,
]
class Source(BaseModel):
@ -70,10 +86,12 @@ class PrivateGptUi:
ingest_service: IngestService,
chat_service: ChatService,
chunks_service: ChunksService,
summarizeService: SummarizeService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
self._chunks_service = chunks_service
self._summarize_service = summarizeService
# Cache the UI blocks
self._ui_block = None
@ -84,7 +102,9 @@ class PrivateGptUi:
self.mode = MODES[0]
self._system_prompt = self._get_default_system_prompt(self.mode)
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
@ -112,6 +132,12 @@ class PrivateGptUi:
full_response += sources_text
yield full_response
def yield_tokens(token_gen: TokenGen) -> Iterable[str]:
full_response: str = ""
for token in token_gen:
full_response += str(token)
yield full_response
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = []
@ -143,8 +169,7 @@ class PrivateGptUi:
),
)
match mode:
case "Query Files":
case Modes.RAG_MODE:
# Use only the selected file for the query
context_filter = None
if self._selected_filename is not None:
@ -163,14 +188,14 @@ class PrivateGptUi:
context_filter=context_filter,
)
yield from yield_deltas(query_stream)
case "LLM Chat (no context from files)":
case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
case "Search Files":
case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)
@ -183,37 +208,76 @@ class PrivateGptUi:
f"{source.text}"
for index, source in enumerate(sources, start=1)
)
case Modes.SUMMARIZE_MODE:
# Summarize the given message, optionally using selected files
context_filter = None
if self._selected_filename:
docs_ids = []
for ingested_document in self._ingest_service.list_ingested():
if (
ingested_document.doc_metadata["file_name"]
== self._selected_filename
):
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)
summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
@staticmethod
def _get_default_system_prompt(mode: str) -> str:
def _get_default_system_prompt(mode: Modes) -> str:
p = ""
match mode:
# For query chat mode, obtain default system prompt from settings
case "Query Files":
case Modes.RAG_MODE:
p = settings().ui.default_query_system_prompt
# For chat mode, obtain default system prompt from settings
case "LLM Chat (no context from files)":
case Modes.BASIC_CHAT_MODE:
p = settings().ui.default_chat_system_prompt
# For summarization mode, obtain default system prompt from settings
case Modes.SUMMARIZE_MODE:
p = settings().ui.default_summarization_system_prompt
# For any other mode, clear the system prompt
case _:
p = ""
return p
@staticmethod
def _get_default_mode_explanation(mode: Modes) -> str:
match mode:
case Modes.RAG_MODE:
return "Get contextualized answers from selected files."
case Modes.SEARCH_MODE:
return "Find relevant chunks of text in selected files."
case Modes.BASIC_CHAT_MODE:
return "Chat with the LLM using its training data. Files are ignored."
case Modes.SUMMARIZE_MODE:
return "Generate a summary of the selected files. Prompt to customize the result."
case _:
return ""
def _set_system_prompt(self, system_prompt_input: str) -> None:
logger.info(f"Setting system prompt to: {system_prompt_input}")
self._system_prompt = system_prompt_input
def _set_current_mode(self, mode: str) -> Any:
def _set_explanatation_mode(self, explanation_mode: str) -> None:
self._explanation_mode = explanation_mode
def _set_current_mode(self, mode: Modes) -> Any:
self.mode = mode
self._set_system_prompt(self._get_default_system_prompt(mode))
# Update placeholder and allow interaction if default system prompt is set
if self._system_prompt:
return gr.update(placeholder=self._system_prompt, interactive=True)
# Update placeholder and disable interaction if no default system prompt is set
else:
return gr.update(placeholder=self._system_prompt, interactive=False)
self._set_explanatation_mode(self._get_default_mode_explanation(mode))
interactive = self._system_prompt is not None
return [
gr.update(placeholder=self._system_prompt, interactive=interactive),
gr.update(value=self._explanation_mode),
]
def _list_ingested_files(self) -> list[list[str]]:
files = set()
@ -326,10 +390,17 @@ class PrivateGptUi:
with gr.Row(equal_height=False):
with gr.Column(scale=3):
default_mode = MODES[0]
mode = gr.Radio(
MODES,
[mode.value for mode in MODES],
label="Mode",
value="Query Files",
value=default_mode,
)
explanation_mode = gr.Textbox(
placeholder=self._get_default_mode_explanation(default_mode),
show_label=False,
max_lines=3,
interactive=False,
)
upload_button = gr.components.UploadButton(
"Upload File(s)",
@ -413,9 +484,11 @@ class PrivateGptUi:
interactive=True,
render=False,
)
# When mode changes, set default system prompt
# When mode changes, set default system prompt, and other stuffs
mode.change(
self._set_current_mode, inputs=mode, outputs=system_prompt_input
self._set_current_mode,
inputs=mode,
outputs=[system_prompt_input, explanation_mode],
)
# On blur, set system prompt to use in queries
system_prompt_input.blur(