mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 23:22:57 +01:00
Working refactor. Dependency clean-up pending.
This commit is contained in:
parent
12f3a39e8a
commit
d0a7d991a2
20 changed files with 877 additions and 907 deletions
|
|
@ -1,15 +1,16 @@
|
|||
import logging
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index import set_global_tokenizer
|
||||
from llama_index.llms import MockLLM
|
||||
from llama_index.llms.base import LLM
|
||||
from llama_index.core.llms import LLM, MockLLM
|
||||
from llama_index.core.utils import set_global_tokenizer
|
||||
from llama_index.core.settings import Settings as LlamaIndexSettings
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
from private_gpt.components.llm.prompt_helper import get_prompt_style
|
||||
from private_gpt.paths import models_cache_path, models_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ class LLMComponent:
|
|||
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
||||
match settings.llm.mode:
|
||||
case "local":
|
||||
from llama_index.llms import LlamaCPP
|
||||
from llama_index.llms.llama_cpp import LlamaCPP
|
||||
|
||||
prompt_style = get_prompt_style(settings.local.prompt_style)
|
||||
|
||||
|
|
@ -41,6 +42,7 @@ class LLMComponent:
|
|||
max_new_tokens=settings.llm.max_new_tokens,
|
||||
context_window=settings.llm.context_window,
|
||||
generate_kwargs={},
|
||||
callback_manager=LlamaIndexSettings.callback_manager,
|
||||
# All to GPU
|
||||
model_kwargs={"n_gpu_layers": -1, "offload_kqv": True},
|
||||
# transform inputs into Llama2 format
|
||||
|
|
@ -58,7 +60,7 @@ class LLMComponent:
|
|||
context_window=settings.llm.context_window,
|
||||
)
|
||||
case "openai":
|
||||
from llama_index.llms import OpenAI
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
openai_settings = settings.openai
|
||||
self.llm = OpenAI(
|
||||
|
|
@ -67,7 +69,7 @@ class LLMComponent:
|
|||
model=openai_settings.model,
|
||||
)
|
||||
case "openailike":
|
||||
from llama_index.llms import OpenAILike
|
||||
from llama_index.llms.openai_like import OpenAILike
|
||||
|
||||
openai_settings = settings.openai
|
||||
self.llm = OpenAILike(
|
||||
|
|
@ -81,7 +83,7 @@ class LLMComponent:
|
|||
case "mock":
|
||||
self.llm = MockLLM()
|
||||
case "ollama":
|
||||
from llama_index.llms import Ollama
|
||||
from llama_index.llms.ollama import Ollama
|
||||
|
||||
ollama_settings = settings.ollama
|
||||
self.llm = Ollama(
|
||||
|
|
|
|||
|
|
@ -3,11 +3,7 @@ import logging
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
from llama_index.llms.llama_utils import (
|
||||
completion_to_prompt,
|
||||
messages_to_prompt,
|
||||
)
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -73,7 +69,9 @@ class DefaultPromptStyle(AbstractPromptStyle):
|
|||
|
||||
|
||||
class Llama2PromptStyle(AbstractPromptStyle):
|
||||
"""Simple prompt style that just uses the default llama_utils functions.
|
||||
"""Simple prompt style that uses llama 2 prompt style.
|
||||
|
||||
Inspired by llama_index/legacy/llms/llama_utils.py
|
||||
|
||||
It transforms the sequence of messages into a prompt that should look like:
|
||||
```text
|
||||
|
|
@ -83,11 +81,61 @@ class Llama2PromptStyle(AbstractPromptStyle):
|
|||
```
|
||||
"""
|
||||
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. \
|
||||
Always answer as helpfully as possible and follow ALL given instructions. \
|
||||
Do not speculate or make up information. \
|
||||
Do not reference any given instructions or context. \
|
||||
"""
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
return messages_to_prompt(messages)
|
||||
string_messages: list[str] = []
|
||||
if messages[0].role == MessageRole.SYSTEM:
|
||||
# pull out the system message (if it exists in messages)
|
||||
system_message_str = messages[0].content or ""
|
||||
messages = messages[1:]
|
||||
else:
|
||||
system_message_str = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}"
|
||||
|
||||
for i in range(0, len(messages), 2):
|
||||
# first message should always be a user
|
||||
user_message = messages[i]
|
||||
assert user_message.role == MessageRole.USER
|
||||
|
||||
if i == 0:
|
||||
# make sure system prompt is included at the start
|
||||
str_message = f"{self.BOS} {self.B_INST} {system_message_str} "
|
||||
else:
|
||||
# end previous user-assistant interaction
|
||||
string_messages[-1] += f" {self.EOS}"
|
||||
# no need to include system prompt
|
||||
str_message = f"{self.BOS} {self.B_INST} "
|
||||
|
||||
# include user message content
|
||||
str_message += f"{user_message.content} {self.E_INST}"
|
||||
|
||||
if len(messages) > (i + 1):
|
||||
# if assistant message exists, add to str_message
|
||||
assistant_message = messages[i + 1]
|
||||
assert assistant_message.role == MessageRole.ASSISTANT
|
||||
str_message += f" {assistant_message.content}"
|
||||
|
||||
string_messages.append(str_message)
|
||||
|
||||
return "".join(string_messages)
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return completion_to_prompt(completion)
|
||||
system_prompt_str = self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
return (
|
||||
f"{self.BOS} {self.B_INST} {self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} "
|
||||
f"{completion.strip()} {self.E_INST}"
|
||||
)
|
||||
|
||||
|
||||
class TagPromptStyle(AbstractPromptStyle):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue