From 0a9c57447b3a201629f140eb506216724acd4e14 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Mon, 29 Jul 2024 09:25:42 +0200 Subject: [PATCH] feat: allow to autopull ollama models --- .../embedding/embedding_component.py | 29 ++++++++++++++++++- private_gpt/components/llm/llm_component.py | 25 ++++++++++++++-- private_gpt/settings/settings.py | 4 +++ settings.yaml | 1 + 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 29ef1cf..d8e4589 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -68,6 +68,7 @@ class EmbeddingComponent: ) case "ollama": try: + from ollama import Client, AsyncClient # type: ignore from llama_index.embeddings.ollama import ( # type: ignore OllamaEmbedding, ) @@ -77,10 +78,36 @@ class EmbeddingComponent: ) from e ollama_settings = settings.ollama + + # calculate embedding model. If not provided tag, it will be use latest model + model_name = ( + ollama_settings.embedding_model + ':latest' + if ":" not in ollama_settings.embedding_model + else ollama_settings.embedding_model + ) + self.embedding_model = OllamaEmbedding( - model_name=ollama_settings.embedding_model, + model_name=model_name, base_url=ollama_settings.embedding_api_base, ) + + if ollama_settings.autopull_models: + try: + # TODO: Reuse llama-index client when llama-index is updated + client = Client( + host=ollama_settings.embedding_api_base, + timeout=ollama_settings.request_timeout, + ) + installed_models = [model['name'] for model in client.list().get("models", {})] + if model_name not in installed_models: + logger.info( + f"Pulling model {model_name}. Please wait..." + ) + client.pull(model_name) + logger.info(f"Model {model_name} pulled successfully") + except Exception as e: + logger.error(f"Failed to pull model {model_name}: {e!s}") + case "azopenai": try: from llama_index.embeddings.azure_openai import ( # type: ignore diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index d4ab81f..edf8af7 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -146,8 +146,15 @@ class LLMComponent: "repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp } - self.llm = Ollama( - model=ollama_settings.llm_model, + # calculate llm model. If not provided tag, it will be use latest model + model_name = ( + ollama_settings.llm_model + ':latest' + if ":" not in ollama_settings.llm_model + else ollama_settings.llm_model + ) + + llm = Ollama( + model=model_name, base_url=ollama_settings.api_base, temperature=settings.llm.temperature, context_window=settings.llm.context_window, @@ -155,6 +162,18 @@ class LLMComponent: request_timeout=ollama_settings.request_timeout, ) + if ollama_settings.autopull_models: + try: + installed_models = [model['name'] for model in llm.client.list().get("models", {})] + if model_name not in installed_models: + logger.info( + f"Pulling model {model_name}. Please wait..." + ) + llm.client.pull(model_name) + logger.info(f"Model {model_name} pulled successfully") + except Exception as e: + logger.error(f"Failed to pull model {model_name}: {e!s}") + if ( ollama_settings.keep_alive != ollama_settings.model_fields["keep_alive"].default @@ -172,6 +191,8 @@ class LLMComponent: Ollama.complete = add_keep_alive(Ollama.complete) Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) + self.llm = llm + case "azopenai": try: from llama_index.llms.azure_openai import ( # type: ignore diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 30514dd..40b96ae 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -290,6 +290,10 @@ class OllamaSettings(BaseModel): 120.0, description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ", ) + autopull_models: bool = Field( + False, + description="If set to True, the Ollama will automatically pull the models from the API base.", + ) class AzureOpenAISettings(BaseModel): diff --git a/settings.yaml b/settings.yaml index cd8fccd..cd977a0 100644 --- a/settings.yaml +++ b/settings.yaml @@ -117,6 +117,7 @@ ollama: embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama keep_alive: 5m request_timeout: 120.0 + autopull_models: true azopenai: api_key: ${AZ_OPENAI_API_KEY:}