feat(llm): autopull ollama models (#2019)
Some checks failed
release-please / release-please (push) Waiting to run
tests / setup (push) Waiting to run
tests / ${{ matrix.quality-command }} (black) (push) Blocked by required conditions
tests / ${{ matrix.quality-command }} (mypy) (push) Blocked by required conditions
tests / ${{ matrix.quality-command }} (ruff) (push) Blocked by required conditions
tests / test (push) Blocked by required conditions
tests / all_checks_passed (push) Blocked by required conditions
publish docs / publish-docs (push) Has been cancelled

* chore: update ollama (llm)

* feat: allow to autopull ollama models

* fix: mypy

* chore: install always ollama client

* refactor: check connection and pull ollama method to utils

* docs: update ollama config with autopulling info
This commit is contained in:
Javier Martinez 2024-07-29 13:25:42 +02:00 committed by GitHub
parent dabf556dae
commit 20bad17c98
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 129 additions and 21 deletions

View file

@ -71,16 +71,46 @@ class EmbeddingComponent:
from llama_index.embeddings.ollama import ( # type: ignore
OllamaEmbedding,
)
from ollama import Client # type: ignore
except ImportError as e:
raise ImportError(
"Local dependencies not found, install with `poetry install --extras embeddings-ollama`"
) from e
ollama_settings = settings.ollama
# Calculate embedding model. If not provided tag, it will be use latest
model_name = (
ollama_settings.embedding_model + ":latest"
if ":" not in ollama_settings.embedding_model
else ollama_settings.embedding_model
)
self.embedding_model = OllamaEmbedding(
model_name=ollama_settings.embedding_model,
model_name=model_name,
base_url=ollama_settings.embedding_api_base,
)
if ollama_settings.autopull_models:
if ollama_settings.autopull_models:
from private_gpt.utils.ollama import (
check_connection,
pull_model,
)
# TODO: Reuse llama-index client when llama-index is updated
client = Client(
host=ollama_settings.embedding_api_base,
timeout=ollama_settings.request_timeout,
)
if not check_connection(client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(client, model_name)
case "azopenai":
try:
from llama_index.embeddings.azure_openai import ( # type: ignore

View file

@ -146,8 +146,15 @@ class LLMComponent:
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
}
self.llm = Ollama(
model=ollama_settings.llm_model,
# calculate llm model. If not provided tag, it will be use latest
model_name = (
ollama_settings.llm_model + ":latest"
if ":" not in ollama_settings.llm_model
else ollama_settings.llm_model
)
llm = Ollama(
model=model_name,
base_url=ollama_settings.api_base,
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
@ -155,6 +162,16 @@ class LLMComponent:
request_timeout=ollama_settings.request_timeout,
)
if ollama_settings.autopull_models:
from private_gpt.utils.ollama import check_connection, pull_model
if not check_connection(llm.client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(llm.client, model_name)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
@ -172,6 +189,8 @@ class LLMComponent:
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
self.llm = llm
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore

View file

@ -290,6 +290,10 @@ class OllamaSettings(BaseModel):
120.0,
description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ",
)
autopull_models: bool = Field(
False,
description="If set to True, the Ollama will automatically pull the models from the API base.",
)
class AzureOpenAISettings(BaseModel):

View file

@ -0,0 +1,32 @@
import logging
try:
from ollama import Client # type: ignore
except ImportError as e:
raise ImportError(
"Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`"
) from e
logger = logging.getLogger(__name__)
def check_connection(client: Client) -> bool:
try:
client.list()
return True
except Exception as e:
logger.error(f"Failed to connect to Ollama: {e!s}")
return False
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
try:
installed_models = [model["name"] for model in client.list().get("models", {})]
if model_name not in installed_models:
logger.info(f"Pulling model {model_name}. Please wait...")
client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
if raise_error:
raise e