diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 714bb6b..89a577b 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -92,21 +92,24 @@ class EmbeddingComponent: ) if ollama_settings.autopull_models: - try: + if ollama_settings.autopull_models: + from private_gpt.utils.ollama import ( + check_connection, + pull_model, + ) + # TODO: Reuse llama-index client when llama-index is updated client = Client( host=ollama_settings.embedding_api_base, timeout=ollama_settings.request_timeout, ) - installed_models = [ - model["name"] for model in client.list().get("models", {}) - ] - if model_name not in installed_models: - logger.info(f"Pulling model {model_name}. Please wait...") - client.pull(model_name) - logger.info(f"Model {model_name} pulled successfully") - except Exception as e: - logger.error(f"Failed to pull model {model_name}: {e!s}") + + if not check_connection(client): + raise ValueError( + f"Failed to connect to Ollama, " + f"check if Ollama server is running on {ollama_settings.api_base}" + ) + pull_model(client, model_name) case "azopenai": try: diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index 0107b58..e3a0281 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -163,17 +163,14 @@ class LLMComponent: ) if ollama_settings.autopull_models: - try: - installed_models = [ - model["name"] - for model in llm.client.list().get("models", {}) - ] - if model_name not in installed_models: - logger.info(f"Pulling model {model_name}. Please wait...") - llm.client.pull(model_name) - logger.info(f"Model {model_name} pulled successfully") - except Exception as e: - logger.error(f"Failed to pull model {model_name}: {e!s}") + from private_gpt.utils.ollama import check_connection, pull_model + + if not check_connection(llm.client): + raise ValueError( + f"Failed to connect to Ollama, " + f"check if Ollama server is running on {ollama_settings.api_base}" + ) + pull_model(llm.client, model_name) if ( ollama_settings.keep_alive diff --git a/private_gpt/utils/ollama.py b/private_gpt/utils/ollama.py new file mode 100644 index 0000000..41c7ecc --- /dev/null +++ b/private_gpt/utils/ollama.py @@ -0,0 +1,32 @@ +import logging + +try: + from ollama import Client # type: ignore +except ImportError as e: + raise ImportError( + "Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`" + ) from e + +logger = logging.getLogger(__name__) + + +def check_connection(client: Client) -> bool: + try: + client.list() + return True + except Exception as e: + logger.error(f"Failed to connect to Ollama: {e!s}") + return False + + +def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: + try: + installed_models = [model["name"] for model in client.list().get("models", {})] + if model_name not in installed_models: + logger.info(f"Pulling model {model_name}. Please wait...") + client.pull(model_name) + logger.info(f"Model {model_name} pulled successfully") + except Exception as e: + logger.error(f"Failed to pull model {model_name}: {e!s}") + if raise_error: + raise e