diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index d8e4589..714bb6b 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -68,10 +68,10 @@ class EmbeddingComponent: ) case "ollama": try: - from ollama import Client, AsyncClient # type: ignore from llama_index.embeddings.ollama import ( # type: ignore OllamaEmbedding, ) + from ollama import Client # type: ignore except ImportError as e: raise ImportError( "Local dependencies not found, install with `poetry install --extras embeddings-ollama`" @@ -79,9 +79,9 @@ class EmbeddingComponent: ollama_settings = settings.ollama - # calculate embedding model. If not provided tag, it will be use latest model + # Calculate embedding model. If not provided tag, it will be use latest model_name = ( - ollama_settings.embedding_model + ':latest' + ollama_settings.embedding_model + ":latest" if ":" not in ollama_settings.embedding_model else ollama_settings.embedding_model ) @@ -98,11 +98,11 @@ class EmbeddingComponent: host=ollama_settings.embedding_api_base, timeout=ollama_settings.request_timeout, ) - installed_models = [model['name'] for model in client.list().get("models", {})] + installed_models = [ + model["name"] for model in client.list().get("models", {}) + ] if model_name not in installed_models: - logger.info( - f"Pulling model {model_name}. Please wait..." - ) + logger.info(f"Pulling model {model_name}. Please wait...") client.pull(model_name) logger.info(f"Model {model_name} pulled successfully") except Exception as e: diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index edf8af7..0107b58 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -146,9 +146,9 @@ class LLMComponent: "repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp } - # calculate llm model. If not provided tag, it will be use latest model + # calculate llm model. If not provided tag, it will be use latest model_name = ( - ollama_settings.llm_model + ':latest' + ollama_settings.llm_model + ":latest" if ":" not in ollama_settings.llm_model else ollama_settings.llm_model ) @@ -164,11 +164,12 @@ class LLMComponent: if ollama_settings.autopull_models: try: - installed_models = [model['name'] for model in llm.client.list().get("models", {})] + installed_models = [ + model["name"] + for model in llm.client.list().get("models", {}) + ] if model_name not in installed_models: - logger.info( - f"Pulling model {model_name}. Please wait..." - ) + logger.info(f"Pulling model {model_name}. Please wait...") llm.client.pull(model_name) logger.info(f"Model {model_name} pulled successfully") except Exception as e: