mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 20:12:55 +01:00
feat: add summary recipe
This commit is contained in:
parent
d080969407
commit
a614b349d3
4 changed files with 218 additions and 0 deletions
|
|
@ -15,6 +15,7 @@ from private_gpt.server.completions.completions_router import completions_router
|
||||||
from private_gpt.server.embeddings.embeddings_router import embeddings_router
|
from private_gpt.server.embeddings.embeddings_router import embeddings_router
|
||||||
from private_gpt.server.health.health_router import health_router
|
from private_gpt.server.health.health_router import health_router
|
||||||
from private_gpt.server.ingest.ingest_router import ingest_router
|
from private_gpt.server.ingest.ingest_router import ingest_router
|
||||||
|
from private_gpt.server.recipes.summarize.summarize_router import summarize_router
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -32,6 +33,7 @@ def create_app(root_injector: Injector) -> FastAPI:
|
||||||
app.include_router(chat_router)
|
app.include_router(chat_router)
|
||||||
app.include_router(chunks_router)
|
app.include_router(chunks_router)
|
||||||
app.include_router(ingest_router)
|
app.include_router(ingest_router)
|
||||||
|
app.include_router(summarize_router)
|
||||||
app.include_router(embeddings_router)
|
app.include_router(embeddings_router)
|
||||||
app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
|
|
||||||
|
|
|
||||||
0
private_gpt/server/recipes/summarize/__init__.py
Normal file
0
private_gpt/server/recipes/summarize/__init__.py
Normal file
81
private_gpt/server/recipes/summarize/summarize_router.py
Normal file
81
private_gpt/server/recipes/summarize/summarize_router.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||||
|
from private_gpt.open_ai.openai_models import (
|
||||||
|
to_openai_sse_stream,
|
||||||
|
)
|
||||||
|
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
|
||||||
|
from private_gpt.server.utils.auth import authenticated
|
||||||
|
|
||||||
|
summarize_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizeBody(BaseModel):
|
||||||
|
text: str | None = None
|
||||||
|
use_context: bool = False
|
||||||
|
context_filter: ContextFilter | None = None
|
||||||
|
prompt: str | None = None
|
||||||
|
instructions: str | None = None
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizeResponse(BaseModel):
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
|
||||||
|
@summarize_router.post(
|
||||||
|
"/summarize",
|
||||||
|
response_model=None,
|
||||||
|
summary="Summarize",
|
||||||
|
responses={200: {"model": SummarizeResponse}},
|
||||||
|
tags=["Recipes"],
|
||||||
|
)
|
||||||
|
def summarize(
|
||||||
|
request: Request, body: SummarizeBody
|
||||||
|
) -> SummarizeResponse | StreamingResponse:
|
||||||
|
"""Given a text, the model will return a summary.
|
||||||
|
|
||||||
|
Optionally include `instructions` to influence the way the summary is generated.
|
||||||
|
|
||||||
|
If `use_context`
|
||||||
|
is set to `true`, the model will also use the content coming from the ingested
|
||||||
|
documents in the summary. The documents being used can
|
||||||
|
be filtered by their metadata using the `context_filter`.
|
||||||
|
Ingested documents metadata can be found using `/ingest/list` endpoint.
|
||||||
|
If you want all ingested documents to be used, remove `context_filter` altogether.
|
||||||
|
|
||||||
|
If `prompt` is set, it will be used as the prompt for the summarization,
|
||||||
|
otherwise the default prompt will be used.
|
||||||
|
|
||||||
|
When using `'stream': true`, the API will return data chunks following [OpenAI's
|
||||||
|
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
|
||||||
|
```
|
||||||
|
{"id":"12345","object":"completion.chunk","created":1694268190,
|
||||||
|
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
|
||||||
|
"finish_reason":null}]}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
service: SummarizeService = request.state.injector.get(SummarizeService)
|
||||||
|
|
||||||
|
completion = service.summarize(
|
||||||
|
text=body.text,
|
||||||
|
instructions=body.instructions,
|
||||||
|
use_context=body.use_context,
|
||||||
|
context_filter=body.context_filter,
|
||||||
|
prompt=body.prompt,
|
||||||
|
stream=body.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(completion, str):
|
||||||
|
return SummarizeResponse(
|
||||||
|
summary=completion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return StreamingResponse(
|
||||||
|
to_openai_sse_stream(
|
||||||
|
response_generator=completion,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
135
private_gpt/server/recipes/summarize/summarize_service.py
Normal file
135
private_gpt/server/recipes/summarize/summarize_service.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
from injector import inject, singleton
|
||||||
|
from llama_index.core import (
|
||||||
|
Document,
|
||||||
|
StorageContext,
|
||||||
|
SummaryIndex,
|
||||||
|
)
|
||||||
|
from llama_index.core.base.response.schema import Response, StreamingResponse
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
from llama_index.core.response_synthesizers import ResponseMode
|
||||||
|
from llama_index.core.storage.docstore.types import RefDocInfo
|
||||||
|
from llama_index.core.types import TokenGen
|
||||||
|
|
||||||
|
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||||
|
from private_gpt.components.llm.llm_component import LLMComponent
|
||||||
|
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||||
|
from private_gpt.components.vector_store.vector_store_component import (
|
||||||
|
VectorStoreComponent,
|
||||||
|
)
|
||||||
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||||
|
|
||||||
|
DEFAULT_SUMMARIZE_PROMPT = (
|
||||||
|
"Provide a comprehensive summary of the provided context information. "
|
||||||
|
"The summary should cover all the key points and main ideas presented in "
|
||||||
|
"the original text, while also condensing the information into a concise "
|
||||||
|
"and easy-to-understand format. Please ensure that the summary includes "
|
||||||
|
"relevant details and examples that support the main ideas, while avoiding "
|
||||||
|
"any unnecessary information or repetition."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class SummarizeService:
|
||||||
|
@inject
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_component: LLMComponent,
|
||||||
|
node_store_component: NodeStoreComponent,
|
||||||
|
vector_store_component: VectorStoreComponent,
|
||||||
|
embedding_component: EmbeddingComponent,
|
||||||
|
) -> None:
|
||||||
|
self.llm_component = llm_component
|
||||||
|
self.node_store_component = node_store_component
|
||||||
|
self.vector_store_component = vector_store_component
|
||||||
|
self.embedding_component = embedding_component
|
||||||
|
self.storage_context = StorageContext.from_defaults(
|
||||||
|
vector_store=vector_store_component.vector_store,
|
||||||
|
docstore=node_store_component.doc_store,
|
||||||
|
index_store=node_store_component.index_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_ref_docs(
|
||||||
|
ref_docs: dict[str, RefDocInfo], context_filter: ContextFilter | None
|
||||||
|
) -> list[RefDocInfo]:
|
||||||
|
if context_filter is None or not context_filter.docs_ids:
|
||||||
|
return list(ref_docs.values())
|
||||||
|
|
||||||
|
return [
|
||||||
|
ref_doc
|
||||||
|
for doc_id, ref_doc in ref_docs.items()
|
||||||
|
if doc_id in context_filter.docs_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
def summarize(
|
||||||
|
self,
|
||||||
|
use_context: bool = False,
|
||||||
|
stream: bool = False,
|
||||||
|
text: str | None = None,
|
||||||
|
instructions: str | None = None,
|
||||||
|
context_filter: ContextFilter | None = None,
|
||||||
|
prompt: str | None = None,
|
||||||
|
) -> str | TokenGen:
|
||||||
|
|
||||||
|
nodes_to_summarize = []
|
||||||
|
|
||||||
|
# Add text to summarize
|
||||||
|
if text:
|
||||||
|
text_documents = [Document(text=text)]
|
||||||
|
nodes_to_summarize += (
|
||||||
|
SentenceSplitter.from_defaults().get_nodes_from_documents(
|
||||||
|
text_documents
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add context documents to summarize
|
||||||
|
if use_context:
|
||||||
|
# 1. Recover all ref docs
|
||||||
|
ref_docs: dict[
|
||||||
|
str, RefDocInfo
|
||||||
|
] | None = self.storage_context.docstore.get_all_ref_doc_info()
|
||||||
|
if ref_docs is None:
|
||||||
|
raise ValueError("No documents have been ingested yet.")
|
||||||
|
|
||||||
|
# 2. Filter documents based on context_filter (if provided)
|
||||||
|
filtered_ref_docs = self._filter_ref_docs(ref_docs, context_filter)
|
||||||
|
|
||||||
|
# 3. Get all nodes from the filtered documents
|
||||||
|
filtered_node_ids = chain.from_iterable(
|
||||||
|
[ref_doc.node_ids for ref_doc in filtered_ref_docs]
|
||||||
|
)
|
||||||
|
filtered_nodes = self.storage_context.docstore.get_nodes(
|
||||||
|
node_ids=list(filtered_node_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes_to_summarize += filtered_nodes
|
||||||
|
|
||||||
|
# Create a SummaryIndex to summarize the nodes
|
||||||
|
summary_index = SummaryIndex(
|
||||||
|
nodes=nodes_to_summarize,
|
||||||
|
storage_context=StorageContext.from_defaults(), # In memory SummaryIndex
|
||||||
|
show_progress=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make a tree summarization query
|
||||||
|
# above the set of all candidate nodes
|
||||||
|
query_engine = summary_index.as_query_engine(
|
||||||
|
llm=self.llm_component.llm,
|
||||||
|
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||||
|
streaming=stream,
|
||||||
|
use_async=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = prompt or DEFAULT_SUMMARIZE_PROMPT
|
||||||
|
|
||||||
|
summarize_query = prompt + "\n" + (instructions or "")
|
||||||
|
|
||||||
|
response = query_engine.query(summarize_query)
|
||||||
|
if isinstance(response, Response):
|
||||||
|
return response.response or ""
|
||||||
|
elif isinstance(response, StreamingResponse):
|
||||||
|
return response.response_gen
|
||||||
|
else:
|
||||||
|
raise TypeError(f"The result is not of a supported type: {type(response)}")
|
||||||
Loading…
Add table
Add a link
Reference in a new issue