mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 20:12:55 +01:00
refactor: check connection and pull ollama method to utils
This commit is contained in:
parent
47f4524fe9
commit
088158e139
3 changed files with 53 additions and 21 deletions
|
|
@ -92,21 +92,24 @@ class EmbeddingComponent:
|
|||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
try:
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import (
|
||||
check_connection,
|
||||
pull_model,
|
||||
)
|
||||
|
||||
# TODO: Reuse llama-index client when llama-index is updated
|
||||
client = Client(
|
||||
host=ollama_settings.embedding_api_base,
|
||||
timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
installed_models = [
|
||||
model["name"] for model in client.list().get("models", {})
|
||||
]
|
||||
if model_name not in installed_models:
|
||||
logger.info(f"Pulling model {model_name}. Please wait...")
|
||||
client.pull(model_name)
|
||||
logger.info(f"Model {model_name} pulled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull model {model_name}: {e!s}")
|
||||
|
||||
if not check_connection(client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(client, model_name)
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -163,17 +163,14 @@ class LLMComponent:
|
|||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
try:
|
||||
installed_models = [
|
||||
model["name"]
|
||||
for model in llm.client.list().get("models", {})
|
||||
]
|
||||
if model_name not in installed_models:
|
||||
logger.info(f"Pulling model {model_name}. Please wait...")
|
||||
llm.client.pull(model_name)
|
||||
logger.info(f"Model {model_name} pulled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull model {model_name}: {e!s}")
|
||||
from private_gpt.utils.ollama import check_connection, pull_model
|
||||
|
||||
if not check_connection(llm.client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(llm.client, model_name)
|
||||
|
||||
if (
|
||||
ollama_settings.keep_alive
|
||||
|
|
|
|||
32
private_gpt/utils/ollama.py
Normal file
32
private_gpt/utils/ollama.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import logging
|
||||
|
||||
try:
|
||||
from ollama import Client # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`"
|
||||
) from e
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_connection(client: Client) -> bool:
|
||||
try:
|
||||
client.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Ollama: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
|
||||
try:
|
||||
installed_models = [model["name"] for model in client.list().get("models", {})]
|
||||
if model_name not in installed_models:
|
||||
logger.info(f"Pulling model {model_name}. Please wait...")
|
||||
client.pull(model_name)
|
||||
logger.info(f"Model {model_name} pulled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull model {model_name}: {e!s}")
|
||||
if raise_error:
|
||||
raise e
|
||||
Loading…
Add table
Add a link
Reference in a new issue