Support for Nvidia TensorRT

This commit is contained in:
imartinez 2024-02-29 19:41:58 +01:00
parent c3fe36e070
commit a7b18058b5
7 changed files with 141 additions and 8 deletions

View file

@ -111,5 +111,20 @@ class LLMComponent:
self.llm = Ollama(
model=ollama_settings.model, base_url=ollama_settings.api_base
)
case "tensorrt":
try:
from llama_index.llms.nvidia_tensorrt import LocalTensorRTLLM # type: ignore
except ImportError as e:
raise ImportError(
"Nvidia TensorRTLLM dependencies not found, install with `poetry install --extras llms-nvidia-tensorrt`"
) from e
prompt_style = get_prompt_style(settings.tensorrt.prompt_style)
self.llm = LocalTensorRTLLM(
model_path=settings.tensorrt.model_path,
engine_name=settings.tensorrt.engine_name,
tokenizer_dir=settings.llm.tokenizer,
completion_to_prompt=prompt_style.completion_to_prompt,
)
case "mock":
self.llm = MockLLM()