mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 20:12:55 +01:00
Merge remote-tracking branch 'origin/main' into llama3
This commit is contained in:
commit
a19e991fa1
35 changed files with 1569 additions and 221 deletions
|
|
@ -71,16 +71,46 @@ class EmbeddingComponent:
|
|||
from llama_index.embeddings.ollama import ( # type: ignore
|
||||
OllamaEmbedding,
|
||||
)
|
||||
from ollama import Client # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Local dependencies not found, install with `poetry install --extras embeddings-ollama`"
|
||||
) from e
|
||||
|
||||
ollama_settings = settings.ollama
|
||||
|
||||
# Calculate embedding model. If not provided tag, it will be use latest
|
||||
model_name = (
|
||||
ollama_settings.embedding_model + ":latest"
|
||||
if ":" not in ollama_settings.embedding_model
|
||||
else ollama_settings.embedding_model
|
||||
)
|
||||
|
||||
self.embedding_model = OllamaEmbedding(
|
||||
model_name=ollama_settings.embedding_model,
|
||||
model_name=model_name,
|
||||
base_url=ollama_settings.embedding_api_base,
|
||||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import (
|
||||
check_connection,
|
||||
pull_model,
|
||||
)
|
||||
|
||||
# TODO: Reuse llama-index client when llama-index is updated
|
||||
client = Client(
|
||||
host=ollama_settings.embedding_api_base,
|
||||
timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if not check_connection(client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(client, model_name)
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.embeddings.azure_openai import ( # type: ignore
|
||||
|
|
@ -99,6 +129,20 @@ class EmbeddingComponent:
|
|||
azure_endpoint=azopenai_settings.azure_endpoint,
|
||||
api_version=azopenai_settings.api_version,
|
||||
)
|
||||
case "gemini":
|
||||
try:
|
||||
from llama_index.embeddings.gemini import ( # type: ignore
|
||||
GeminiEmbedding,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Gemini dependencies not found, install with `poetry install --extras embeddings-gemini`"
|
||||
) from e
|
||||
|
||||
self.embedding_model = GeminiEmbedding(
|
||||
api_key=settings.gemini.api_key,
|
||||
model_name=settings.gemini.embedding_model,
|
||||
)
|
||||
case "mock":
|
||||
# Not a random number, is the dimensionality used by
|
||||
# the default embedding model
|
||||
|
|
|
|||
|
|
@ -35,10 +35,10 @@ class LLMComponent:
|
|||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to download tokenizer %s. Falling back to "
|
||||
"default tokenizer.",
|
||||
settings.llm.tokenizer,
|
||||
e,
|
||||
f"Failed to download tokenizer {settings.llm.tokenizer}: {e!s}"
|
||||
f"Please follow the instructions in the documentation to download it if needed: "
|
||||
f"https://docs.privategpt.dev/installation/getting-started/troubleshooting#tokenizer-setup."
|
||||
f"Falling back to default tokenizer."
|
||||
)
|
||||
|
||||
logger.info("Initializing the LLM in mode=%s", llm_mode)
|
||||
|
|
@ -146,8 +146,15 @@ class LLMComponent:
|
|||
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
|
||||
}
|
||||
|
||||
self.llm = Ollama(
|
||||
model=ollama_settings.llm_model,
|
||||
# calculate llm model. If not provided tag, it will be use latest
|
||||
model_name = (
|
||||
ollama_settings.llm_model + ":latest"
|
||||
if ":" not in ollama_settings.llm_model
|
||||
else ollama_settings.llm_model
|
||||
)
|
||||
|
||||
llm = Ollama(
|
||||
model=model_name,
|
||||
base_url=ollama_settings.api_base,
|
||||
temperature=settings.llm.temperature,
|
||||
context_window=settings.llm.context_window,
|
||||
|
|
@ -155,6 +162,16 @@ class LLMComponent:
|
|||
request_timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import check_connection, pull_model
|
||||
|
||||
if not check_connection(llm.client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(llm.client, model_name)
|
||||
|
||||
if (
|
||||
ollama_settings.keep_alive
|
||||
!= ollama_settings.model_fields["keep_alive"].default
|
||||
|
|
@ -172,6 +189,8 @@ class LLMComponent:
|
|||
Ollama.complete = add_keep_alive(Ollama.complete)
|
||||
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
|
||||
|
||||
self.llm = llm
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.llms.azure_openai import ( # type: ignore
|
||||
|
|
@ -190,5 +209,18 @@ class LLMComponent:
|
|||
azure_endpoint=azopenai_settings.azure_endpoint,
|
||||
api_version=azopenai_settings.api_version,
|
||||
)
|
||||
case "gemini":
|
||||
try:
|
||||
from llama_index.llms.gemini import ( # type: ignore
|
||||
Gemini,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Google Gemini dependencies not found, install with `poetry install --extras llms-gemini`"
|
||||
) from e
|
||||
gemini_settings = settings.gemini
|
||||
self.llm = Gemini(
|
||||
model_name=gemini_settings.model, api_key=gemini_settings.api_key
|
||||
)
|
||||
case "mock":
|
||||
self.llm = MockLLM()
|
||||
|
|
|
|||
|
|
@ -121,6 +121,72 @@ class VectorStoreComponent:
|
|||
collection_name="make_this_parameterizable_per_api_call",
|
||||
), # TODO
|
||||
)
|
||||
|
||||
case "milvus":
|
||||
try:
|
||||
from llama_index.vector_stores.milvus import ( # type: ignore
|
||||
MilvusVectorStore,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Milvus dependencies not found, install with `poetry install --extras vector-stores-milvus`"
|
||||
) from e
|
||||
|
||||
if settings.milvus is None:
|
||||
logger.info(
|
||||
"Milvus config not found. Using default settings.\n"
|
||||
"Trying to connect to Milvus at local_data/private_gpt/milvus/milvus_local.db "
|
||||
"with collection 'make_this_parameterizable_per_api_call'."
|
||||
)
|
||||
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
MilvusVectorStore(
|
||||
dim=settings.embedding.embed_dim,
|
||||
collection_name="make_this_parameterizable_per_api_call",
|
||||
overwrite=True,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
self.vector_store = typing.cast(
|
||||
BasePydanticVectorStore,
|
||||
MilvusVectorStore(
|
||||
dim=settings.embedding.embed_dim,
|
||||
uri=settings.milvus.uri,
|
||||
token=settings.milvus.token,
|
||||
collection_name=settings.milvus.collection_name,
|
||||
overwrite=settings.milvus.overwrite,
|
||||
),
|
||||
)
|
||||
|
||||
case "clickhouse":
|
||||
try:
|
||||
from clickhouse_connect import ( # type: ignore
|
||||
get_client,
|
||||
)
|
||||
from llama_index.vector_stores.clickhouse import ( # type: ignore
|
||||
ClickHouseVectorStore,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ClickHouse dependencies not found, install with `poetry install --extras vector-stores-clickhouse`"
|
||||
) from e
|
||||
|
||||
if settings.clickhouse is None:
|
||||
raise ValueError(
|
||||
"ClickHouse settings not found. Please provide settings."
|
||||
)
|
||||
|
||||
clickhouse_client = get_client(
|
||||
host=settings.clickhouse.host,
|
||||
port=settings.clickhouse.port,
|
||||
username=settings.clickhouse.username,
|
||||
password=settings.clickhouse.password,
|
||||
)
|
||||
self.vector_store = ClickHouseVectorStore(
|
||||
clickhouse_client=clickhouse_client
|
||||
)
|
||||
case _:
|
||||
# Should be unreachable
|
||||
# The settings validator should have caught this
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue