fix: mypy

This commit is contained in:
Javier Martinez 2024-07-29 09:31:21 +02:00
parent 0a9c57447b
commit c096a42aa1
No known key found for this signature in database
2 changed files with 14 additions and 13 deletions

View file

@ -68,10 +68,10 @@ class EmbeddingComponent:
) )
case "ollama": case "ollama":
try: try:
from ollama import Client, AsyncClient # type: ignore
from llama_index.embeddings.ollama import ( # type: ignore from llama_index.embeddings.ollama import ( # type: ignore
OllamaEmbedding, OllamaEmbedding,
) )
from ollama import Client # type: ignore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Local dependencies not found, install with `poetry install --extras embeddings-ollama`" "Local dependencies not found, install with `poetry install --extras embeddings-ollama`"
@ -79,9 +79,9 @@ class EmbeddingComponent:
ollama_settings = settings.ollama ollama_settings = settings.ollama
# calculate embedding model. If not provided tag, it will be use latest model # Calculate embedding model. If not provided tag, it will be use latest
model_name = ( model_name = (
ollama_settings.embedding_model + ':latest' ollama_settings.embedding_model + ":latest"
if ":" not in ollama_settings.embedding_model if ":" not in ollama_settings.embedding_model
else ollama_settings.embedding_model else ollama_settings.embedding_model
) )
@ -98,11 +98,11 @@ class EmbeddingComponent:
host=ollama_settings.embedding_api_base, host=ollama_settings.embedding_api_base,
timeout=ollama_settings.request_timeout, timeout=ollama_settings.request_timeout,
) )
installed_models = [model['name'] for model in client.list().get("models", {})] installed_models = [
model["name"] for model in client.list().get("models", {})
]
if model_name not in installed_models: if model_name not in installed_models:
logger.info( logger.info(f"Pulling model {model_name}. Please wait...")
f"Pulling model {model_name}. Please wait..."
)
client.pull(model_name) client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully") logger.info(f"Model {model_name} pulled successfully")
except Exception as e: except Exception as e:

View file

@ -146,9 +146,9 @@ class LLMComponent:
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp "repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
} }
# calculate llm model. If not provided tag, it will be use latest model # calculate llm model. If not provided tag, it will be use latest
model_name = ( model_name = (
ollama_settings.llm_model + ':latest' ollama_settings.llm_model + ":latest"
if ":" not in ollama_settings.llm_model if ":" not in ollama_settings.llm_model
else ollama_settings.llm_model else ollama_settings.llm_model
) )
@ -164,11 +164,12 @@ class LLMComponent:
if ollama_settings.autopull_models: if ollama_settings.autopull_models:
try: try:
installed_models = [model['name'] for model in llm.client.list().get("models", {})] installed_models = [
model["name"]
for model in llm.client.list().get("models", {})
]
if model_name not in installed_models: if model_name not in installed_models:
logger.info( logger.info(f"Pulling model {model_name}. Please wait...")
f"Pulling model {model_name}. Please wait..."
)
llm.client.pull(model_name) llm.client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully") logger.info(f"Model {model_name} pulled successfully")
except Exception as e: except Exception as e: