refactor: check connection and pull ollama method to utils

This commit is contained in:
Javier Martinez 2024-07-29 11:40:15 +02:00
parent 47f4524fe9
commit 088158e139
No known key found for this signature in database
3 changed files with 53 additions and 21 deletions

View file

@ -92,21 +92,24 @@ class EmbeddingComponent:
)
if ollama_settings.autopull_models:
try:
if ollama_settings.autopull_models:
from private_gpt.utils.ollama import (
check_connection,
pull_model,
)
# TODO: Reuse llama-index client when llama-index is updated
client = Client(
host=ollama_settings.embedding_api_base,
timeout=ollama_settings.request_timeout,
)
installed_models = [
model["name"] for model in client.list().get("models", {})
]
if model_name not in installed_models:
logger.info(f"Pulling model {model_name}. Please wait...")
client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
if not check_connection(client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(client, model_name)
case "azopenai":
try:

View file

@ -163,17 +163,14 @@ class LLMComponent:
)
if ollama_settings.autopull_models:
try:
installed_models = [
model["name"]
for model in llm.client.list().get("models", {})
]
if model_name not in installed_models:
logger.info(f"Pulling model {model_name}. Please wait...")
llm.client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
from private_gpt.utils.ollama import check_connection, pull_model
if not check_connection(llm.client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(llm.client, model_name)
if (
ollama_settings.keep_alive

View file

@ -0,0 +1,32 @@
import logging
try:
from ollama import Client # type: ignore
except ImportError as e:
raise ImportError(
"Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`"
) from e
logger = logging.getLogger(__name__)
def check_connection(client: Client) -> bool:
try:
client.list()
return True
except Exception as e:
logger.error(f"Failed to connect to Ollama: {e!s}")
return False
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
try:
installed_models = [model["name"] for model in client.list().get("models", {})]
if model_name not in installed_models:
logger.info(f"Pulling model {model_name}. Please wait...")
client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
if raise_error:
raise e