mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 17:05:41 +01:00
Added hybrid search with updated chunk size for semantic embedding
This commit is contained in:
parent
ebe43082cd
commit
3ba585ffc0
7 changed files with 241 additions and 21 deletions
|
|
@ -8,7 +8,8 @@ from fastapi_pagination import add_pagination
|
|||
from private_gpt.settings.settings import settings
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from private_gpt.constants import UPLOAD_DIR
|
||||
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# Set log_config=None to do not use the uvicorn logging configuration, and
|
||||
# use ours instead. For reference, see below:
|
||||
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
|
||||
|
|
|
|||
151
private_gpt/components/vector_store/hybrid_fn.py
Normal file
151
private_gpt/components/vector_store/hybrid_fn.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
from llama_index.core.vector_stores import VectorStoreQueryResult
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
|
||||
doc_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"naver/efficient-splade-VI-BT-large-doc"
|
||||
)
|
||||
doc_model = AutoModelForMaskedLM.from_pretrained(
|
||||
"naver/efficient-splade-VI-BT-large-doc"
|
||||
)
|
||||
|
||||
query_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"naver/efficient-splade-VI-BT-large-query"
|
||||
)
|
||||
query_model = AutoModelForMaskedLM.from_pretrained(
|
||||
"naver/efficient-splade-VI-BT-large-query"
|
||||
)
|
||||
|
||||
def sparse_doc_vectors(
|
||||
texts: List[str],
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""
|
||||
Computes vectors from logits and attention mask using ReLU, log, and max operations.
|
||||
"""
|
||||
tokens = doc_tokenizer(
|
||||
texts, truncation=True, padding=True, return_tensors="pt"
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
tokens = tokens.to("cuda")
|
||||
|
||||
output = doc_model(**tokens)
|
||||
logits, attention_mask = output.logits, tokens.attention_mask
|
||||
relu_log = torch.log(1 + torch.relu(logits))
|
||||
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
||||
tvecs, _ = torch.max(weighted_log, dim=1)
|
||||
|
||||
# extract the vectors that are non-zero and their indices
|
||||
indices = []
|
||||
vecs = []
|
||||
for batch in tvecs:
|
||||
indices.append(batch.nonzero(as_tuple=True)[0].tolist())
|
||||
vecs.append(batch[indices[-1]].tolist())
|
||||
|
||||
return indices, vecs
|
||||
|
||||
|
||||
def sparse_query_vectors(
|
||||
texts: List[str],
|
||||
) -> Tuple[List[List[int]], List[List[float]]]:
|
||||
"""
|
||||
Computes vectors from logits and attention mask using ReLU, log, and max operations.
|
||||
"""
|
||||
# TODO: compute sparse vectors in batches if max length is exceeded
|
||||
tokens = query_tokenizer(
|
||||
texts, truncation=True, padding=True, return_tensors="pt"
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
tokens = tokens.to("cuda")
|
||||
|
||||
output = query_model(**tokens)
|
||||
logits, attention_mask = output.logits, tokens.attention_mask
|
||||
relu_log = torch.log(1 + torch.relu(logits))
|
||||
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
||||
tvecs, _ = torch.max(weighted_log, dim=1)
|
||||
|
||||
# extract the vectors that are non-zero and their indices
|
||||
indices = []
|
||||
vecs = []
|
||||
for batch in tvecs:
|
||||
indices.append(batch.nonzero(as_tuple=True)[0].tolist())
|
||||
vecs.append(batch[indices[-1]].tolist())
|
||||
|
||||
return indices, vecs
|
||||
|
||||
def relative_score_fusion(
|
||||
dense_result: VectorStoreQueryResult,
|
||||
sparse_result: VectorStoreQueryResult,
|
||||
alpha: float = 0.5, # passed in from the query engine
|
||||
top_k: int = 2, # passed in from the query engine i.e. similarity_top_k
|
||||
) -> VectorStoreQueryResult:
|
||||
"""
|
||||
Fuse dense and sparse results using relative score fusion.
|
||||
"""
|
||||
# sanity check
|
||||
assert dense_result.nodes is not None
|
||||
assert dense_result.similarities is not None
|
||||
assert sparse_result.nodes is not None
|
||||
assert sparse_result.similarities is not None
|
||||
|
||||
# deconstruct results
|
||||
sparse_result_tuples = list(
|
||||
zip(sparse_result.similarities, sparse_result.nodes)
|
||||
)
|
||||
sparse_result_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
dense_result_tuples = list(
|
||||
zip(dense_result.similarities, dense_result.nodes)
|
||||
)
|
||||
dense_result_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# track nodes in both results
|
||||
all_nodes_dict = {x.node_id: x for x in dense_result.nodes}
|
||||
for node in sparse_result.nodes:
|
||||
if node.node_id not in all_nodes_dict:
|
||||
all_nodes_dict[node.node_id] = node
|
||||
|
||||
# normalize sparse similarities from 0 to 1
|
||||
sparse_similarities = [x[0] for x in sparse_result_tuples]
|
||||
max_sparse_sim = max(sparse_similarities)
|
||||
min_sparse_sim = min(sparse_similarities)
|
||||
sparse_similarities = [
|
||||
(x - min_sparse_sim) / (max_sparse_sim - min_sparse_sim)
|
||||
for x in sparse_similarities
|
||||
]
|
||||
sparse_per_node = {
|
||||
sparse_result_tuples[i][1].node_id: x
|
||||
for i, x in enumerate(sparse_similarities)
|
||||
}
|
||||
|
||||
# normalize dense similarities from 0 to 1
|
||||
dense_similarities = [x[0] for x in dense_result_tuples]
|
||||
max_dense_sim = max(dense_similarities)
|
||||
min_dense_sim = min(dense_similarities)
|
||||
dense_similarities = [
|
||||
(x - min_dense_sim) / (max_dense_sim - min_dense_sim)
|
||||
for x in dense_similarities
|
||||
]
|
||||
dense_per_node = {
|
||||
dense_result_tuples[i][1].node_id: x
|
||||
for i, x in enumerate(dense_similarities)
|
||||
}
|
||||
|
||||
# fuse the scores
|
||||
fused_similarities = []
|
||||
for node_id in all_nodes_dict:
|
||||
sparse_sim = sparse_per_node.get(node_id, 0)
|
||||
dense_sim = dense_per_node.get(node_id, 0)
|
||||
fused_sim = alpha * (sparse_sim + dense_sim)
|
||||
fused_similarities.append((fused_sim, all_nodes_dict[node_id]))
|
||||
|
||||
fused_similarities.sort(key=lambda x: x[0], reverse=True)
|
||||
fused_similarities = fused_similarities[:top_k]
|
||||
|
||||
# create final response object
|
||||
return VectorStoreQueryResult(
|
||||
nodes=[x[1] for x in fused_similarities],
|
||||
similarities=[x[0] for x in fused_similarities],
|
||||
ids=[x[1].node_id for x in fused_similarities],
|
||||
)
|
||||
|
|
@ -13,6 +13,7 @@ from llama_index.core.vector_stores.types import (
|
|||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
from .hybrid_fn import sparse_query_vectors, sparse_doc_vectors, relative_score_fusion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -119,6 +120,11 @@ class VectorStoreComponent:
|
|||
QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name="make_this_parameterizable_per_api_call",
|
||||
enable_hybrid=True,
|
||||
batch_size=20,
|
||||
sparse_doc_fn=sparse_doc_vectors,
|
||||
sparse_query_fn=sparse_query_vectors,
|
||||
# hybrid_fusion_fn=relative_score_fusion,
|
||||
), # TODO
|
||||
)
|
||||
case _:
|
||||
|
|
@ -144,6 +150,9 @@ class VectorStoreComponent:
|
|||
if self.settings.vectorstore.database != "qdrant"
|
||||
else None
|
||||
),
|
||||
sparse_top_k=12,
|
||||
vector_store_query_mode="hybrid",
|
||||
alpha=0.5
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ from llama_index.core.storage import StorageContext
|
|||
from llama_index.core.types import TokenGen
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_index.core import get_response_synthesizer
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
|
|
@ -26,6 +29,7 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|||
from private_gpt.server.chunks.chunks_service import Chunk
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
from private_gpt.paths import models_path
|
||||
|
||||
class Completion(BaseModel):
|
||||
response: str
|
||||
|
|
@ -36,7 +40,7 @@ class CompletionGen(BaseModel):
|
|||
response: TokenGen
|
||||
sources: list[Chunk] | None = None
|
||||
|
||||
|
||||
reranker_path = models_path / 'reranker'
|
||||
@dataclass
|
||||
class ChatEngineInput:
|
||||
system_message: ChatMessage | None = None
|
||||
|
|
@ -126,9 +130,16 @@ class ChatService:
|
|||
)
|
||||
node_postprocessors.append(rerank_postprocessor)
|
||||
|
||||
return CondensePlusContextChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
response_synthesizer = get_response_synthesizer(structured_answer_filtering=True, llm=self.llm_component.llm)
|
||||
|
||||
custom_query_engine = RetrieverQueryEngine(
|
||||
retriever=vector_index_retriever,
|
||||
response_synthesizer=response_synthesizer
|
||||
)
|
||||
|
||||
return ContextChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
retriever=custom_query_engine,
|
||||
llm=self.llm_component.llm, # Takes no effect at the moment
|
||||
node_postprocessors=node_postprocessors,
|
||||
)
|
||||
|
|
@ -189,16 +200,15 @@ class ChatService:
|
|||
system_prompt = (
|
||||
"""
|
||||
You are a helpful assistant named QuickGPT by Quickfox Consulting.
|
||||
Your responses must be strictly and exclusively based on the context documents provided.
|
||||
|
||||
You are not allowed to use any information, knowledge, or external sources outside of the given context documents.
|
||||
If the answer to a query is not present in the context documents,
|
||||
you should respond with "I do not have enough information in the provided context to answer this question."
|
||||
Engage in a two-way conversation, ensuring that your responses are strictly and exclusively based on the relevant context documents provided.
|
||||
|
||||
Your responses should be relevant, informative, and easy to understand.
|
||||
Do not use any prior knowledge or external sources or make assumptions, inferences, or draw upon any prior knowledge beyond what is explicitly stated in the relevant context documents.
|
||||
If the answer to a query is not present in the relevant context documents, respond with "I do not have enough information in the provided context to answer this question."
|
||||
|
||||
Your responses must be relevant, informative, and easy to understand.
|
||||
Aim to deliver high-quality answers that are respectful and helpful, using clear and concise language.
|
||||
Focus on providing accurate and reliable answers based solely on the given context.
|
||||
Do not make assumptions, inferences, or draw upon any prior knowledge beyond what is explicitly stated in the context documents.
|
||||
Consider previous queries only if the latest query is directly related to them. Address only the most recent query unless it explicitly builds upon a previous one.
|
||||
"""
|
||||
)
|
||||
chat_history = (
|
||||
|
|
@ -209,7 +219,6 @@ class ChatService:
|
|||
use_context=use_context,
|
||||
context_filter=context_filter,
|
||||
)
|
||||
# chat_engine = chat_engine.as_chat_engine(chat_mode="react", llm=self.llm_component.llm, verbose=True) # configuring ReAct Chat engine
|
||||
wrapped_response = chat_engine.chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=chat_history,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, AnyStr, BinaryIO
|
||||
from typing import TYPE_CHECKING, AnyStr, BinaryIO, Sequence, Any, List
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.node_parser import SentenceWindowNodeParser, SemanticSplitterNodeParser
|
||||
from llama_index.core.node_parser import SemanticSplitterNodeParser, SentenceSplitter
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.core.schema import BaseNode , ObjectType , TextNode
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.ingest.ingest_component import get_ingestion_component
|
||||
|
|
@ -17,12 +18,44 @@ from private_gpt.components.vector_store.vector_store_component import (
|
|||
from private_gpt.server.ingest.model import IngestedDoc
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
|
||||
from llama_index.core.extractors import (
|
||||
QuestionsAnsweredExtractor,
|
||||
TitleExtractor,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.storage.docstore.types import RefDocInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 512
|
||||
SENTENCE_CHUNK_OVERLAP = 20
|
||||
|
||||
class SafeSemanticSplitter(SemanticSplitterNodeParser):
|
||||
|
||||
safety_chunker: SentenceSplitter = SentenceSplitter(chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=SENTENCE_CHUNK_OVERLAP)
|
||||
|
||||
def _parse_nodes(
|
||||
self,
|
||||
nodes,
|
||||
show_progress: bool = False,
|
||||
**kwargs
|
||||
) -> List[BaseNode]:
|
||||
all_nodes: List[BaseNode] = super()._parse_nodes(nodes=nodes, show_progress=show_progress, **kwargs)
|
||||
all_good = True
|
||||
for node in all_nodes:
|
||||
if node.get_type() == ObjectType.TEXT:
|
||||
node: TextNode= node
|
||||
if self.safety_chunker._token_size(node.text) > self.safety_chunker.chunk_size:
|
||||
logging.info("Chunk size too big after semantic chunking: switching to static chunking")
|
||||
all_good = False
|
||||
break
|
||||
if not all_good:
|
||||
all_nodes = self.safety_chunker._parse_nodes(nodes, show_progress=show_progress, **kwargs)
|
||||
return all_nodes
|
||||
|
||||
|
||||
@singleton
|
||||
class IngestService:
|
||||
@inject
|
||||
|
|
@ -39,14 +72,22 @@ class IngestService:
|
|||
docstore=node_store_component.doc_store,
|
||||
index_store=node_store_component.index_store,
|
||||
)
|
||||
node_parser = SemanticSplitterNodeParser.from_defaults(
|
||||
# splitter = SentenceSplitter(chunk_size=512, chunk_overlap=128)
|
||||
node_parser = SafeSemanticSplitter.from_defaults(
|
||||
embed_model=embedding_component.embedding_model,
|
||||
# sentence_splitter=splitter,
|
||||
include_metadata=True,
|
||||
include_prev_next_rel=True,
|
||||
)
|
||||
|
||||
self.ingest_component = get_ingestion_component(
|
||||
self.storage_context,
|
||||
embed_model=embedding_component.embedding_model,
|
||||
transformations=[node_parser, embedding_component.embedding_model],
|
||||
transformations=[
|
||||
node_parser,
|
||||
TitleExtractor(nodes=1, llm=self.llm_service.llm),
|
||||
QuestionsAnsweredExtractor(questions=1,llm=self.llm_service.llm),
|
||||
embedding_component.embedding_model],
|
||||
settings=settings(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import argparse
|
|||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from private_gpt.paths import models_path, models_cache_path
|
||||
from private_gpt.settings.settings import settings
|
||||
|
|
@ -46,4 +47,12 @@ AutoTokenizer.from_pretrained(
|
|||
)
|
||||
print("Tokenizer downloaded!")
|
||||
|
||||
# Download Reranker
|
||||
# print(f"Downloading reranker {settings().rag.rerank.model}")
|
||||
|
||||
# reranker_path = r'D:/QuickGPT/privateGPT/models/reranker'
|
||||
# rerank_postprocessor = SentenceTransformer(
|
||||
# settings().rag.rerank.model
|
||||
# )
|
||||
# rerank_postprocessor.save(reranker_path)
|
||||
print("Setup done")
|
||||
|
|
@ -51,14 +51,14 @@ rag:
|
|||
#This value is disabled by default. If you enable this settings, the RAG will only use articles that meet a certain percentage score.
|
||||
rerank:
|
||||
enabled: true
|
||||
model: mixedbread-ai/mxbai-embed-large-v1
|
||||
model: avsolatorio/GIST-Embedding-v0
|
||||
top_n: 2
|
||||
|
||||
llamacpp:
|
||||
# llm_hf_repo_id: bartowski/Meta-Llama-3-8B-Instruct-GGUF
|
||||
# llm_hf_model_file: Meta-Llama-3-8B-Instruct-Q6_K.gguf
|
||||
llm_hf_repo_id: qwp4w3hyb/Hermes-2-Pro-Llama-3-8B-iMat-GGUF
|
||||
llm_hf_model_file: hermes-2-pro-llama-3-8b-imat-Q6_K.gguf
|
||||
llm_hf_model_file: hermes-2-pro-llama-3-8b-imat-Q4_K_S.gguf
|
||||
tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
|
||||
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
|
||||
top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||
|
|
@ -68,11 +68,11 @@ llamacpp:
|
|||
embedding:
|
||||
# Should be matching the value above in most cases
|
||||
mode: huggingface
|
||||
ingest_mode: parallel
|
||||
ingest_mode: pipeline
|
||||
embed_dim: 384 # 384 is for BAAI/bge-small-en-v1.5
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: mixedbread-ai/mxbai-embed-large-v1
|
||||
embedding_hf_model_name: BAAI/bge-large-en
|
||||
access_token: ${HUGGINGFACE_TOKEN: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}
|
||||
|
||||
vectorstore:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue