Updated the llm component

This commit is contained in:
Saurab-Shrestha 2024-05-02 10:58:03 +05:45
parent bc343206cc
commit 1963190d16
10 changed files with 145 additions and 30 deletions

View file

@ -1,4 +1,6 @@
import logging
from collections.abc import Callable
from typing import Any
from injector import inject, singleton
from llama_index.core.llms import LLM, MockLLM
@ -18,14 +20,24 @@ class LLMComponent:
@inject
def __init__(self, settings: Settings) -> None:
llm_mode = settings.llm.mode
if settings.llm.tokenizer:
set_global_tokenizer(
AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path),
if settings.llm.tokenizer and settings.llm.mode != "mock":
# Try to download the tokenizer. If it fails, the LLM will still work
# using the default one, which is less accurate.
try:
set_global_tokenizer(
AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path),
token=settings.huggingface.access_token,
)
)
except Exception as e:
logger.warning(
"Failed to download tokenizer %s. Falling back to "
"default tokenizer.",
settings.llm.tokenizer,
e,
)
)
logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode:
@ -47,7 +59,8 @@ class LLMComponent:
"offload_kqv": True,
}
self.llm = LlamaCPP(
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
model_path=str(
models_path / settings.llamacpp.llm_hf_model_file),
temperature=settings.llm.temperature,
max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window,
@ -130,6 +143,44 @@ class LLMComponent:
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
additional_kwargs=settings_kwargs,
request_timeout=ollama_settings.request_timeout,
)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
):
# Modify Ollama methods to use the "keep_alive" field.
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
kwargs["keep_alive"] = ollama_settings.keep_alive
return func(*args, **kwargs)
return wrapper
Ollama.chat = add_keep_alive(Ollama.chat)
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat)
Ollama.complete = add_keep_alive(Ollama.complete)
Ollama.stream_complete = add_keep_alive(
Ollama.stream_complete)
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
AzureOpenAI,
)
except ImportError as e:
raise ImportError(
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
) from e
azopenai_settings = settings.azopenai
self.llm = AzureOpenAI(
model=azopenai_settings.llm_model,
deployment_name=azopenai_settings.llm_deployment_name,
api_key=azopenai_settings.api_key,
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
)
case "mock":
self.llm = MockLLM()