Merge remote-tracking branch 'origin/main' into llama3

This commit is contained in:
Javier Martinez 2024-07-29 15:47:21 +02:00
commit a19e991fa1
No known key found for this signature in database
35 changed files with 1569 additions and 221 deletions

View file

@ -35,10 +35,10 @@ class LLMComponent:
)
except Exception as e:
logger.warning(
"Failed to download tokenizer %s. Falling back to "
"default tokenizer.",
settings.llm.tokenizer,
e,
f"Failed to download tokenizer {settings.llm.tokenizer}: {e!s}"
f"Please follow the instructions in the documentation to download it if needed: "
f"https://docs.privategpt.dev/installation/getting-started/troubleshooting#tokenizer-setup."
f"Falling back to default tokenizer."
)
logger.info("Initializing the LLM in mode=%s", llm_mode)
@ -146,8 +146,15 @@ class LLMComponent:
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
}
self.llm = Ollama(
model=ollama_settings.llm_model,
# calculate llm model. If not provided tag, it will be use latest
model_name = (
ollama_settings.llm_model + ":latest"
if ":" not in ollama_settings.llm_model
else ollama_settings.llm_model
)
llm = Ollama(
model=model_name,
base_url=ollama_settings.api_base,
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
@ -155,6 +162,16 @@ class LLMComponent:
request_timeout=ollama_settings.request_timeout,
)
if ollama_settings.autopull_models:
from private_gpt.utils.ollama import check_connection, pull_model
if not check_connection(llm.client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(llm.client, model_name)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
@ -172,6 +189,8 @@ class LLMComponent:
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
self.llm = llm
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
@ -190,5 +209,18 @@ class LLMComponent:
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
)
case "gemini":
try:
from llama_index.llms.gemini import ( # type: ignore
Gemini,
)
except ImportError as e:
raise ImportError(
"Google Gemini dependencies not found, install with `poetry install --extras llms-gemini`"
) from e
gemini_settings = settings.gemini
self.llm = Gemini(
model_name=gemini_settings.model, api_key=gemini_settings.api_key
)
case "mock":
self.llm = MockLLM()