diff --git a/private_gpt/server/recipes/summarize/summarize_router.py b/private_gpt/server/recipes/summarize/summarize_router.py index 2968e2d..c1770c3 100644 --- a/private_gpt/server/recipes/summarize/summarize_router.py +++ b/private_gpt/server/recipes/summarize/summarize_router.py @@ -59,23 +59,28 @@ def summarize( """ service: SummarizeService = request.state.injector.get(SummarizeService) - completion = service.summarize( - text=body.text, - instructions=body.instructions, - use_context=body.use_context, - context_filter=body.context_filter, - prompt=body.prompt, - stream=body.stream, - ) - - if isinstance(completion, str): - return SummarizeResponse( - summary=completion, + if body.stream: + completion_gen = service.stream_summarize( + text=body.text, + instructions=body.instructions, + use_context=body.use_context, + context_filter=body.context_filter, + prompt=body.prompt, ) - else: return StreamingResponse( to_openai_sse_stream( - response_generator=completion, + response_generator=completion_gen, ), media_type="text/event-stream", ) + else: + completion = service.summarize( + text=body.text, + instructions=body.instructions, + use_context=body.use_context, + context_filter=body.context_filter, + prompt=body.prompt, + ) + return SummarizeResponse( + summary=completion, + ) diff --git a/private_gpt/server/recipes/summarize/summarize_service.py b/private_gpt/server/recipes/summarize/summarize_service.py index c05393c..c4c2fb8 100644 --- a/private_gpt/server/recipes/summarize/summarize_service.py +++ b/private_gpt/server/recipes/summarize/summarize_service.py @@ -63,7 +63,7 @@ class SummarizeService: if doc_id in context_filter.docs_ids ] - def summarize( + def _summarize( self, use_context: bool = False, stream: bool = False, @@ -133,3 +133,37 @@ class SummarizeService: return response.response_gen else: raise TypeError(f"The result is not of a supported type: {type(response)}") + + def summarize( + self, + use_context: bool = False, + text: str | None = None, + instructions: str | None = None, + context_filter: ContextFilter | None = None, + prompt: str | None = None, + ) -> str: + return self._summarize( + use_context=use_context, + stream=False, + text=text, + instructions=instructions, + context_filter=context_filter, + prompt=prompt, + ) # type: ignore + + def stream_summarize( + self, + use_context: bool = False, + text: str | None = None, + instructions: str | None = None, + context_filter: ContextFilter | None = None, + prompt: str | None = None, + ) -> TokenGen: + return self._summarize( + use_context=use_context, + stream=True, + text=text, + instructions=instructions, + context_filter=context_filter, + prompt=prompt, + ) # type: ignore