From a614b349d3a8a3cc833708460e402f45163b336e Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Tue, 30 Jul 2024 16:05:05 +0200 Subject: [PATCH] feat: add summary recipe --- private_gpt/launcher.py | 2 + .../server/recipes/summarize/__init__.py | 0 .../recipes/summarize/summarize_router.py | 81 +++++++++++ .../recipes/summarize/summarize_service.py | 135 ++++++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 private_gpt/server/recipes/summarize/__init__.py create mode 100644 private_gpt/server/recipes/summarize/summarize_router.py create mode 100644 private_gpt/server/recipes/summarize/summarize_service.py diff --git a/private_gpt/launcher.py b/private_gpt/launcher.py index 43bd803..191cb90 100644 --- a/private_gpt/launcher.py +++ b/private_gpt/launcher.py @@ -15,6 +15,7 @@ from private_gpt.server.completions.completions_router import completions_router from private_gpt.server.embeddings.embeddings_router import embeddings_router from private_gpt.server.health.health_router import health_router from private_gpt.server.ingest.ingest_router import ingest_router +from private_gpt.server.recipes.summarize.summarize_router import summarize_router from private_gpt.settings.settings import Settings logger = logging.getLogger(__name__) @@ -32,6 +33,7 @@ def create_app(root_injector: Injector) -> FastAPI: app.include_router(chat_router) app.include_router(chunks_router) app.include_router(ingest_router) + app.include_router(summarize_router) app.include_router(embeddings_router) app.include_router(health_router) diff --git a/private_gpt/server/recipes/summarize/__init__.py b/private_gpt/server/recipes/summarize/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/private_gpt/server/recipes/summarize/summarize_router.py b/private_gpt/server/recipes/summarize/summarize_router.py new file mode 100644 index 0000000..2968e2d --- /dev/null +++ b/private_gpt/server/recipes/summarize/summarize_router.py @@ -0,0 +1,81 @@ +from fastapi import APIRouter, Depends, Request +from pydantic import BaseModel +from starlette.responses import StreamingResponse + +from private_gpt.open_ai.extensions.context_filter import ContextFilter +from private_gpt.open_ai.openai_models import ( + to_openai_sse_stream, +) +from private_gpt.server.recipes.summarize.summarize_service import SummarizeService +from private_gpt.server.utils.auth import authenticated + +summarize_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) + + +class SummarizeBody(BaseModel): + text: str | None = None + use_context: bool = False + context_filter: ContextFilter | None = None + prompt: str | None = None + instructions: str | None = None + stream: bool = False + + +class SummarizeResponse(BaseModel): + summary: str + + +@summarize_router.post( + "/summarize", + response_model=None, + summary="Summarize", + responses={200: {"model": SummarizeResponse}}, + tags=["Recipes"], +) +def summarize( + request: Request, body: SummarizeBody +) -> SummarizeResponse | StreamingResponse: + """Given a text, the model will return a summary. + + Optionally include `instructions` to influence the way the summary is generated. + + If `use_context` + is set to `true`, the model will also use the content coming from the ingested + documents in the summary. The documents being used can + be filtered by their metadata using the `context_filter`. + Ingested documents metadata can be found using `/ingest/list` endpoint. + If you want all ingested documents to be used, remove `context_filter` altogether. + + If `prompt` is set, it will be used as the prompt for the summarization, + otherwise the default prompt will be used. + + When using `'stream': true`, the API will return data chunks following [OpenAI's + streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): + ``` + {"id":"12345","object":"completion.chunk","created":1694268190, + "model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"}, + "finish_reason":null}]} + ``` + """ + service: SummarizeService = request.state.injector.get(SummarizeService) + + completion = service.summarize( + text=body.text, + instructions=body.instructions, + use_context=body.use_context, + context_filter=body.context_filter, + prompt=body.prompt, + stream=body.stream, + ) + + if isinstance(completion, str): + return SummarizeResponse( + summary=completion, + ) + else: + return StreamingResponse( + to_openai_sse_stream( + response_generator=completion, + ), + media_type="text/event-stream", + ) diff --git a/private_gpt/server/recipes/summarize/summarize_service.py b/private_gpt/server/recipes/summarize/summarize_service.py new file mode 100644 index 0000000..c05393c --- /dev/null +++ b/private_gpt/server/recipes/summarize/summarize_service.py @@ -0,0 +1,135 @@ +from itertools import chain + +from injector import inject, singleton +from llama_index.core import ( + Document, + StorageContext, + SummaryIndex, +) +from llama_index.core.base.response.schema import Response, StreamingResponse +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.response_synthesizers import ResponseMode +from llama_index.core.storage.docstore.types import RefDocInfo +from llama_index.core.types import TokenGen + +from private_gpt.components.embedding.embedding_component import EmbeddingComponent +from private_gpt.components.llm.llm_component import LLMComponent +from private_gpt.components.node_store.node_store_component import NodeStoreComponent +from private_gpt.components.vector_store.vector_store_component import ( + VectorStoreComponent, +) +from private_gpt.open_ai.extensions.context_filter import ContextFilter + +DEFAULT_SUMMARIZE_PROMPT = ( + "Provide a comprehensive summary of the provided context information. " + "The summary should cover all the key points and main ideas presented in " + "the original text, while also condensing the information into a concise " + "and easy-to-understand format. Please ensure that the summary includes " + "relevant details and examples that support the main ideas, while avoiding " + "any unnecessary information or repetition." +) + + +@singleton +class SummarizeService: + @inject + def __init__( + self, + llm_component: LLMComponent, + node_store_component: NodeStoreComponent, + vector_store_component: VectorStoreComponent, + embedding_component: EmbeddingComponent, + ) -> None: + self.llm_component = llm_component + self.node_store_component = node_store_component + self.vector_store_component = vector_store_component + self.embedding_component = embedding_component + self.storage_context = StorageContext.from_defaults( + vector_store=vector_store_component.vector_store, + docstore=node_store_component.doc_store, + index_store=node_store_component.index_store, + ) + + @staticmethod + def _filter_ref_docs( + ref_docs: dict[str, RefDocInfo], context_filter: ContextFilter | None + ) -> list[RefDocInfo]: + if context_filter is None or not context_filter.docs_ids: + return list(ref_docs.values()) + + return [ + ref_doc + for doc_id, ref_doc in ref_docs.items() + if doc_id in context_filter.docs_ids + ] + + def summarize( + self, + use_context: bool = False, + stream: bool = False, + text: str | None = None, + instructions: str | None = None, + context_filter: ContextFilter | None = None, + prompt: str | None = None, + ) -> str | TokenGen: + + nodes_to_summarize = [] + + # Add text to summarize + if text: + text_documents = [Document(text=text)] + nodes_to_summarize += ( + SentenceSplitter.from_defaults().get_nodes_from_documents( + text_documents + ) + ) + + # Add context documents to summarize + if use_context: + # 1. Recover all ref docs + ref_docs: dict[ + str, RefDocInfo + ] | None = self.storage_context.docstore.get_all_ref_doc_info() + if ref_docs is None: + raise ValueError("No documents have been ingested yet.") + + # 2. Filter documents based on context_filter (if provided) + filtered_ref_docs = self._filter_ref_docs(ref_docs, context_filter) + + # 3. Get all nodes from the filtered documents + filtered_node_ids = chain.from_iterable( + [ref_doc.node_ids for ref_doc in filtered_ref_docs] + ) + filtered_nodes = self.storage_context.docstore.get_nodes( + node_ids=list(filtered_node_ids), + ) + + nodes_to_summarize += filtered_nodes + + # Create a SummaryIndex to summarize the nodes + summary_index = SummaryIndex( + nodes=nodes_to_summarize, + storage_context=StorageContext.from_defaults(), # In memory SummaryIndex + show_progress=True, + ) + + # Make a tree summarization query + # above the set of all candidate nodes + query_engine = summary_index.as_query_engine( + llm=self.llm_component.llm, + response_mode=ResponseMode.TREE_SUMMARIZE, + streaming=stream, + use_async=False, + ) + + prompt = prompt or DEFAULT_SUMMARIZE_PROMPT + + summarize_query = prompt + "\n" + (instructions or "") + + response = query_engine.query(summarize_query) + if isinstance(response, Response): + return response.response or "" + elif isinstance(response, StreamingResponse): + return response.response_gen + else: + raise TypeError(f"The result is not of a supported type: {type(response)}")