mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 10:45:42 +01:00
Merge remote-tracking branch 'origin/main' into j_main
This commit is contained in:
commit
f457101b87
37 changed files with 3563 additions and 2657 deletions
|
|
@ -8,7 +8,7 @@ inputs:
|
|||
poetry_version:
|
||||
required: true
|
||||
type: string
|
||||
default: "1.5.1"
|
||||
default: "1.8.3"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ FROM python:3.11.6-slim-bookworm as base
|
|||
# Install poetry
|
||||
RUN pip install pipx
|
||||
RUN python3 -m pipx ensurepath
|
||||
RUN pipx install poetry
|
||||
RUN pipx install poetry==1.8.3
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV PATH=".venv/bin/:$PATH"
|
||||
|
||||
|
|
@ -14,27 +14,38 @@ FROM base as dependencies
|
|||
WORKDIR /home/worker/app
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
RUN poetry install --extras "ui vector-stores-qdrant llms-ollama embeddings-ollama"
|
||||
ARG POETRY_EXTRAS="ui vector-stores-qdrant llms-ollama embeddings-ollama"
|
||||
RUN poetry install --no-root --extras "${POETRY_EXTRAS}"
|
||||
|
||||
FROM base as app
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PORT=8080
|
||||
ENV APP_ENV=prod
|
||||
ENV PYTHONPATH="$PYTHONPATH:/home/worker/app/private_gpt/"
|
||||
EXPOSE 8080
|
||||
|
||||
# Prepare a non-root user
|
||||
RUN adduser --system worker
|
||||
# More info about how to configure UIDs and GIDs in Docker:
|
||||
# https://github.com/systemd/systemd/blob/main/docs/UIDS-GIDS.md
|
||||
|
||||
# Define the User ID (UID) for the non-root user
|
||||
# UID 100 is chosen to avoid conflicts with existing system users
|
||||
ARG UID=100
|
||||
|
||||
# Define the Group ID (GID) for the non-root user
|
||||
# GID 65534 is often used for the 'nogroup' or 'nobody' group
|
||||
ARG GID=65534
|
||||
|
||||
RUN adduser --system --gid ${GID} --uid ${UID} --home /home/worker worker
|
||||
WORKDIR /home/worker/app
|
||||
|
||||
RUN mkdir local_data; chown worker local_data
|
||||
RUN mkdir models; chown worker models
|
||||
RUN chown worker /home/worker/app
|
||||
RUN mkdir local_data && chown worker local_data
|
||||
RUN mkdir models && chown worker models
|
||||
COPY --chown=worker --from=dependencies /home/worker/app/.venv/ .venv
|
||||
COPY --chown=worker private_gpt/ private_gpt
|
||||
COPY --chown=worker fern/ fern
|
||||
COPY --chown=worker *.yaml *.md ./
|
||||
COPY --chown=worker *.yaml .
|
||||
COPY --chown=worker scripts/ scripts
|
||||
|
||||
ENV PYTHONPATH="$PYTHONPATH:/private_gpt/"
|
||||
|
||||
USER worker
|
||||
ENTRYPOINT python -m private_gpt
|
||||
ENTRYPOINT python -m private_gpt
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ FROM python:3.11.6-slim-bookworm as base
|
|||
# Install poetry
|
||||
RUN pip install pipx
|
||||
RUN python3 -m pipx ensurepath
|
||||
RUN pipx install poetry
|
||||
RUN pipx install poetry==1.8.3
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
ENV PATH=".venv/bin/:$PATH"
|
||||
|
||||
|
|
@ -24,28 +24,39 @@ FROM base as dependencies
|
|||
WORKDIR /home/worker/app
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
RUN poetry install --extras "ui embeddings-huggingface llms-llama-cpp vector-stores-qdrant"
|
||||
ARG POETRY_EXTRAS="ui embeddings-huggingface llms-llama-cpp vector-stores-qdrant"
|
||||
RUN poetry install --no-root --extras "${POETRY_EXTRAS}"
|
||||
|
||||
FROM base as app
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PORT=8080
|
||||
ENV APP_ENV=prod
|
||||
ENV PYTHONPATH="$PYTHONPATH:/home/worker/app/private_gpt/"
|
||||
EXPOSE 8080
|
||||
|
||||
# Prepare a non-root user
|
||||
RUN adduser --group worker
|
||||
RUN adduser --system --ingroup worker worker
|
||||
# More info about how to configure UIDs and GIDs in Docker:
|
||||
# https://github.com/systemd/systemd/blob/main/docs/UIDS-GIDS.md
|
||||
|
||||
# Define the User ID (UID) for the non-root user
|
||||
# UID 100 is chosen to avoid conflicts with existing system users
|
||||
ARG UID=100
|
||||
|
||||
# Define the Group ID (GID) for the non-root user
|
||||
# GID 65534 is often used for the 'nogroup' or 'nobody' group
|
||||
ARG GID=65534
|
||||
|
||||
RUN adduser --system --gid ${GID} --uid ${UID} --home /home/worker worker
|
||||
WORKDIR /home/worker/app
|
||||
|
||||
RUN mkdir local_data; chown worker local_data
|
||||
RUN mkdir models; chown worker models
|
||||
RUN chown worker /home/worker/app
|
||||
RUN mkdir local_data && chown worker local_data
|
||||
RUN mkdir models && chown worker models
|
||||
COPY --chown=worker --from=dependencies /home/worker/app/.venv/ .venv
|
||||
COPY --chown=worker private_gpt/ private_gpt
|
||||
COPY --chown=worker fern/ fern
|
||||
COPY --chown=worker *.yaml *.md ./
|
||||
COPY --chown=worker *.yaml ./
|
||||
COPY --chown=worker scripts/ scripts
|
||||
|
||||
ENV PYTHONPATH="$PYTHONPATH:/private_gpt/"
|
||||
|
||||
USER worker
|
||||
ENTRYPOINT python -m private_gpt
|
||||
21
README.md
21
README.md
|
|
@ -2,21 +2,21 @@
|
|||
|
||||
[](https://github.com/zylon-ai/private-gpt/actions/workflows/tests.yml?query=branch%3Amain)
|
||||
[](https://docs.privategpt.dev/)
|
||||
|
||||
[](https://discord.gg/bK6mRVpErU)
|
||||
[](https://twitter.com/ZylonPrivateGPT)
|
||||
|
||||
|
||||
> Install & usage docs: https://docs.privategpt.dev/
|
||||
>
|
||||
> Join the community: [Twitter](https://twitter.com/ZylonPrivateGPT) & [Discord](https://discord.gg/bK6mRVpErU)
|
||||
|
||||

|
||||
|
||||
PrivateGPT is a production-ready AI project that allows you to ask questions about your documents using the power
|
||||
of Large Language Models (LLMs), even in scenarios without an Internet connection. 100% private, no data leaves your
|
||||
execution environment at any point.
|
||||
|
||||
>[!TIP]
|
||||
> If you are looking for an **enterprise-ready, fully private AI workspace**
|
||||
> check out [Zylon's website](https://zylon.ai) or [request a demo](https://cal.com/zylon/demo?source=pgpt-readme).
|
||||
> Crafted by the team behind PrivateGPT, Zylon is a best-in-class AI collaborative
|
||||
> workspace that can be easily deployed on-premise (data center, bare metal...) or in your private cloud (AWS, GCP, Azure...).
|
||||
|
||||
The project provides an API offering all the primitives required to build private, context-aware AI applications.
|
||||
It follows and extends the [OpenAI API standard](https://openai.com/blog/openai-api),
|
||||
and supports both normal and streaming responses.
|
||||
|
|
@ -38,13 +38,10 @@ In addition to this, a working [Gradio UI](https://www.gradio.app/)
|
|||
client is provided to test the API, together with a set of useful tools such as bulk model
|
||||
download script, ingestion script, documents folder watch, etc.
|
||||
|
||||
> 👂 **Need help applying PrivateGPT to your specific use case?**
|
||||
> [Let us know more about it](https://forms.gle/4cSDmH13RZBHV9at7)
|
||||
> and we'll try to help! We are refining PrivateGPT through your feedback.
|
||||
|
||||
## 🎞️ Overview
|
||||
DISCLAIMER: This README is not updated as frequently as the [documentation](https://docs.privategpt.dev/).
|
||||
Please check it out for the latest updates!
|
||||
>[!WARNING]
|
||||
> This README is not updated as frequently as the [documentation](https://docs.privategpt.dev/).
|
||||
> Please check it out for the latest updates!
|
||||
|
||||
### Motivation behind PrivateGPT
|
||||
Generative AI is a game changer for our society, but adoption in companies of all sizes and data-sensitive
|
||||
|
|
|
|||
|
|
@ -5,12 +5,15 @@ services:
|
|||
volumes:
|
||||
- ./local_data/:/home/worker/app/local_data
|
||||
ports:
|
||||
- 8001:8080
|
||||
- 8001:8001
|
||||
environment:
|
||||
PORT: 8080
|
||||
PORT: 8001
|
||||
PGPT_PROFILES: docker
|
||||
PGPT_MODE: ollama
|
||||
PGPT_EMBED_MODE: ollama
|
||||
ollama:
|
||||
image: ollama/ollama:latest
|
||||
ports:
|
||||
- 11434:11434
|
||||
volumes:
|
||||
- ./models:/root/.ollama
|
||||
|
|
|
|||
|
|
@ -74,14 +74,16 @@ navigation:
|
|||
path: ./docs/pages/ui/gradio.mdx
|
||||
- page: Alternatives
|
||||
path: ./docs/pages/ui/alternatives.mdx
|
||||
# Small code snippet or example of usage to help users
|
||||
- tab: recipes
|
||||
layout:
|
||||
- section: Choice of LLM
|
||||
- section: Getting started
|
||||
contents:
|
||||
# TODO: add recipes
|
||||
- page: List of LLMs
|
||||
path: ./docs/pages/recipes/list-llm.mdx
|
||||
- page: Quickstart
|
||||
path: ./docs/pages/recipes/quickstart.mdx
|
||||
- section: General use cases
|
||||
contents:
|
||||
- page: Summarize
|
||||
path: ./docs/pages/recipes/summarize.mdx
|
||||
# More advanced usage of PrivateGPT, by API
|
||||
- tab: api-reference
|
||||
layout:
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 212 KiB After Width: | Height: | Size: 154 KiB |
|
|
@ -28,6 +28,11 @@ pyenv local 3.11
|
|||
Install [Poetry](https://python-poetry.org/docs/#installing-with-the-official-installer) for dependency management:
|
||||
Follow the instructions on the official Poetry website to install it.
|
||||
|
||||
<Callout intent="warning">
|
||||
A bug exists in Poetry versions 1.7.0 and earlier. We strongly recommend upgrading to a tested version.
|
||||
To upgrade Poetry to latest tested version, run `poetry self update 1.8.3` after installing it.
|
||||
</Callout>
|
||||
|
||||
### 4. Optional: Install `make`
|
||||
To run various scripts, you need to install `make`. Follow the instructions for your operating system:
|
||||
#### macOS
|
||||
|
|
@ -130,18 +135,22 @@ Go to [ollama.ai](https://ollama.ai/) and follow the instructions to install Oll
|
|||
|
||||
After the installation, make sure the Ollama desktop app is closed.
|
||||
|
||||
Install the models to be used, the default settings-ollama.yaml is configured to user `mistral 7b` LLM (~4GB) and `nomic-embed-text` Embeddings (~275MB). Therefore:
|
||||
|
||||
```bash
|
||||
ollama pull mistral
|
||||
ollama pull nomic-embed-text
|
||||
```
|
||||
|
||||
Now, start Ollama service (it will start a local inference server, serving both the LLM and the Embeddings):
|
||||
```bash
|
||||
ollama serve
|
||||
```
|
||||
|
||||
Install the models to be used, the default settings-ollama.yaml is configured to user llama3.1 8b LLM (~4GB) and nomic-embed-text Embeddings (~275MB)
|
||||
|
||||
By default, PGPT will automatically pull models as needed. This behavior can be changed by modifying the `ollama.autopull_models` property.
|
||||
|
||||
In any case, if you want to manually pull models, run the following commands:
|
||||
|
||||
```bash
|
||||
ollama pull llama3.1
|
||||
ollama pull nomic-embed-text
|
||||
```
|
||||
|
||||
Once done, on a different terminal, you can install PrivateGPT with the following command:
|
||||
```bash
|
||||
poetry install --extras "ui llms-ollama embeddings-ollama vector-stores-qdrant"
|
||||
|
|
|
|||
|
|
@ -24,8 +24,26 @@ PrivateGPT uses the `AutoTokenizer` library to tokenize input text accurately. I
|
|||
In your `settings.yaml` file, specify the model you want to use:
|
||||
```yaml
|
||||
llm:
|
||||
tokenizer: mistralai/Mistral-7B-Instruct-v0.2
|
||||
tokenizer: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
```
|
||||
2. **Set Access Token for Gated Models:**
|
||||
If you are using a gated model, ensure the `access_token` is set as mentioned in the previous section.
|
||||
This configuration ensures that PrivateGPT can download and use the correct tokenizer for the model you are working with.
|
||||
This configuration ensures that PrivateGPT can download and use the correct tokenizer for the model you are working with.
|
||||
|
||||
# Embedding dimensions mismatch
|
||||
If you encounter an error message like `Embedding dimensions mismatch`, it is likely due to the embedding model and
|
||||
current vector dimension mismatch. To resolve this issue, ensure that the model and the input data have the same vector dimensions.
|
||||
|
||||
By default, PrivateGPT uses `nomic-embed-text` embeddings, which have a vector dimension of 768.
|
||||
If you are using a different embedding model, ensure that the vector dimensions match the model's output.
|
||||
|
||||
<Callout intent = "warning">
|
||||
In versions below to 0.6.0, the default embedding model was `BAAI/bge-small-en-v1.5` in `huggingface` setup.
|
||||
If you plan to reuse the old generated embeddings, you need to update the `settings.yaml` file to use the correct embedding model:
|
||||
```yaml
|
||||
huggingface:
|
||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||
embedding:
|
||||
embed_dim: 384
|
||||
```
|
||||
</Callout>
|
||||
|
|
@ -8,6 +8,14 @@ The ingestion of documents can be done in different ways:
|
|||
|
||||
## Bulk Local Ingestion
|
||||
|
||||
You will need to activate `data.local_ingestion.enabled` in your setting file to use this feature. Additionally,
|
||||
it is probably a good idea to set `data.local_ingestion.allow_ingest_from` to specify which folders are allowed to be ingested.
|
||||
|
||||
<Callout intent = "warning">
|
||||
Be careful enabling this feature in a production environment, as it can be a security risk, as it allows users to
|
||||
ingest any local file with permissions.
|
||||
</Callout>
|
||||
|
||||
When you are running PrivateGPT in a fully local setup, you can ingest a complete folder for convenience (containing
|
||||
pdf, text files, etc.)
|
||||
and optionally watch changes on it with the command:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
PrivateGPT provides an **API** containing all the building blocks required to
|
||||
build **private, context-aware AI applications**.
|
||||
|
||||
<Callout intent = "tip">
|
||||
If you are looking for an **enterprise-ready, fully private AI workspace**
|
||||
check out [Zylon's website](https://zylon.ai) or [request a demo](https://cal.com/zylon/demo?source=pgpt-docs).
|
||||
Crafted by the team behind PrivateGPT, Zylon is a best-in-class AI collaborative
|
||||
workspace that can be easily deployed on-premise (data center, bare metal...) or in your private cloud (AWS, GCP, Azure...).
|
||||
</Callout>
|
||||
|
||||
The API follows and extends OpenAI API standard, and supports both normal and streaming responses.
|
||||
That means that, if you can use OpenAI API in one of your tools, you can use your own PrivateGPT API instead,
|
||||
with no code changes, **and for free** if you are running PrivateGPT in a `local` setup.
|
||||
|
|
|
|||
|
|
@ -1,122 +0,0 @@
|
|||
# List of working LLM
|
||||
|
||||
**Do you have any working combination of LLM and embeddings?**
|
||||
|
||||
Please open a PR to add it to the list, and come on our Discord to tell us about it!
|
||||
|
||||
## Prompt style
|
||||
|
||||
LLMs might have been trained with different prompt styles.
|
||||
The prompt style is the way the prompt is written, and how the system message is injected in the prompt.
|
||||
|
||||
For example, `llama2` looks like this:
|
||||
```text
|
||||
<s>[INST] <<SYS>>
|
||||
{{ system_prompt }}
|
||||
<</SYS>>
|
||||
|
||||
{{ user_message }} [/INST]
|
||||
```
|
||||
|
||||
While `default` (the `llama_index` default) looks like this:
|
||||
```text
|
||||
system: {{ system_prompt }}
|
||||
user: {{ user_message }}
|
||||
assistant: {{ assistant_message }}
|
||||
```
|
||||
|
||||
The "`tag`" style looks like this:
|
||||
|
||||
```text
|
||||
<|system|>: {{ system_prompt }}
|
||||
<|user|>: {{ user_message }}
|
||||
<|assistant|>: {{ assistant_message }}
|
||||
```
|
||||
|
||||
The "`mistral`" style looks like this:
|
||||
|
||||
```text
|
||||
<s>[INST] You are an AI assistant. [/INST]</s>[INST] Hello, how are you doing? [/INST]
|
||||
```
|
||||
|
||||
The "`chatml`" style looks like this:
|
||||
```text
|
||||
<|im_start|>system
|
||||
{{ system_prompt }}<|im_end|>
|
||||
<|im_start|>user"
|
||||
{{ user_message }}<|im_end|>
|
||||
<|im_start|>assistant
|
||||
{{ assistant_message }}
|
||||
```
|
||||
|
||||
Some LLMs will not understand these prompt styles, and will not work (returning nothing).
|
||||
You can try to change the prompt style to `default` (or `tag`) in the settings, and it will
|
||||
change the way the messages are formatted to be passed to the LLM.
|
||||
|
||||
## Example of configuration
|
||||
|
||||
You might want to change the prompt depending on the language and model you are using.
|
||||
|
||||
### English, with instructions
|
||||
|
||||
`settings-en.yaml`:
|
||||
```yml
|
||||
local:
|
||||
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.1-GGUF
|
||||
llm_hf_model_file: mistral-7b-instruct-v0.1.Q4_K_M.gguf
|
||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||
prompt_style: "llama2"
|
||||
```
|
||||
|
||||
### French, with instructions
|
||||
|
||||
`settings-fr.yaml`:
|
||||
```yml
|
||||
local:
|
||||
llm_hf_repo_id: TheBloke/Vigogne-2-7B-Instruct-GGUF
|
||||
llm_hf_model_file: vigogne-2-7b-instruct.Q4_K_M.gguf
|
||||
embedding_hf_model_name: dangvantuan/sentence-camembert-base
|
||||
prompt_style: "default"
|
||||
# prompt_style: "tag" # also works
|
||||
# The default system prompt is injected only when the `prompt_style` != default, and there are no system message in the discussion
|
||||
# default_system_prompt: Vous êtes un assistant IA qui répond à la question posée à la fin en utilisant le contexte suivant. Si vous ne connaissez pas la réponse, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse. Veuillez répondre exclusivement en français.
|
||||
```
|
||||
|
||||
You might want to change the prompt as the one above might not directly answer your question.
|
||||
You can read online about how to write a good prompt, but in a nutshell, make it (extremely) directive.
|
||||
|
||||
You can try and troubleshot your prompt by writing multiline requests in the UI, while
|
||||
writing your interaction with the model, for example:
|
||||
|
||||
```text
|
||||
Tu es un programmeur senior qui programme en python et utilise le framework fastapi. Ecrit moi un serveur qui retourne "hello world".
|
||||
```
|
||||
|
||||
Another example:
|
||||
```text
|
||||
Context: None
|
||||
Situation: tu es au milieu d'un champ.
|
||||
Tache: va a la rivière, en bas du champ.
|
||||
Décrit comment aller a la rivière.
|
||||
```
|
||||
|
||||
### Optimised Models
|
||||
GodziLLa2-70B LLM (English, rank 2 on HuggingFace OpenLLM Leaderboard), bge large Embedding Model (rank 1 on HuggingFace MTEB Leaderboard)
|
||||
`settings-optimised.yaml`:
|
||||
```yml
|
||||
local:
|
||||
llm_hf_repo_id: TheBloke/GodziLLa2-70B-GGUF
|
||||
llm_hf_model_file: godzilla2-70b.Q4_K_M.gguf
|
||||
embedding_hf_model_name: BAAI/bge-large-en
|
||||
prompt_style: "llama2"
|
||||
```
|
||||
### German speaking model
|
||||
`settings-de.yaml`:
|
||||
```yml
|
||||
local:
|
||||
llm_hf_repo_id: TheBloke/em_german_leo_mistral-GGUF
|
||||
llm_hf_model_file: em_german_leo_mistral.Q4_K_M.gguf
|
||||
embedding_hf_model_name: T-Systems-onsite/german-roberta-sentence-transformer-v2
|
||||
#llama, default or tag
|
||||
prompt_style: "default"
|
||||
```
|
||||
23
fern/docs/pages/recipes/quickstart.mdx
Normal file
23
fern/docs/pages/recipes/quickstart.mdx
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Recipes
|
||||
|
||||
Recipes are predefined use cases that help users solve very specific tasks using PrivateGPT.
|
||||
They provide a streamlined approach to achieve common goals with the platform, offering both a starting point and inspiration for further exploration.
|
||||
The main goal of Recipes is to empower the community to create and share solutions, expanding the capabilities of PrivateGPT.
|
||||
|
||||
## How to Create a New Recipe
|
||||
|
||||
1. **Identify the Task**: Define a specific task or problem that the Recipe will address.
|
||||
2. **Develop the Solution**: Create a clear and concise guide, including any necessary code snippets or configurations.
|
||||
3. **Submit a PR**: Fork the PrivateGPT repository, add your Recipe to the appropriate section, and submit a PR for review.
|
||||
|
||||
We encourage you to be creative and think outside the box! Your contributions help shape the future of PrivateGPT.
|
||||
|
||||
## Available Recipes
|
||||
|
||||
<Cards>
|
||||
<Card
|
||||
title="Summarize"
|
||||
icon="fa-solid fa-file-alt"
|
||||
href="/recipes/general-use-cases/summarize"
|
||||
/>
|
||||
</Cards>
|
||||
20
fern/docs/pages/recipes/summarize.mdx
Normal file
20
fern/docs/pages/recipes/summarize.mdx
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
The Summarize Recipe provides a method to extract concise summaries from ingested documents or texts using PrivateGPT.
|
||||
This tool is particularly useful for quickly understanding large volumes of information by distilling key points and main ideas.
|
||||
|
||||
## Use Case
|
||||
|
||||
The primary use case for the `Summarize` tool is to automate the summarization of lengthy documents,
|
||||
making it easier for users to grasp the essential information without reading through entire texts.
|
||||
This can be applied in various scenarios, such as summarizing research papers, news articles, or business reports.
|
||||
|
||||
## Key Features
|
||||
|
||||
1. **Ingestion-compatible**: The user provides the text to be summarized. The text can be directly inputted or retrieved from ingested documents within the system.
|
||||
2. **Customization**: The summary generation can be influenced by providing specific `instructions` or a `prompt`. These inputs guide the model on how to frame the summary, allowing for customization according to user needs.
|
||||
3. **Streaming Support**: The tool supports streaming, allowing for real-time summary generation, which can be particularly useful for handling large texts or providing immediate feedback.
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have ideas for improving the Summarize or want to add new features, feel free to contribute!
|
||||
You can submit your enhancements via a pull request on our [GitHub repository](https://github.com/zylon-ai/private-gpt).
|
||||
|
||||
|
|
@ -339,6 +339,48 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"/v1/summarize": {
|
||||
"post": {
|
||||
"tags": [
|
||||
"Recipes"
|
||||
],
|
||||
"summary": "Summarize",
|
||||
"description": "Given a text, the model will return a summary.\n\nOptionally include `instructions` to influence the way the summary is generated.\n\nIf `use_context`\nis set to `true`, the model will also use the content coming from the ingested\ndocuments in the summary. The documents being used can\nbe filtered by their metadata using the `context_filter`.\nIngested documents metadata can be found using `/ingest/list` endpoint.\nIf you want all ingested documents to be used, remove `context_filter` altogether.\n\nIf `prompt` is set, it will be used as the prompt for the summarization,\notherwise the default prompt will be used.\n\nWhen using `'stream': true`, the API will return data chunks following [OpenAI's\nstreaming model](https://platform.openai.com/docs/api-reference/chat/streaming):\n```\n{\"id\":\"12345\",\"object\":\"completion.chunk\",\"created\":1694268190,\n\"model\":\"private-gpt\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\n\"finish_reason\":null}]}\n```",
|
||||
"operationId": "summarize_v1_summarize_post",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SummarizeBody"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SummarizeResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/embeddings": {
|
||||
"post": {
|
||||
"tags": [
|
||||
|
|
@ -500,6 +542,10 @@
|
|||
"Chunk": {
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"context.chunk"
|
||||
],
|
||||
"const": "context.chunk",
|
||||
"title": "Object"
|
||||
},
|
||||
|
|
@ -612,10 +658,18 @@
|
|||
"ChunksResponse": {
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"list"
|
||||
],
|
||||
"const": "list",
|
||||
"title": "Object"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"private-gpt"
|
||||
],
|
||||
"const": "private-gpt",
|
||||
"title": "Model"
|
||||
},
|
||||
|
|
@ -728,6 +782,10 @@
|
|||
"title": "Index"
|
||||
},
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"embedding"
|
||||
],
|
||||
"const": "embedding",
|
||||
"title": "Object"
|
||||
},
|
||||
|
|
@ -779,10 +837,18 @@
|
|||
"EmbeddingsResponse": {
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"list"
|
||||
],
|
||||
"const": "list",
|
||||
"title": "Object"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"private-gpt"
|
||||
],
|
||||
"const": "private-gpt",
|
||||
"title": "Model"
|
||||
},
|
||||
|
|
@ -818,6 +884,10 @@
|
|||
"HealthResponse": {
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"ok"
|
||||
],
|
||||
"const": "ok",
|
||||
"title": "Status",
|
||||
"default": "ok"
|
||||
|
|
@ -829,10 +899,18 @@
|
|||
"IngestResponse": {
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"list"
|
||||
],
|
||||
"const": "list",
|
||||
"title": "Object"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"private-gpt"
|
||||
],
|
||||
"const": "private-gpt",
|
||||
"title": "Model"
|
||||
},
|
||||
|
|
@ -879,6 +957,10 @@
|
|||
"IngestedDoc": {
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"ingest.document"
|
||||
],
|
||||
"const": "ingest.document",
|
||||
"title": "Object"
|
||||
},
|
||||
|
|
@ -1001,6 +1083,10 @@
|
|||
]
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"private-gpt"
|
||||
],
|
||||
"const": "private-gpt",
|
||||
"title": "Model"
|
||||
},
|
||||
|
|
@ -1074,6 +1160,78 @@
|
|||
"title": "OpenAIMessage",
|
||||
"description": "Inference result, with the source of the message.\n\nRole could be the assistant or system\n(providing a default response, not AI generated)."
|
||||
},
|
||||
"SummarizeBody": {
|
||||
"properties": {
|
||||
"text": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Text"
|
||||
},
|
||||
"use_context": {
|
||||
"type": "boolean",
|
||||
"title": "Use Context",
|
||||
"default": false
|
||||
},
|
||||
"context_filter": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/ContextFilter"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"prompt": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Prompt"
|
||||
},
|
||||
"instructions": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Instructions"
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean",
|
||||
"title": "Stream",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "SummarizeBody"
|
||||
},
|
||||
"SummarizeResponse": {
|
||||
"properties": {
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"title": "Summary"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"summary"
|
||||
],
|
||||
"title": "SummarizeResponse"
|
||||
},
|
||||
"ValidationError": {
|
||||
"properties": {
|
||||
"loc": {
|
||||
|
|
|
|||
4640
poetry.lock
generated
4640
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -31,6 +31,7 @@ class EmbeddingComponent:
|
|||
self.embedding_model = HuggingFaceEmbedding(
|
||||
model_name=settings.huggingface.embedding_hf_model_name,
|
||||
cache_folder=str(models_cache_path),
|
||||
trust_remote_code=settings.huggingface.trust_remote_code,
|
||||
)
|
||||
case "sagemaker":
|
||||
try:
|
||||
|
|
@ -71,16 +72,46 @@ class EmbeddingComponent:
|
|||
from llama_index.embeddings.ollama import ( # type: ignore
|
||||
OllamaEmbedding,
|
||||
)
|
||||
from ollama import Client # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Local dependencies not found, install with `poetry install --extras embeddings-ollama`"
|
||||
) from e
|
||||
|
||||
ollama_settings = settings.ollama
|
||||
|
||||
# Calculate embedding model. If not provided tag, it will be use latest
|
||||
model_name = (
|
||||
ollama_settings.embedding_model + ":latest"
|
||||
if ":" not in ollama_settings.embedding_model
|
||||
else ollama_settings.embedding_model
|
||||
)
|
||||
|
||||
self.embedding_model = OllamaEmbedding(
|
||||
model_name=ollama_settings.embedding_model,
|
||||
model_name=model_name,
|
||||
base_url=ollama_settings.embedding_api_base,
|
||||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import (
|
||||
check_connection,
|
||||
pull_model,
|
||||
)
|
||||
|
||||
# TODO: Reuse llama-index client when llama-index is updated
|
||||
client = Client(
|
||||
host=ollama_settings.embedding_api_base,
|
||||
timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if not check_connection(client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(client, model_name)
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.embeddings.azure_openai import ( # type: ignore
|
||||
|
|
|
|||
|
|
@ -146,8 +146,15 @@ class LLMComponent:
|
|||
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
|
||||
}
|
||||
|
||||
self.llm = Ollama(
|
||||
model=ollama_settings.llm_model,
|
||||
# calculate llm model. If not provided tag, it will be use latest
|
||||
model_name = (
|
||||
ollama_settings.llm_model + ":latest"
|
||||
if ":" not in ollama_settings.llm_model
|
||||
else ollama_settings.llm_model
|
||||
)
|
||||
|
||||
llm = Ollama(
|
||||
model=model_name,
|
||||
base_url=ollama_settings.api_base,
|
||||
temperature=settings.llm.temperature,
|
||||
context_window=settings.llm.context_window,
|
||||
|
|
@ -155,6 +162,16 @@ class LLMComponent:
|
|||
request_timeout=ollama_settings.request_timeout,
|
||||
)
|
||||
|
||||
if ollama_settings.autopull_models:
|
||||
from private_gpt.utils.ollama import check_connection, pull_model
|
||||
|
||||
if not check_connection(llm.client):
|
||||
raise ValueError(
|
||||
f"Failed to connect to Ollama, "
|
||||
f"check if Ollama server is running on {ollama_settings.api_base}"
|
||||
)
|
||||
pull_model(llm.client, model_name)
|
||||
|
||||
if (
|
||||
ollama_settings.keep_alive
|
||||
!= ollama_settings.model_fields["keep_alive"].default
|
||||
|
|
@ -172,6 +189,8 @@ class LLMComponent:
|
|||
Ollama.complete = add_keep_alive(Ollama.complete)
|
||||
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete)
|
||||
|
||||
self.llm = llm
|
||||
|
||||
case "azopenai":
|
||||
try:
|
||||
from llama_index.llms.azure_openai import ( # type: ignore
|
||||
|
|
|
|||
|
|
@ -138,6 +138,72 @@ class Llama2PromptStyle(AbstractPromptStyle):
|
|||
)
|
||||
|
||||
|
||||
class Llama3PromptStyle(AbstractPromptStyle):
|
||||
r"""Template for Meta's Llama 3.1.
|
||||
|
||||
The format follows this structure:
|
||||
<|begin_of_text|>
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
[System message content]<|eot_id|>
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
[User message content]<|eot_id|>
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
[Assistant message content]<|eot_id|>
|
||||
...
|
||||
(Repeat for each message, including possible 'ipython' role)
|
||||
"""
|
||||
|
||||
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
|
||||
B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>"
|
||||
EOT = "<|eot_id|>"
|
||||
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>"
|
||||
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. \
|
||||
Always answer as helpfully as possible and follow ALL given instructions. \
|
||||
Do not speculate or make up information. \
|
||||
Do not reference any given instructions or context. \
|
||||
"""
|
||||
|
||||
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
|
||||
prompt = ""
|
||||
has_system_message = False
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if not message or message.content is None:
|
||||
continue
|
||||
if message.role == MessageRole.SYSTEM:
|
||||
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
|
||||
has_system_message = True
|
||||
else:
|
||||
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
|
||||
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"
|
||||
|
||||
# Add assistant header if the last message is not from the assistant
|
||||
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
|
||||
prompt += f"{self.ASSISTANT_INST}\n\n"
|
||||
|
||||
# Add default system prompt if no system message was provided
|
||||
if not has_system_message:
|
||||
prompt = (
|
||||
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + prompt
|
||||
)
|
||||
|
||||
# TODO: Implement tool handling logic
|
||||
|
||||
return prompt
|
||||
|
||||
def _completion_to_prompt(self, completion: str) -> str:
|
||||
return (
|
||||
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
|
||||
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
|
||||
f"{self.ASSISTANT_INST}\n\n"
|
||||
)
|
||||
|
||||
|
||||
class TagPromptStyle(AbstractPromptStyle):
|
||||
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
|
||||
|
||||
|
|
@ -219,7 +285,8 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
|||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
||||
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"]
|
||||
| None
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
|
|
@ -230,6 +297,8 @@ def get_prompt_style(
|
|||
return DefaultPromptStyle()
|
||||
elif prompt_style == "llama2":
|
||||
return Llama2PromptStyle()
|
||||
elif prompt_style == "llama3":
|
||||
return Llama3PromptStyle()
|
||||
elif prompt_style == "tag":
|
||||
return TagPromptStyle()
|
||||
elif prompt_style == "mistral":
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from private_gpt.server.completions.completions_router import completions_router
|
|||
from private_gpt.server.embeddings.embeddings_router import embeddings_router
|
||||
from private_gpt.server.health.health_router import health_router
|
||||
from private_gpt.server.ingest.ingest_router import ingest_router
|
||||
from private_gpt.server.recipes.summarize.summarize_router import summarize_router
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -32,12 +33,14 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|||
app.include_router(chat_router)
|
||||
app.include_router(chunks_router)
|
||||
app.include_router(ingest_router)
|
||||
app.include_router(summarize_router)
|
||||
app.include_router(embeddings_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
# Add LlamaIndex simple observability
|
||||
global_handler = create_global_handler("simple")
|
||||
LlamaIndexSettings.callback_manager = CallbackManager([global_handler])
|
||||
if global_handler:
|
||||
LlamaIndexSettings.callback_manager = CallbackManager([global_handler])
|
||||
|
||||
settings = root_injector.get(Settings)
|
||||
if settings.server.cors.enabled:
|
||||
|
|
|
|||
0
private_gpt/server/recipes/summarize/__init__.py
Normal file
0
private_gpt/server/recipes/summarize/__init__.py
Normal file
86
private_gpt/server/recipes/summarize/summarize_router.py
Normal file
86
private_gpt/server/recipes/summarize/summarize_router.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.open_ai.openai_models import (
|
||||
to_openai_sse_stream,
|
||||
)
|
||||
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
summarize_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class SummarizeBody(BaseModel):
|
||||
text: str | None = None
|
||||
use_context: bool = False
|
||||
context_filter: ContextFilter | None = None
|
||||
prompt: str | None = None
|
||||
instructions: str | None = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class SummarizeResponse(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
@summarize_router.post(
|
||||
"/summarize",
|
||||
response_model=None,
|
||||
summary="Summarize",
|
||||
responses={200: {"model": SummarizeResponse}},
|
||||
tags=["Recipes"],
|
||||
)
|
||||
def summarize(
|
||||
request: Request, body: SummarizeBody
|
||||
) -> SummarizeResponse | StreamingResponse:
|
||||
"""Given a text, the model will return a summary.
|
||||
|
||||
Optionally include `instructions` to influence the way the summary is generated.
|
||||
|
||||
If `use_context`
|
||||
is set to `true`, the model will also use the content coming from the ingested
|
||||
documents in the summary. The documents being used can
|
||||
be filtered by their metadata using the `context_filter`.
|
||||
Ingested documents metadata can be found using `/ingest/list` endpoint.
|
||||
If you want all ingested documents to be used, remove `context_filter` altogether.
|
||||
|
||||
If `prompt` is set, it will be used as the prompt for the summarization,
|
||||
otherwise the default prompt will be used.
|
||||
|
||||
When using `'stream': true`, the API will return data chunks following [OpenAI's
|
||||
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
|
||||
```
|
||||
{"id":"12345","object":"completion.chunk","created":1694268190,
|
||||
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
|
||||
"finish_reason":null}]}
|
||||
```
|
||||
"""
|
||||
service: SummarizeService = request.state.injector.get(SummarizeService)
|
||||
|
||||
if body.stream:
|
||||
completion_gen = service.stream_summarize(
|
||||
text=body.text,
|
||||
instructions=body.instructions,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
prompt=body.prompt,
|
||||
)
|
||||
return StreamingResponse(
|
||||
to_openai_sse_stream(
|
||||
response_generator=completion_gen,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
completion = service.summarize(
|
||||
text=body.text,
|
||||
instructions=body.instructions,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
prompt=body.prompt,
|
||||
)
|
||||
return SummarizeResponse(
|
||||
summary=completion,
|
||||
)
|
||||
172
private_gpt/server/recipes/summarize/summarize_service.py
Normal file
172
private_gpt/server/recipes/summarize/summarize_service.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
from itertools import chain
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core import (
|
||||
Document,
|
||||
StorageContext,
|
||||
SummaryIndex,
|
||||
)
|
||||
from llama_index.core.base.response.schema import Response, StreamingResponse
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.core.storage.docstore.types import RefDocInfo
|
||||
from llama_index.core.types import TokenGen
|
||||
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
from private_gpt.components.vector_store.vector_store_component import (
|
||||
VectorStoreComponent,
|
||||
)
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
DEFAULT_SUMMARIZE_PROMPT = (
|
||||
"Provide a comprehensive summary of the provided context information. "
|
||||
"The summary should cover all the key points and main ideas presented in "
|
||||
"the original text, while also condensing the information into a concise "
|
||||
"and easy-to-understand format. Please ensure that the summary includes "
|
||||
"relevant details and examples that support the main ideas, while avoiding "
|
||||
"any unnecessary information or repetition."
|
||||
)
|
||||
|
||||
|
||||
@singleton
|
||||
class SummarizeService:
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
llm_component: LLMComponent,
|
||||
node_store_component: NodeStoreComponent,
|
||||
vector_store_component: VectorStoreComponent,
|
||||
embedding_component: EmbeddingComponent,
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.llm_component = llm_component
|
||||
self.node_store_component = node_store_component
|
||||
self.vector_store_component = vector_store_component
|
||||
self.embedding_component = embedding_component
|
||||
self.storage_context = StorageContext.from_defaults(
|
||||
vector_store=vector_store_component.vector_store,
|
||||
docstore=node_store_component.doc_store,
|
||||
index_store=node_store_component.index_store,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_ref_docs(
|
||||
ref_docs: dict[str, RefDocInfo], context_filter: ContextFilter | None
|
||||
) -> list[RefDocInfo]:
|
||||
if context_filter is None or not context_filter.docs_ids:
|
||||
return list(ref_docs.values())
|
||||
|
||||
return [
|
||||
ref_doc
|
||||
for doc_id, ref_doc in ref_docs.items()
|
||||
if doc_id in context_filter.docs_ids
|
||||
]
|
||||
|
||||
def _summarize(
|
||||
self,
|
||||
use_context: bool = False,
|
||||
stream: bool = False,
|
||||
text: str | None = None,
|
||||
instructions: str | None = None,
|
||||
context_filter: ContextFilter | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str | TokenGen:
|
||||
|
||||
nodes_to_summarize = []
|
||||
|
||||
# Add text to summarize
|
||||
if text:
|
||||
text_documents = [Document(text=text)]
|
||||
nodes_to_summarize += (
|
||||
SentenceSplitter.from_defaults().get_nodes_from_documents(
|
||||
text_documents
|
||||
)
|
||||
)
|
||||
|
||||
# Add context documents to summarize
|
||||
if use_context:
|
||||
# 1. Recover all ref docs
|
||||
ref_docs: dict[
|
||||
str, RefDocInfo
|
||||
] | None = self.storage_context.docstore.get_all_ref_doc_info()
|
||||
if ref_docs is None:
|
||||
raise ValueError("No documents have been ingested yet.")
|
||||
|
||||
# 2. Filter documents based on context_filter (if provided)
|
||||
filtered_ref_docs = self._filter_ref_docs(ref_docs, context_filter)
|
||||
|
||||
# 3. Get all nodes from the filtered documents
|
||||
filtered_node_ids = chain.from_iterable(
|
||||
[ref_doc.node_ids for ref_doc in filtered_ref_docs]
|
||||
)
|
||||
filtered_nodes = self.storage_context.docstore.get_nodes(
|
||||
node_ids=list(filtered_node_ids),
|
||||
)
|
||||
|
||||
nodes_to_summarize += filtered_nodes
|
||||
|
||||
# Create a SummaryIndex to summarize the nodes
|
||||
summary_index = SummaryIndex(
|
||||
nodes=nodes_to_summarize,
|
||||
storage_context=StorageContext.from_defaults(), # In memory SummaryIndex
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
# Make a tree summarization query
|
||||
# above the set of all candidate nodes
|
||||
query_engine = summary_index.as_query_engine(
|
||||
llm=self.llm_component.llm,
|
||||
response_mode=ResponseMode.TREE_SUMMARIZE,
|
||||
streaming=stream,
|
||||
use_async=self.settings.summarize.use_async,
|
||||
)
|
||||
|
||||
prompt = prompt or DEFAULT_SUMMARIZE_PROMPT
|
||||
|
||||
summarize_query = prompt + "\n" + (instructions or "")
|
||||
|
||||
response = query_engine.query(summarize_query)
|
||||
if isinstance(response, Response):
|
||||
return response.response or ""
|
||||
elif isinstance(response, StreamingResponse):
|
||||
return response.response_gen
|
||||
else:
|
||||
raise TypeError(f"The result is not of a supported type: {type(response)}")
|
||||
|
||||
def summarize(
|
||||
self,
|
||||
use_context: bool = False,
|
||||
text: str | None = None,
|
||||
instructions: str | None = None,
|
||||
context_filter: ContextFilter | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
return self._summarize(
|
||||
use_context=use_context,
|
||||
stream=False,
|
||||
text=text,
|
||||
instructions=instructions,
|
||||
context_filter=context_filter,
|
||||
prompt=prompt,
|
||||
) # type: ignore
|
||||
|
||||
def stream_summarize(
|
||||
self,
|
||||
use_context: bool = False,
|
||||
text: str | None = None,
|
||||
instructions: str | None = None,
|
||||
context_filter: ContextFilter | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> TokenGen:
|
||||
return self._summarize(
|
||||
use_context=use_context,
|
||||
stream=True,
|
||||
text=text,
|
||||
instructions=instructions,
|
||||
context_filter=context_filter,
|
||||
prompt=prompt,
|
||||
) # type: ignore
|
||||
|
|
@ -59,6 +59,27 @@ class AuthSettings(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class IngestionSettings(BaseModel):
|
||||
"""Ingestion configuration.
|
||||
|
||||
This configuration is used to control the ingestion of data into the system
|
||||
using non-server methods. This is useful for local development and testing;
|
||||
or to ingest in bulk from a folder.
|
||||
|
||||
Please note that this configuration is not secure and should be used in
|
||||
a controlled environment only (setting right permissions, etc.).
|
||||
"""
|
||||
|
||||
enabled: bool = Field(
|
||||
description="Flag indicating if local ingestion is enabled or not.",
|
||||
default=False,
|
||||
)
|
||||
allow_ingest_from: list[str] = Field(
|
||||
description="A list of folders that should be permitted to make ingest requests.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
|
||||
class ServerSettings(BaseModel):
|
||||
env_name: str = Field(
|
||||
description="Name of the environment (prod, staging, local...)"
|
||||
|
|
@ -74,6 +95,10 @@ class ServerSettings(BaseModel):
|
|||
|
||||
|
||||
class DataSettings(BaseModel):
|
||||
local_ingestion: IngestionSettings = Field(
|
||||
description="Ingestion configuration",
|
||||
default_factory=lambda: IngestionSettings(allow_ingest_from=["*"]),
|
||||
)
|
||||
local_data_folder: str = Field(
|
||||
description="Path to local storage."
|
||||
"It will be treated as an absolute path if it starts with /"
|
||||
|
|
@ -111,12 +136,15 @@ class LLMSettings(BaseModel):
|
|||
0.1,
|
||||
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
|
||||
)
|
||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] = Field(
|
||||
prompt_style: Literal[
|
||||
"default", "llama2", "llama3", "tag", "mistral", "chatml"
|
||||
] = Field(
|
||||
"llama2",
|
||||
description=(
|
||||
"The prompt style to use for the chat engine. "
|
||||
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
|
||||
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
|
||||
"If `llama3` - use the llama3 prompt style from the llama_index."
|
||||
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
|
||||
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
|
||||
"`llama2` is the historic behaviour. `default` might work better with your custom models."
|
||||
|
|
@ -161,6 +189,10 @@ class HuggingFaceSettings(BaseModel):
|
|||
None,
|
||||
description="Huggingface access token, required to download some models",
|
||||
)
|
||||
trust_remote_code: bool = Field(
|
||||
False,
|
||||
description="If set to True, the code from the remote model will be trusted and executed.",
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingSettings(BaseModel):
|
||||
|
|
@ -290,6 +322,10 @@ class OllamaSettings(BaseModel):
|
|||
120.0,
|
||||
description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ",
|
||||
)
|
||||
autopull_models: bool = Field(
|
||||
False,
|
||||
description="If set to True, the Ollama will automatically pull the models from the API base.",
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAISettings(BaseModel):
|
||||
|
|
@ -321,6 +357,10 @@ class UISettings(BaseModel):
|
|||
default_query_system_prompt: str = Field(
|
||||
None, description="The default system prompt to use for the query mode."
|
||||
)
|
||||
default_summarization_system_prompt: str = Field(
|
||||
None,
|
||||
description="The default system prompt to use for the summarization mode.",
|
||||
)
|
||||
delete_file_button_enabled: bool = Field(
|
||||
True, description="If the button to delete a file is enabled or not."
|
||||
)
|
||||
|
|
@ -356,6 +396,13 @@ class RagSettings(BaseModel):
|
|||
rerank: RerankSettings
|
||||
|
||||
|
||||
class SummarizeSettings(BaseModel):
|
||||
use_async: bool = Field(
|
||||
True,
|
||||
description="If set to True, the summarization will be done asynchronously.",
|
||||
)
|
||||
|
||||
|
||||
class ClickHouseSettings(BaseModel):
|
||||
host: str = Field(
|
||||
"localhost",
|
||||
|
|
@ -545,6 +592,7 @@ class Settings(BaseModel):
|
|||
vectorstore: VectorstoreSettings
|
||||
nodestore: NodeStoreSettings
|
||||
rag: RagSettings
|
||||
summarize: SummarizeSettings
|
||||
qdrant: QdrantSettings | None = None
|
||||
postgres: PostgresSettings | None = None
|
||||
clickhouse: ClickHouseSettings | None = None
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""This file should be imported if and only if you want to run the UI locally."""
|
||||
|
||||
import itertools
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ from fastapi import FastAPI
|
|||
from gradio.themes.utils.colors import slate # type: ignore
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole
|
||||
from llama_index.core.types import TokenGen
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.constants import PROJECT_ROOT_PATH
|
||||
|
|
@ -20,6 +21,7 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|||
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
|
||||
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
|
||||
from private_gpt.settings.settings import settings
|
||||
from private_gpt.ui.images import logo_svg
|
||||
|
||||
|
|
@ -31,9 +33,22 @@ AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico"
|
|||
|
||||
UI_TAB_TITLE = "My Private GPT"
|
||||
|
||||
SOURCES_SEPARATOR = "\n\n Sources: \n"
|
||||
SOURCES_SEPARATOR = "<hr>Sources: \n"
|
||||
|
||||
MODES = ["Query Files", "Search Files", "LLM Chat (no context from files)"]
|
||||
|
||||
class Modes(str, Enum):
|
||||
RAG_MODE = "RAG"
|
||||
SEARCH_MODE = "Search"
|
||||
BASIC_CHAT_MODE = "Basic"
|
||||
SUMMARIZE_MODE = "Summarize"
|
||||
|
||||
|
||||
MODES: list[Modes] = [
|
||||
Modes.RAG_MODE,
|
||||
Modes.SEARCH_MODE,
|
||||
Modes.BASIC_CHAT_MODE,
|
||||
Modes.SUMMARIZE_MODE,
|
||||
]
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
|
|
@ -71,10 +86,12 @@ class PrivateGptUi:
|
|||
ingest_service: IngestService,
|
||||
chat_service: ChatService,
|
||||
chunks_service: ChunksService,
|
||||
summarizeService: SummarizeService,
|
||||
) -> None:
|
||||
self._ingest_service = ingest_service
|
||||
self._chat_service = chat_service
|
||||
self._chunks_service = chunks_service
|
||||
self._summarize_service = summarizeService
|
||||
|
||||
# Cache the UI blocks
|
||||
self._ui_block = None
|
||||
|
|
@ -85,7 +102,9 @@ class PrivateGptUi:
|
|||
self.mode = MODES[0]
|
||||
self._system_prompt = self._get_default_system_prompt(self.mode)
|
||||
|
||||
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
|
||||
def _chat(
|
||||
self, message: str, history: list[list[str]], mode: Modes, *_: Any
|
||||
) -> Any:
|
||||
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
|
||||
full_response: str = ""
|
||||
stream = completion_gen.response
|
||||
|
|
@ -109,25 +128,31 @@ class PrivateGptUi:
|
|||
+ f"{index}. {source.file} (page {source.page}) \n\n"
|
||||
)
|
||||
used_files.add(f"{source.file}-{source.page}")
|
||||
sources_text += "<hr>\n\n"
|
||||
full_response += sources_text
|
||||
yield full_response
|
||||
|
||||
def yield_tokens(token_gen: TokenGen) -> Iterable[str]:
|
||||
full_response: str = ""
|
||||
for token in token_gen:
|
||||
full_response += str(token)
|
||||
yield full_response
|
||||
|
||||
def build_history() -> list[ChatMessage]:
|
||||
history_messages: list[ChatMessage] = list(
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
ChatMessage(content=interaction[0], role=MessageRole.USER),
|
||||
ChatMessage(
|
||||
# Remove from history content the Sources information
|
||||
content=interaction[1].split(SOURCES_SEPARATOR)[0],
|
||||
role=MessageRole.ASSISTANT,
|
||||
),
|
||||
]
|
||||
for interaction in history
|
||||
]
|
||||
history_messages: list[ChatMessage] = []
|
||||
|
||||
for interaction in history:
|
||||
history_messages.append(
|
||||
ChatMessage(content=interaction[0], role=MessageRole.USER)
|
||||
)
|
||||
)
|
||||
if len(interaction) > 1 and interaction[1] is not None:
|
||||
history_messages.append(
|
||||
ChatMessage(
|
||||
# Remove from history content the Sources information
|
||||
content=interaction[1].split(SOURCES_SEPARATOR)[0],
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
# max 20 messages to try to avoid context overflow
|
||||
return history_messages[:20]
|
||||
|
|
@ -144,8 +169,7 @@ class PrivateGptUi:
|
|||
),
|
||||
)
|
||||
match mode:
|
||||
case "Query Files":
|
||||
|
||||
case Modes.RAG_MODE:
|
||||
# Use only the selected file for the query
|
||||
context_filter = None
|
||||
if self._selected_filename is not None:
|
||||
|
|
@ -164,14 +188,14 @@ class PrivateGptUi:
|
|||
context_filter=context_filter,
|
||||
)
|
||||
yield from yield_deltas(query_stream)
|
||||
case "LLM Chat (no context from files)":
|
||||
case Modes.BASIC_CHAT_MODE:
|
||||
llm_stream = self._chat_service.stream_chat(
|
||||
messages=all_messages,
|
||||
use_context=False,
|
||||
)
|
||||
yield from yield_deltas(llm_stream)
|
||||
|
||||
case "Search Files":
|
||||
case Modes.SEARCH_MODE:
|
||||
response = self._chunks_service.retrieve_relevant(
|
||||
text=message, limit=4, prev_next_chunks=0
|
||||
)
|
||||
|
|
@ -184,37 +208,76 @@ class PrivateGptUi:
|
|||
f"{source.text}"
|
||||
for index, source in enumerate(sources, start=1)
|
||||
)
|
||||
case Modes.SUMMARIZE_MODE:
|
||||
# Summarize the given message, optionally using selected files
|
||||
context_filter = None
|
||||
if self._selected_filename:
|
||||
docs_ids = []
|
||||
for ingested_document in self._ingest_service.list_ingested():
|
||||
if (
|
||||
ingested_document.doc_metadata["file_name"]
|
||||
== self._selected_filename
|
||||
):
|
||||
docs_ids.append(ingested_document.doc_id)
|
||||
context_filter = ContextFilter(docs_ids=docs_ids)
|
||||
|
||||
summary_stream = self._summarize_service.stream_summarize(
|
||||
use_context=True,
|
||||
context_filter=context_filter,
|
||||
instructions=message,
|
||||
)
|
||||
yield from yield_tokens(summary_stream)
|
||||
|
||||
# On initialization and on mode change, this function set the system prompt
|
||||
# to the default prompt based on the mode (and user settings).
|
||||
@staticmethod
|
||||
def _get_default_system_prompt(mode: str) -> str:
|
||||
def _get_default_system_prompt(mode: Modes) -> str:
|
||||
p = ""
|
||||
match mode:
|
||||
# For query chat mode, obtain default system prompt from settings
|
||||
case "Query Files":
|
||||
case Modes.RAG_MODE:
|
||||
p = settings().ui.default_query_system_prompt
|
||||
# For chat mode, obtain default system prompt from settings
|
||||
case "LLM Chat (no context from files)":
|
||||
case Modes.BASIC_CHAT_MODE:
|
||||
p = settings().ui.default_chat_system_prompt
|
||||
# For summarization mode, obtain default system prompt from settings
|
||||
case Modes.SUMMARIZE_MODE:
|
||||
p = settings().ui.default_summarization_system_prompt
|
||||
# For any other mode, clear the system prompt
|
||||
case _:
|
||||
p = ""
|
||||
return p
|
||||
|
||||
@staticmethod
|
||||
def _get_default_mode_explanation(mode: Modes) -> str:
|
||||
match mode:
|
||||
case Modes.RAG_MODE:
|
||||
return "Get contextualized answers from selected files."
|
||||
case Modes.SEARCH_MODE:
|
||||
return "Find relevant chunks of text in selected files."
|
||||
case Modes.BASIC_CHAT_MODE:
|
||||
return "Chat with the LLM using its training data. Files are ignored."
|
||||
case Modes.SUMMARIZE_MODE:
|
||||
return "Generate a summary of the selected files. Prompt to customize the result."
|
||||
case _:
|
||||
return ""
|
||||
|
||||
def _set_system_prompt(self, system_prompt_input: str) -> None:
|
||||
logger.info(f"Setting system prompt to: {system_prompt_input}")
|
||||
self._system_prompt = system_prompt_input
|
||||
|
||||
def _set_current_mode(self, mode: str) -> Any:
|
||||
def _set_explanatation_mode(self, explanation_mode: str) -> None:
|
||||
self._explanation_mode = explanation_mode
|
||||
|
||||
def _set_current_mode(self, mode: Modes) -> Any:
|
||||
self.mode = mode
|
||||
self._set_system_prompt(self._get_default_system_prompt(mode))
|
||||
# Update placeholder and allow interaction if default system prompt is set
|
||||
if self._system_prompt:
|
||||
return gr.update(placeholder=self._system_prompt, interactive=True)
|
||||
# Update placeholder and disable interaction if no default system prompt is set
|
||||
else:
|
||||
return gr.update(placeholder=self._system_prompt, interactive=False)
|
||||
self._set_explanatation_mode(self._get_default_mode_explanation(mode))
|
||||
interactive = self._system_prompt is not None
|
||||
return [
|
||||
gr.update(placeholder=self._system_prompt, interactive=interactive),
|
||||
gr.update(value=self._explanation_mode),
|
||||
]
|
||||
|
||||
def _list_ingested_files(self) -> list[list[str]]:
|
||||
files = set()
|
||||
|
|
@ -314,17 +377,30 @@ class PrivateGptUi:
|
|||
".contain { display: flex !important; flex-direction: column !important; }"
|
||||
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
|
||||
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
|
||||
"#col { height: calc(100vh - 112px - 16px) !important; }",
|
||||
"#col { height: calc(100vh - 112px - 16px) !important; }"
|
||||
"hr { margin-top: 1em; margin-bottom: 1em; border: 0; border-top: 1px solid #FFF; }"
|
||||
".avatar-image { background-color: antiquewhite; border-radius: 2px; }"
|
||||
".footer { text-align: center; margin-top: 20px; font-size: 14px; display: flex; align-items: center; justify-content: center; }"
|
||||
".footer-zylon-link { display:flex; margin-left: 5px; text-decoration: auto; color: var(--body-text-color); }"
|
||||
".footer-zylon-link:hover { color: #C7BAFF; }"
|
||||
".footer-zylon-ico { height: 20px; margin-left: 5px; background-color: antiquewhite; border-radius: 2px; }",
|
||||
) as blocks:
|
||||
with gr.Row():
|
||||
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
|
||||
|
||||
with gr.Row(equal_height=False):
|
||||
with gr.Column(scale=3):
|
||||
default_mode = MODES[0]
|
||||
mode = gr.Radio(
|
||||
MODES,
|
||||
[mode.value for mode in MODES],
|
||||
label="Mode",
|
||||
value="Query Files",
|
||||
value=default_mode,
|
||||
)
|
||||
explanation_mode = gr.Textbox(
|
||||
placeholder=self._get_default_mode_explanation(default_mode),
|
||||
show_label=False,
|
||||
max_lines=3,
|
||||
interactive=False,
|
||||
)
|
||||
upload_button = gr.components.UploadButton(
|
||||
"Upload File(s)",
|
||||
|
|
@ -408,9 +484,11 @@ class PrivateGptUi:
|
|||
interactive=True,
|
||||
render=False,
|
||||
)
|
||||
# When mode changes, set default system prompt
|
||||
# When mode changes, set default system prompt, and other stuffs
|
||||
mode.change(
|
||||
self._set_current_mode, inputs=mode, outputs=system_prompt_input
|
||||
self._set_current_mode,
|
||||
inputs=mode,
|
||||
outputs=[system_prompt_input, explanation_mode],
|
||||
)
|
||||
# On blur, set system prompt to use in queries
|
||||
system_prompt_input.blur(
|
||||
|
|
@ -477,6 +555,14 @@ class PrivateGptUi:
|
|||
),
|
||||
additional_inputs=[mode, upload_button, system_prompt_input],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
avatar_byte = AVATAR_BOT.read_bytes()
|
||||
f_base64 = f"data:image/png;base64,{base64.b64encode(avatar_byte).decode('utf-8')}"
|
||||
gr.HTML(
|
||||
f"<div class='footer'><a class='footer-zylon-link' href='https://zylon.ai/'>Maintained by Zylon <img class='footer-zylon-ico' src='{f_base64}' alt=Zylon></a></div>"
|
||||
)
|
||||
|
||||
return blocks
|
||||
|
||||
def get_ui_blocks(self) -> gr.Blocks:
|
||||
|
|
@ -488,7 +574,7 @@ class PrivateGptUi:
|
|||
blocks = self.get_ui_blocks()
|
||||
blocks.queue()
|
||||
logger.info("Mounting the gradio UI, at path=%s", path)
|
||||
gr.mount_gradio_app(app, blocks, path=path)
|
||||
gr.mount_gradio_app(app, blocks, path=path, favicon_path=AVATAR_BOT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
80
private_gpt/utils/ollama.py
Normal file
80
private_gpt/utils/ollama.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from collections import deque
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
try:
|
||||
from ollama import Client # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`"
|
||||
) from e
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_connection(client: Client) -> bool:
|
||||
try:
|
||||
client.list()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Ollama: {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
|
||||
progress_bars = {}
|
||||
queue = deque() # type: ignore
|
||||
|
||||
def create_progress_bar(dgt: str, total: int) -> Any:
|
||||
return tqdm(
|
||||
total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True
|
||||
)
|
||||
|
||||
current_digest = None
|
||||
|
||||
for chunk in generator:
|
||||
digest = chunk.get("digest")
|
||||
completed_size = chunk.get("completed", 0)
|
||||
total_size = chunk.get("total")
|
||||
|
||||
if digest and total_size is not None:
|
||||
if digest not in progress_bars and completed_size > 0:
|
||||
progress_bars[digest] = create_progress_bar(digest, total=total_size)
|
||||
if current_digest is None:
|
||||
current_digest = digest
|
||||
else:
|
||||
queue.append(digest)
|
||||
|
||||
if digest in progress_bars:
|
||||
progress_bar = progress_bars[digest]
|
||||
progress = completed_size - progress_bar.n
|
||||
if completed_size > 0 and total_size >= progress != progress_bar.n:
|
||||
if digest == current_digest:
|
||||
progress_bar.update(progress)
|
||||
if progress_bar.n >= total_size:
|
||||
progress_bar.close()
|
||||
current_digest = queue.popleft() if queue else None
|
||||
else:
|
||||
# Store progress for later update
|
||||
progress_bars[digest].total = total_size
|
||||
progress_bars[digest].n = completed_size
|
||||
|
||||
# Close any remaining progress bars at the end
|
||||
for progress_bar in progress_bars.values():
|
||||
progress_bar.close()
|
||||
|
||||
|
||||
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
|
||||
try:
|
||||
installed_models = [model["name"] for model in client.list().get("models", {})]
|
||||
if model_name not in installed_models:
|
||||
logger.info(f"Pulling model {model_name}. Please wait...")
|
||||
process_streaming(client.pull(model_name, stream=True))
|
||||
logger.info(f"Model {model_name} pulled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pull model {model_name}: {e!s}")
|
||||
if raise_error:
|
||||
raise e
|
||||
|
|
@ -22,7 +22,7 @@ llama-index-readers-file = "^0.1.27"
|
|||
llama-index-llms-llama-cpp = {version = "^0.1.4", optional = true}
|
||||
llama-index-llms-openai = {version = "^0.1.25", optional = true}
|
||||
llama-index-llms-openai-like = {version ="^0.1.3", optional = true}
|
||||
llama-index-llms-ollama = {version ="^0.1.5", optional = true}
|
||||
llama-index-llms-ollama = {version ="^0.2.2", optional = true}
|
||||
llama-index-llms-azure-openai = {version ="^0.1.8", optional = true}
|
||||
llama-index-llms-gemini = {version ="^0.1.11", optional = true}
|
||||
llama-index-embeddings-ollama = {version ="^0.1.2", optional = true}
|
||||
|
|
@ -56,21 +56,29 @@ sentence-transformers = {version ="^3.0.1", optional = true}
|
|||
|
||||
# Optional UI
|
||||
gradio = {version ="^4.37.2", optional = true}
|
||||
# Fix: https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/16289#issuecomment-2255106490
|
||||
ffmpy = {git = "https://github.com/EuDs63/ffmpy.git", rev = "333a19ee4d21f32537c0508aa1942ef1aa7afe24", optional = true}
|
||||
|
||||
# Optional Google Gemini dependency
|
||||
google-generativeai = {version ="^0.5.4", optional = true}
|
||||
|
||||
# Optional Ollama client
|
||||
ollama = {version ="^0.3.0", optional = true}
|
||||
|
||||
# Optional HF Transformers
|
||||
einops = {version = "^0.8.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
ui = ["gradio"]
|
||||
ui = ["gradio", "ffmpy"]
|
||||
llms-llama-cpp = ["llama-index-llms-llama-cpp"]
|
||||
llms-openai = ["llama-index-llms-openai"]
|
||||
llms-openai-like = ["llama-index-llms-openai-like"]
|
||||
llms-ollama = ["llama-index-llms-ollama"]
|
||||
llms-ollama = ["llama-index-llms-ollama", "ollama"]
|
||||
llms-sagemaker = ["boto3"]
|
||||
llms-azopenai = ["llama-index-llms-azure-openai"]
|
||||
llms-gemini = ["llama-index-llms-gemini", "google-generativeai"]
|
||||
embeddings-ollama = ["llama-index-embeddings-ollama"]
|
||||
embeddings-huggingface = ["llama-index-embeddings-huggingface"]
|
||||
embeddings-ollama = ["llama-index-embeddings-ollama", "ollama"]
|
||||
embeddings-huggingface = ["llama-index-embeddings-huggingface", "einops"]
|
||||
embeddings-openai = ["llama-index-embeddings-openai"]
|
||||
embeddings-sagemaker = ["boto3"]
|
||||
embeddings-azopenai = ["llama-index-embeddings-azure-openai"]
|
||||
|
|
@ -119,7 +127,7 @@ target-version = ['py311']
|
|||
target-version = 'py311'
|
||||
|
||||
# See all rules at https://beta.ruff.rs/docs/rules/
|
||||
select = [
|
||||
lint.select = [
|
||||
"E", # pycodestyle
|
||||
"W", # pycodestyle
|
||||
"F", # Pyflakes
|
||||
|
|
@ -136,7 +144,7 @@ select = [
|
|||
"RUF", # Ruff-specific rules
|
||||
]
|
||||
|
||||
ignore = [
|
||||
lint.ignore = [
|
||||
"E501", # "Line too long"
|
||||
# -> line length already regulated by black
|
||||
"PT011", # "pytest.raises() should specify expected exception"
|
||||
|
|
@ -154,24 +162,24 @@ ignore = [
|
|||
# -> "Missing docstring in public function too restrictive"
|
||||
]
|
||||
|
||||
[tool.ruff.pydocstyle]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
# Automatically disable rules that are incompatible with Google docstring convention
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.pycodestyle]
|
||||
[tool.ruff.lint.pycodestyle]
|
||||
max-doc-length = 88
|
||||
|
||||
[tool.ruff.flake8-tidy-imports]
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
ban-relative-imports = "all"
|
||||
|
||||
[tool.ruff.flake8-type-checking]
|
||||
[tool.ruff.lint.flake8-type-checking]
|
||||
strict = true
|
||||
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
||||
# Pydantic needs to be able to evaluate types at runtime
|
||||
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
|
||||
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Allow missing docstrings for tests
|
||||
"tests/**/*.py" = ["D1"]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,13 @@ from pathlib import Path
|
|||
from private_gpt.di import global_injector
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from private_gpt.server.ingest.ingest_watcher import IngestWatcher
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalIngestWorker:
|
||||
def __init__(self, ingest_service: IngestService) -> None:
|
||||
def __init__(self, ingest_service: IngestService, setting: Settings) -> None:
|
||||
self.ingest_service = ingest_service
|
||||
|
||||
self.total_documents = 0
|
||||
|
|
@ -20,6 +21,24 @@ class LocalIngestWorker:
|
|||
|
||||
self._files_under_root_folder: list[Path] = []
|
||||
|
||||
self.is_local_ingestion_enabled = setting.data.local_ingestion.enabled
|
||||
self.allowed_local_folders = setting.data.local_ingestion.allow_ingest_from
|
||||
|
||||
def _validate_folder(self, folder_path: Path) -> None:
|
||||
if not self.is_local_ingestion_enabled:
|
||||
raise ValueError(
|
||||
"Local ingestion is disabled."
|
||||
"You can enable it in settings `ingestion.enabled`"
|
||||
)
|
||||
|
||||
# Allow all folders if wildcard is present
|
||||
if "*" in self.allowed_local_folders:
|
||||
return
|
||||
|
||||
for allowed_folder in self.allowed_local_folders:
|
||||
if not folder_path.is_relative_to(allowed_folder):
|
||||
raise ValueError(f"Folder {folder_path} is not allowed for ingestion")
|
||||
|
||||
def _find_all_files_in_folder(self, root_path: Path, ignored: list[str]) -> None:
|
||||
"""Search all files under the root folder recursively.
|
||||
|
||||
|
|
@ -28,6 +47,7 @@ class LocalIngestWorker:
|
|||
for file_path in root_path.iterdir():
|
||||
if file_path.is_file() and file_path.name not in ignored:
|
||||
self.total_documents += 1
|
||||
self._validate_folder(file_path)
|
||||
self._files_under_root_folder.append(file_path)
|
||||
elif file_path.is_dir() and file_path.name not in ignored:
|
||||
self._find_all_files_in_folder(file_path, ignored)
|
||||
|
|
@ -92,13 +112,13 @@ if args.log_file:
|
|||
logger.addHandler(file_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
root_path = Path(args.folder)
|
||||
if not root_path.exists():
|
||||
raise ValueError(f"Path {args.folder} does not exist")
|
||||
|
||||
ingest_service = global_injector.get(IngestService)
|
||||
worker = LocalIngestWorker(ingest_service)
|
||||
settings = global_injector.get(Settings)
|
||||
worker = LocalIngestWorker(ingest_service, settings)
|
||||
worker.ingest_folder(root_path, args.ignored)
|
||||
|
||||
if args.ignored:
|
||||
|
|
|
|||
|
|
@ -6,21 +6,21 @@ llm:
|
|||
mode: ${PGPT_MODE:mock}
|
||||
|
||||
embedding:
|
||||
mode: ${PGPT_MODE:sagemaker}
|
||||
mode: ${PGPT_EMBED_MODE:mock}
|
||||
|
||||
llamacpp:
|
||||
llm_hf_repo_id: ${PGPT_HF_REPO_ID:TheBloke/Mistral-7B-Instruct-v0.1-GGUF}
|
||||
llm_hf_model_file: ${PGPT_HF_MODEL_FILE:mistral-7b-instruct-v0.1.Q4_K_M.gguf}
|
||||
llm_hf_repo_id: ${PGPT_HF_REPO_ID:lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF}
|
||||
llm_hf_model_file: ${PGPT_HF_MODEL_FILE:Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf}
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:BAAI/bge-small-en-v1.5}
|
||||
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:nomic-ai/nomic-embed-text-v1.5}
|
||||
|
||||
sagemaker:
|
||||
llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
|
||||
embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
|
||||
|
||||
ollama:
|
||||
llm_model: ${PGPT_OLLAMA_LLM_MODEL:mistral}
|
||||
llm_model: ${PGPT_OLLAMA_LLM_MODEL:llama3.1}
|
||||
embedding_model: ${PGPT_OLLAMA_EMBEDDING_MODEL:nomic-embed-text}
|
||||
api_base: ${PGPT_OLLAMA_API_BASE:http://ollama:11434}
|
||||
embedding_api_base: ${PGPT_OLLAMA_EMBEDDING_API_BASE:http://ollama:11434}
|
||||
|
|
@ -30,6 +30,7 @@ ollama:
|
|||
repeat_last_n: ${PGPT_OLLAMA_REPEAT_LAST_N:64}
|
||||
repeat_penalty: ${PGPT_OLLAMA_REPEAT_PENALTY:1.2}
|
||||
request_timeout: ${PGPT_OLLAMA_REQUEST_TIMEOUT:600.0}
|
||||
autopull_models: ${PGPT_OLLAMA_AUTOPULL_MODELS:true}
|
||||
|
||||
ui:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -7,18 +7,18 @@ llm:
|
|||
# Should be matching the selected model
|
||||
max_new_tokens: 512
|
||||
context_window: 3900
|
||||
tokenizer: mistralai/Mistral-7B-Instruct-v0.2
|
||||
prompt_style: "mistral"
|
||||
tokenizer: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
prompt_style: "llama3"
|
||||
|
||||
llamacpp:
|
||||
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
|
||||
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
|
||||
llm_hf_repo_id: lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF
|
||||
llm_hf_model_file: Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
||||
|
||||
embedding:
|
||||
mode: huggingface
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||
embedding_hf_model_name: nomic-ai/nomic-embed-text-v1.5
|
||||
|
||||
vectorstore:
|
||||
database: qdrant
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ embedding:
|
|||
embed_dim: 768
|
||||
|
||||
ollama:
|
||||
llm_model: mistral
|
||||
llm_model: llama3.1
|
||||
embedding_model: nomic-embed-text
|
||||
api_base: http://localhost:11434
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ embedding:
|
|||
mode: ollama
|
||||
|
||||
ollama:
|
||||
llm_model: mistral
|
||||
llm_model: llama3.1
|
||||
embedding_model: nomic-embed-text
|
||||
api_base: http://localhost:11434
|
||||
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ server:
|
|||
llm:
|
||||
mode: openailike
|
||||
max_new_tokens: 512
|
||||
tokenizer: mistralai/Mistral-7B-Instruct-v0.2
|
||||
tokenizer: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
temperature: 0.1
|
||||
|
||||
embedding:
|
||||
|
|
@ -12,7 +12,7 @@ embedding:
|
|||
ingest_mode: simple
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||
embedding_hf_model_name: nomic-ai/nomic-embed-text-v1.5
|
||||
|
||||
openai:
|
||||
api_base: http://localhost:8000/v1
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ server:
|
|||
secret: "Basic c2VjcmV0OmtleQ=="
|
||||
|
||||
data:
|
||||
local_ingestion:
|
||||
enabled: ${LOCAL_INGESTION_ENABLED:false}
|
||||
allow_ingest_from: ["*"]
|
||||
local_data_folder: local_data/private_gpt
|
||||
|
||||
ui:
|
||||
|
|
@ -31,17 +34,24 @@ ui:
|
|||
You can only answer questions about the provided context.
|
||||
If you know the answer but it is not based in the provided context, don't provide
|
||||
the answer, just state the answer is not in the context provided.
|
||||
default_summarization_system_prompt: >
|
||||
Provide a comprehensive summary of the provided context information.
|
||||
The summary should cover all the key points and main ideas presented in
|
||||
the original text, while also condensing the information into a concise
|
||||
and easy-to-understand format. Please ensure that the summary includes
|
||||
relevant details and examples that support the main ideas, while avoiding
|
||||
any unnecessary information or repetition.
|
||||
delete_file_button_enabled: true
|
||||
delete_all_files_button_enabled: true
|
||||
|
||||
llm:
|
||||
mode: llamacpp
|
||||
prompt_style: "mistral"
|
||||
prompt_style: "llama3"
|
||||
# Should be matching the selected model
|
||||
max_new_tokens: 512
|
||||
context_window: 3900
|
||||
# Select your tokenizer. Llama-index tokenizer is the default.
|
||||
# tokenizer: mistralai/Mistral-7B-Instruct-v0.2
|
||||
# tokenizer: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
temperature: 0.1 # The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1)
|
||||
|
||||
rag:
|
||||
|
|
@ -54,6 +64,9 @@ rag:
|
|||
model: cross-encoder/ms-marco-MiniLM-L-2-v2
|
||||
top_n: 1
|
||||
|
||||
summarize:
|
||||
use_async: true
|
||||
|
||||
clickhouse:
|
||||
host: localhost
|
||||
port: 8443
|
||||
|
|
@ -62,8 +75,8 @@ clickhouse:
|
|||
database: embeddings
|
||||
|
||||
llamacpp:
|
||||
llm_hf_repo_id: TheBloke/Mistral-7B-Instruct-v0.2-GGUF
|
||||
llm_hf_model_file: mistral-7b-instruct-v0.2.Q4_K_M.gguf
|
||||
llm_hf_repo_id: lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF
|
||||
llm_hf_model_file: Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
|
||||
tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
|
||||
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
|
||||
top_p: 1.0 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||
|
|
@ -73,11 +86,14 @@ embedding:
|
|||
# Should be matching the value above in most cases
|
||||
mode: huggingface
|
||||
ingest_mode: simple
|
||||
embed_dim: 384 # 384 is for BAAI/bge-small-en-v1.5
|
||||
embed_dim: 768 # 768 is for nomic-ai/nomic-embed-text-v1.5
|
||||
|
||||
huggingface:
|
||||
embedding_hf_model_name: BAAI/bge-small-en-v1.5
|
||||
embedding_hf_model_name: nomic-ai/nomic-embed-text-v1.5
|
||||
access_token: ${HF_TOKEN:}
|
||||
# Warning: Enabling this option will allow the model to download and execute code from the internet.
|
||||
# Nomic AI requires this option to be enabled to use the model, be aware if you are using a different model.
|
||||
trust_remote_code: true
|
||||
|
||||
vectorstore:
|
||||
database: qdrant
|
||||
|
|
@ -111,12 +127,13 @@ openai:
|
|||
embedding_api_key: ${OPENAI_API_KEY:}
|
||||
|
||||
ollama:
|
||||
llm_model: llama2
|
||||
llm_model: llama3.1
|
||||
embedding_model: nomic-embed-text
|
||||
api_base: http://localhost:11434
|
||||
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
|
||||
keep_alive: 5m
|
||||
request_timeout: 120.0
|
||||
autopull_models: true
|
||||
|
||||
azopenai:
|
||||
api_key: ${AZ_OPENAI_API_KEY:}
|
||||
|
|
|
|||
74
tests/server/ingest/test_local_ingest.py
Normal file
74
tests/server/ingest/test_local_ingest.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def file_path() -> str:
|
||||
return "test.txt"
|
||||
|
||||
|
||||
def create_test_file(file_path: str) -> None:
|
||||
with open(file_path, "w") as f:
|
||||
f.write("test")
|
||||
|
||||
|
||||
def clear_log_file(log_file_path: str) -> None:
|
||||
if Path(log_file_path).exists():
|
||||
os.remove(log_file_path)
|
||||
|
||||
|
||||
def read_log_file(log_file_path: str) -> str:
|
||||
with open(log_file_path) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def init_structure(folder: str, file_path: str) -> None:
|
||||
clear_log_file(file_path)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
create_test_file(f"{folder}/${file_path}")
|
||||
|
||||
|
||||
def test_ingest_one_file_in_allowed_folder(
|
||||
file_path: str, test_client: TestClient
|
||||
) -> None:
|
||||
allowed_folder = "local_data/tests/allowed_folder"
|
||||
init_structure(allowed_folder, file_path)
|
||||
|
||||
test_env = os.environ.copy()
|
||||
test_env["PGPT_PROFILES"] = "test"
|
||||
test_env["LOCAL_INGESTION_ENABLED"] = "True"
|
||||
|
||||
result = subprocess.run(
|
||||
["python", "scripts/ingest_folder.py", allowed_folder],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=test_env,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Script failed with error: {result.stderr}"
|
||||
response_after = test_client.get("/v1/ingest/list")
|
||||
|
||||
count_ingest_after = len(response_after.json()["data"])
|
||||
assert count_ingest_after > 0, "No documents were ingested"
|
||||
|
||||
|
||||
def test_ingest_disabled(file_path: str) -> None:
|
||||
allowed_folder = "local_data/tests/allowed_folder"
|
||||
init_structure(allowed_folder, file_path)
|
||||
|
||||
test_env = os.environ.copy()
|
||||
test_env["PGPT_PROFILES"] = "test"
|
||||
test_env["LOCAL_INGESTION_ENABLED"] = "False"
|
||||
|
||||
result = subprocess.run(
|
||||
["python", "scripts/ingest_folder.py", allowed_folder],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=test_env,
|
||||
)
|
||||
|
||||
assert result.returncode != 0, f"Script failed with error: {result.stderr}"
|
||||
159
tests/server/recipes/test_summarize_router.py
Normal file
159
tests/server/recipes/test_summarize_router.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from private_gpt.server.recipes.summarize.summarize_router import (
|
||||
SummarizeBody,
|
||||
SummarizeResponse,
|
||||
)
|
||||
|
||||
|
||||
def test_summarize_route_produces_a_stream(test_client: TestClient) -> None:
|
||||
body = SummarizeBody(
|
||||
text="Test",
|
||||
stream=True,
|
||||
)
|
||||
response = test_client.post("/v1/summarize", json=body.model_dump())
|
||||
|
||||
raw_events = response.text.split("\n\n")
|
||||
events = [
|
||||
item.removeprefix("data: ") for item in raw_events if item.startswith("data: ")
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
assert len(events) > 0
|
||||
assert events[-1] == "[DONE]"
|
||||
|
||||
|
||||
def test_summarize_route_produces_a_single_value(test_client: TestClient) -> None:
|
||||
body = SummarizeBody(
|
||||
text="test",
|
||||
stream=False,
|
||||
)
|
||||
response = test_client.post("/v1/summarize", json=body.model_dump())
|
||||
|
||||
# No asserts, if it validates it's good
|
||||
SummarizeResponse.model_validate(response.json())
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_summarize_with_document_context(test_client: TestClient) -> None:
|
||||
# Ingest an document
|
||||
ingest_response = test_client.post(
|
||||
"/v1/ingest/text",
|
||||
json={
|
||||
"file_name": "file_name",
|
||||
"text": "Lorem ipsum dolor sit amet",
|
||||
},
|
||||
)
|
||||
assert ingest_response.status_code == 200
|
||||
ingested_docs = ingest_response.json()["data"]
|
||||
assert len(ingested_docs) == 1
|
||||
|
||||
body = SummarizeBody(
|
||||
use_context=True,
|
||||
context_filter={"docs_ids": [doc["doc_id"] for doc in ingested_docs]},
|
||||
stream=False,
|
||||
)
|
||||
response = test_client.post("/v1/summarize", json=body.model_dump())
|
||||
|
||||
completion: SummarizeResponse = SummarizeResponse.model_validate(response.json())
|
||||
assert response.status_code == 200
|
||||
# We can check the content of the completion, because mock LLM used in tests
|
||||
# always echoes the prompt. In the case of summary, the input context is passed.
|
||||
assert completion.summary.find("Lorem ipsum dolor sit amet") != -1
|
||||
|
||||
|
||||
def test_summarize_with_non_existent_document_context_not_fails(
|
||||
test_client: TestClient,
|
||||
) -> None:
|
||||
body = SummarizeBody(
|
||||
use_context=True,
|
||||
context_filter={
|
||||
"docs_ids": ["non-existent-doc-id"],
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = test_client.post("/v1/summarize", json=body.model_dump())
|
||||
|
||||
completion: SummarizeResponse = SummarizeResponse.model_validate(response.json())
|
||||
assert response.status_code == 200
|
||||
# We can check the content of the completion, because mock LLM used in tests
|
||||
# always echoes the prompt. In the case of summary, the input context is passed.
|
||||
assert completion.summary.find("Empty Response") != -1
|
||||
|
||||
|
||||
def test_summarize_with_metadata_and_document_context(test_client: TestClient) -> None:
|
||||
docs = []
|
||||
|
||||
# Ingest a first document
|
||||
document_1_content = "Content of document 1"
|
||||
ingest_response = test_client.post(
|
||||
"/v1/ingest/text",
|
||||
json={
|
||||
"file_name": "file_name_1",
|
||||
"text": document_1_content,
|
||||
},
|
||||
)
|
||||
assert ingest_response.status_code == 200
|
||||
ingested_docs = ingest_response.json()["data"]
|
||||
assert len(ingested_docs) == 1
|
||||
docs += ingested_docs
|
||||
|
||||
# Ingest a second document
|
||||
document_2_content = "Text of document 2"
|
||||
ingest_response = test_client.post(
|
||||
"/v1/ingest/text",
|
||||
json={
|
||||
"file_name": "file_name_2",
|
||||
"text": document_2_content,
|
||||
},
|
||||
)
|
||||
assert ingest_response.status_code == 200
|
||||
ingested_docs = ingest_response.json()["data"]
|
||||
assert len(ingested_docs) == 1
|
||||
docs += ingested_docs
|
||||
|
||||
# Completions with the first document's id and the second document's metadata
|
||||
body = SummarizeBody(
|
||||
use_context=True,
|
||||
context_filter={"docs_ids": [doc["doc_id"] for doc in docs]},
|
||||
stream=False,
|
||||
)
|
||||
response = test_client.post("/v1/summarize", json=body.model_dump())
|
||||
|
||||
completion: SummarizeResponse = SummarizeResponse.model_validate(response.json())
|
||||
assert response.status_code == 200
|
||||
# Assert both documents are part of the used sources
|
||||
# We can check the content of the completion, because mock LLM used in tests
|
||||
# always echoes the prompt. In the case of summary, the input context is passed.
|
||||
assert completion.summary.find(document_1_content) != -1
|
||||
assert completion.summary.find(document_2_content) != -1
|
||||
|
||||
|
||||
def test_summarize_with_prompt(test_client: TestClient) -> None:
|
||||
ingest_response = test_client.post(
|
||||
"/v1/ingest/text",
|
||||
json={
|
||||
"file_name": "file_name",
|
||||
"text": "Lorem ipsum dolor sit amet",
|
||||
},
|
||||
)
|
||||
assert ingest_response.status_code == 200
|
||||
ingested_docs = ingest_response.json()["data"]
|
||||
assert len(ingested_docs) == 1
|
||||
|
||||
body = SummarizeBody(
|
||||
use_context=True,
|
||||
context_filter={
|
||||
"docs_ids": [doc["doc_id"] for doc in ingested_docs],
|
||||
},
|
||||
prompt="This is a custom summary prompt, 54321",
|
||||
stream=False,
|
||||
)
|
||||
response = test_client.post("/v1/summarize", json=body.model_dump())
|
||||
|
||||
completion: SummarizeResponse = SummarizeResponse.model_validate(response.json())
|
||||
assert response.status_code == 200
|
||||
# We can check the content of the completion, because mock LLM used in tests
|
||||
# always echoes the prompt. In the case of summary, the input context is passed.
|
||||
assert completion.summary.find("This is a custom summary prompt, 54321") != -1
|
||||
|
|
@ -5,6 +5,7 @@ from private_gpt.components.llm.prompt_helper import (
|
|||
ChatMLPromptStyle,
|
||||
DefaultPromptStyle,
|
||||
Llama2PromptStyle,
|
||||
Llama3PromptStyle,
|
||||
MistralPromptStyle,
|
||||
TagPromptStyle,
|
||||
get_prompt_style,
|
||||
|
|
@ -139,3 +140,57 @@ def test_llama2_prompt_style_with_system_prompt():
|
|||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama3_prompt_style_format():
|
||||
prompt_style = Llama3PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n"
|
||||
"You are a helpful assistant<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"Hello, how are you doing?<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
||||
|
||||
def test_llama3_prompt_style_with_default_system():
|
||||
prompt_style = Llama3PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="Hello!", role=MessageRole.USER),
|
||||
]
|
||||
expected = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n"
|
||||
f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
assert prompt_style._messages_to_prompt(messages) == expected
|
||||
|
||||
|
||||
def test_llama3_prompt_style_with_assistant_response():
|
||||
prompt_style = Llama3PromptStyle()
|
||||
messages = [
|
||||
ChatMessage(content="You are a helpful assistant", role=MessageRole.SYSTEM),
|
||||
ChatMessage(content="What is the capital of France?", role=MessageRole.USER),
|
||||
ChatMessage(
|
||||
content="The capital of France is Paris.", role=MessageRole.ASSISTANT
|
||||
),
|
||||
]
|
||||
|
||||
expected_prompt = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n"
|
||||
"You are a helpful assistant<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"What is the capital of France?<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
"The capital of France is Paris.<|eot_id|>"
|
||||
)
|
||||
|
||||
assert prompt_style.messages_to_prompt(messages) == expected_prompt
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue