mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 17:05:41 +01:00
WIP more prompt format, and more maintainable
This commit is contained in:
parent
3d301d0c6f
commit
76faffb269
11 changed files with 476 additions and 217 deletions
|
|
@ -9,7 +9,7 @@ import gradio as gr # type: ignore
|
|||
from fastapi import FastAPI
|
||||
from gradio.themes.utils.colors import slate # type: ignore
|
||||
from injector import inject, singleton
|
||||
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.constants import PROJECT_ROOT_PATH
|
||||
|
|
@ -55,6 +55,27 @@ class Source(BaseModel):
|
|||
return curated_sources
|
||||
|
||||
|
||||
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
|
||||
full_response: str = ""
|
||||
stream = completion_gen.response
|
||||
for delta in stream:
|
||||
# if isinstance(delta, str):
|
||||
full_response += str(delta)
|
||||
# elif isinstance(delta, ChatResponse):
|
||||
# full_response += delta.delta or ""
|
||||
yield full_response
|
||||
|
||||
if completion_gen.sources:
|
||||
full_response += SOURCES_SEPARATOR
|
||||
cur_sources = Source.curate_sources(completion_gen.sources)
|
||||
sources_text = "\n\n\n".join(
|
||||
f"{index}. {source.file} (page {source.page})"
|
||||
for index, source in enumerate(cur_sources, start=1)
|
||||
)
|
||||
full_response += sources_text
|
||||
yield full_response
|
||||
|
||||
|
||||
@singleton
|
||||
class PrivateGptUi:
|
||||
@inject
|
||||
|
|
@ -72,26 +93,6 @@ class PrivateGptUi:
|
|||
self._ui_block = None
|
||||
|
||||
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
|
||||
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
|
||||
full_response: str = ""
|
||||
stream = completion_gen.response
|
||||
for delta in stream:
|
||||
if isinstance(delta, str):
|
||||
full_response += str(delta)
|
||||
elif isinstance(delta, ChatResponse):
|
||||
full_response += delta.delta or ""
|
||||
yield full_response
|
||||
|
||||
if completion_gen.sources:
|
||||
full_response += SOURCES_SEPARATOR
|
||||
cur_sources = Source.curate_sources(completion_gen.sources)
|
||||
sources_text = "\n\n\n".join(
|
||||
f"{index}. {source.file} (page {source.page})"
|
||||
for index, source in enumerate(cur_sources, start=1)
|
||||
)
|
||||
full_response += sources_text
|
||||
yield full_response
|
||||
|
||||
def build_history() -> list[ChatMessage]:
|
||||
history_messages: list[ChatMessage] = list(
|
||||
itertools.chain(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue