feat: allow to autopull ollama models

This commit is contained in:
Javier Martinez 2024-07-29 09:25:42 +02:00
parent d34b7d22f4
commit 0a9c57447b
No known key found for this signature in database
4 changed files with 56 additions and 3 deletions

View file

@ -68,6 +68,7 @@ class EmbeddingComponent:
) )
case "ollama": case "ollama":
try: try:
from ollama import Client, AsyncClient # type: ignore
from llama_index.embeddings.ollama import ( # type: ignore from llama_index.embeddings.ollama import ( # type: ignore
OllamaEmbedding, OllamaEmbedding,
) )
@ -77,10 +78,36 @@ class EmbeddingComponent:
) from e ) from e
ollama_settings = settings.ollama ollama_settings = settings.ollama
# calculate embedding model. If not provided tag, it will be use latest model
model_name = (
ollama_settings.embedding_model + ':latest'
if ":" not in ollama_settings.embedding_model
else ollama_settings.embedding_model
)
self.embedding_model = OllamaEmbedding( self.embedding_model = OllamaEmbedding(
model_name=ollama_settings.embedding_model, model_name=model_name,
base_url=ollama_settings.embedding_api_base, base_url=ollama_settings.embedding_api_base,
) )
if ollama_settings.autopull_models:
try:
# TODO: Reuse llama-index client when llama-index is updated
client = Client(
host=ollama_settings.embedding_api_base,
timeout=ollama_settings.request_timeout,
)
installed_models = [model['name'] for model in client.list().get("models", {})]
if model_name not in installed_models:
logger.info(
f"Pulling model {model_name}. Please wait..."
)
client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
case "azopenai": case "azopenai":
try: try:
from llama_index.embeddings.azure_openai import ( # type: ignore from llama_index.embeddings.azure_openai import ( # type: ignore

View file

@ -146,8 +146,15 @@ class LLMComponent:
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp "repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
} }
self.llm = Ollama( # calculate llm model. If not provided tag, it will be use latest model
model=ollama_settings.llm_model, model_name = (
ollama_settings.llm_model + ':latest'
if ":" not in ollama_settings.llm_model
else ollama_settings.llm_model
)
llm = Ollama(
model=model_name,
base_url=ollama_settings.api_base, base_url=ollama_settings.api_base,
temperature=settings.llm.temperature, temperature=settings.llm.temperature,
context_window=settings.llm.context_window, context_window=settings.llm.context_window,
@ -155,6 +162,18 @@ class LLMComponent:
request_timeout=ollama_settings.request_timeout, request_timeout=ollama_settings.request_timeout,
) )
if ollama_settings.autopull_models:
try:
installed_models = [model['name'] for model in llm.client.list().get("models", {})]
if model_name not in installed_models:
logger.info(
f"Pulling model {model_name}. Please wait..."
)
llm.client.pull(model_name)
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
if ( if (
ollama_settings.keep_alive ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default != ollama_settings.model_fields["keep_alive"].default
@ -172,6 +191,8 @@ class LLMComponent:
Ollama.complete = add_keep_alive(Ollama.complete) Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
self.llm = llm
case "azopenai": case "azopenai":
try: try:
from llama_index.llms.azure_openai import ( # type: ignore from llama_index.llms.azure_openai import ( # type: ignore

View file

@ -290,6 +290,10 @@ class OllamaSettings(BaseModel):
120.0, 120.0,
description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ", description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ",
) )
autopull_models: bool = Field(
False,
description="If set to True, the Ollama will automatically pull the models from the API base.",
)
class AzureOpenAISettings(BaseModel): class AzureOpenAISettings(BaseModel):

View file

@ -117,6 +117,7 @@ ollama:
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
keep_alive: 5m keep_alive: 5m
request_timeout: 120.0 request_timeout: 120.0
autopull_models: true
azopenai: azopenai:
api_key: ${AZ_OPENAI_API_KEY:} api_key: ${AZ_OPENAI_API_KEY:}