WIP more prompt format, and more maintainable

This commit is contained in:
Louis 2023-12-03 00:48:43 +01:00
parent 3d301d0c6f
commit 76faffb269
11 changed files with 476 additions and 217 deletions

View file

@ -9,7 +9,7 @@ import gradio as gr # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH
@ -55,6 +55,27 @@ class Source(BaseModel):
return curated_sources
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
# if isinstance(delta, str):
full_response += str(delta)
# elif isinstance(delta, ChatResponse):
# full_response += delta.delta or ""
yield full_response
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
yield full_response
@singleton
class PrivateGptUi:
@inject
@ -72,26 +93,6 @@ class PrivateGptUi:
self._ui_block = None
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
yield full_response
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(