From f27272fb5ccdac386e1c7c9fd2d2cb9565bdbb03 Mon Sep 17 00:00:00 2001 From: SkiingIsFun123 <101684827+SkiingIsFun123@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:27:31 -0700 Subject: [PATCH] Adding MistralAI mode --- poetry.lock | 59 ++++++++++++++++--- .../embedding/embedding_component.py | 17 ++++++ private_gpt/components/llm/llm_component.py | 26 ++++++++ private_gpt/settings/settings.py | 13 +++- private_gpt/ui/ui.py | 1 + pyproject.toml | 2 + settings-mistral.yaml | 16 +++++ 7 files changed, 125 insertions(+), 9 deletions(-) create mode 100644 settings-mistral.yaml diff --git a/poetry.lock b/poetry.lock index 4d6d699..bb0a1ae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiofiles" @@ -2231,6 +2231,17 @@ files = [ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, ] +[[package]] +name = "jsonpath-python" +version = "1.0.6" +description = "A more powerful JSONPath implementation in modern python" +optional = true +python-versions = ">=3.6" +files = [ + {file = "jsonpath-python-1.0.6.tar.gz", hash = "sha256:dd5be4a72d8a2995c3f583cf82bf3cd1a9544cfdabf2d22595b67aff07349666"}, + {file = "jsonpath_python-1.0.6-py3-none-any.whl", hash = "sha256:1e3b78df579f5efc23565293612decee04214609208a2335884b3ee3f786b575"}, +] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -2474,6 +2485,21 @@ huggingface-hub = {version = ">=0.19.0", extras = ["inference"]} llama-index-core = ">=0.10.1,<0.11.0" sentence-transformers = ">=2.6.1" +[[package]] +name = "llama-index-embeddings-mistralai" +version = "0.1.6" +description = "llama-index embeddings mistralai integration" +optional = true +python-versions = "<4.0,>=3.9" +files = [ + {file = "llama_index_embeddings_mistralai-0.1.6-py3-none-any.whl", hash = "sha256:d69d6fc0be8a1772aaf890bc036f2d575af46070b375a2649803c0eb9736ea1b"}, + {file = "llama_index_embeddings_mistralai-0.1.6.tar.gz", hash = "sha256:7c9cbf974b1e7d14ded34d3eb749a0d1a379fb151ab75115cc1ffdd08a96a045"}, +] + +[package.dependencies] +llama-index-core = ">=0.10.1,<0.11.0" +mistralai = ">=1.0.0" + [[package]] name = "llama-index-embeddings-ollama" version = "0.1.2" @@ -2911,29 +2937,24 @@ files = [ {file = "matplotlib-3.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd2a59ff4b83d33bca3b5ec58203cc65985367812cb8c257f3e101632be86d92"}, {file = "matplotlib-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fc001516ffcf1a221beb51198b194d9230199d6842c540108e4ce109ac05cc0"}, {file = "matplotlib-3.9.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:83c6a792f1465d174c86d06f3ae85a8fe36e6f5964633ae8106312ec0921fdf5"}, - {file = "matplotlib-3.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:421851f4f57350bcf0811edd754a708d2275533e84f52f6760b740766c6747a7"}, {file = "matplotlib-3.9.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b3fce58971b465e01b5c538f9d44915640c20ec5ff31346e963c9e1cd66fa812"}, {file = "matplotlib-3.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a973c53ad0668c53e0ed76b27d2eeeae8799836fd0d0caaa4ecc66bf4e6676c0"}, {file = "matplotlib-3.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82cd5acf8f3ef43f7532c2f230249720f5dc5dd40ecafaf1c60ac8200d46d7eb"}, {file = "matplotlib-3.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab38a4f3772523179b2f772103d8030215b318fef6360cb40558f585bf3d017f"}, {file = "matplotlib-3.9.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2315837485ca6188a4b632c5199900e28d33b481eb083663f6a44cfc8987ded3"}, - {file = "matplotlib-3.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:a0c977c5c382f6696caf0bd277ef4f936da7e2aa202ff66cad5f0ac1428ee15b"}, {file = "matplotlib-3.9.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:565d572efea2b94f264dd86ef27919515aa6d629252a169b42ce5f570db7f37b"}, {file = "matplotlib-3.9.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d397fd8ccc64af2ec0af1f0efc3bacd745ebfb9d507f3f552e8adb689ed730a"}, {file = "matplotlib-3.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26040c8f5121cd1ad712abffcd4b5222a8aec3a0fe40bc8542c94331deb8780d"}, {file = "matplotlib-3.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12cb1837cffaac087ad6b44399d5e22b78c729de3cdae4629e252067b705e2b"}, {file = "matplotlib-3.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0e835c6988edc3d2d08794f73c323cc62483e13df0194719ecb0723b564e0b5c"}, - {file = "matplotlib-3.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:44a21d922f78ce40435cb35b43dd7d573cf2a30138d5c4b709d19f00e3907fd7"}, {file = "matplotlib-3.9.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0c584210c755ae921283d21d01f03a49ef46d1afa184134dd0f95b0202ee6f03"}, {file = "matplotlib-3.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:11fed08f34fa682c2b792942f8902e7aefeed400da71f9e5816bea40a7ce28fe"}, {file = "matplotlib-3.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0000354e32efcfd86bda75729716b92f5c2edd5b947200be9881f0a671565c33"}, {file = "matplotlib-3.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4db17fea0ae3aceb8e9ac69c7e3051bae0b3d083bfec932240f9bf5d0197a049"}, {file = "matplotlib-3.9.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:208cbce658b72bf6a8e675058fbbf59f67814057ae78165d8a2f87c45b48d0ff"}, - {file = "matplotlib-3.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:dc23f48ab630474264276be156d0d7710ac6c5a09648ccdf49fef9200d8cbe80"}, {file = "matplotlib-3.9.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:3fda72d4d472e2ccd1be0e9ccb6bf0d2eaf635e7f8f51d737ed7e465ac020cb3"}, {file = "matplotlib-3.9.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:84b3ba8429935a444f1fdc80ed930babbe06725bcf09fbeb5c8757a2cd74af04"}, {file = "matplotlib-3.9.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b918770bf3e07845408716e5bbda17eadfc3fcbd9307dc67f37d6cf834bb3d98"}, - {file = "matplotlib-3.9.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:f1f2e5d29e9435c97ad4c36fb6668e89aee13d48c75893e25cef064675038ac9"}, {file = "matplotlib-3.9.1.tar.gz", hash = "sha256:de06b19b8db95dd33d0dc17c926c7c9ebed9f572074b6fac4f65068a6814d010"}, ] @@ -2995,6 +3016,27 @@ files = [ {file = "minijinja-2.0.1.tar.gz", hash = "sha256:e774beffebfb8a1ad17e638ef70917cf5e94593f79acb8a8fff7d983169f3a4e"}, ] +[[package]] +name = "mistralai" +version = "1.0.1" +description = "Python Client SDK for the Mistral AI API." +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "mistralai-1.0.1-py3-none-any.whl", hash = "sha256:5e5fc28122e11aec0ce37781b6419963e31cd7caccddf89f54eac1ece81f063f"}, + {file = "mistralai-1.0.1.tar.gz", hash = "sha256:f6b055d21dd56e174e5023371295c35945d0f7b282486457d6a71ff47c703fe8"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" +jsonpath-python = ">=1.0.6,<2.0.0" +pydantic = ">=2.8.2,<2.9.0" +python-dateutil = ">=2.9.0.post0,<3.0.0" +typing-inspect = ">=0.9.0,<0.10.0" + +[package.extras] +gcp = ["google-auth (>=2.31.0,<3.0.0)", "requests (>=2.32.3,<3.0.0)"] + [[package]] name = "mmh3" version = "4.1.0" @@ -3854,6 +3896,8 @@ files = [ {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, + {file = "orjson-3.10.6-cp313-none-win32.whl", hash = "sha256:efdf2c5cde290ae6b83095f03119bdc00303d7a03b42b16c54517baa3c4ca3d0"}, + {file = "orjson-3.10.6-cp313-none-win_amd64.whl", hash = "sha256:8e190fe7888e2e4392f52cafb9626113ba135ef53aacc65cd13109eb9746c43e"}, {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, @@ -6671,6 +6715,7 @@ cffi = ["cffi (>=1.11)"] embeddings-azopenai = ["llama-index-embeddings-azure-openai"] embeddings-gemini = ["llama-index-embeddings-gemini"] embeddings-huggingface = ["einops", "llama-index-embeddings-huggingface"] +embeddings-mistral = ["llama-index-embeddings-mistralai"] embeddings-ollama = ["llama-index-embeddings-ollama", "ollama"] embeddings-openai = ["llama-index-embeddings-openai"] embeddings-sagemaker = ["boto3"] @@ -6693,4 +6738,4 @@ vector-stores-qdrant = ["llama-index-vector-stores-qdrant"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "25abbb45bc462dbf056b83c0925b505ad1232484a18e50f07c5e7f517dd84e6f" +content-hash = "3de5e86444ceee26b22bcecf8dd547d7c5ef2a14f62a1192a6125a4700e74640" diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index 5d3e997..8828fdc 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -144,6 +144,23 @@ class EmbeddingComponent: api_key=settings.gemini.api_key, model_name=settings.gemini.embedding_model, ) + case "mistral": + try: + from llama_index.embeddings.mistralai import ( # type: ignore + MistralAIEmbedding, + ) + except ImportError as e: + raise ImportError( + "Mistral dependencies not found, install with `poetry install --extras embeddings-mistral`" + ) from e + + api_key = settings.mistral.api_key + model = settings.mistral.embedding_model + + self.embedding_model = MistralAIEmbedding( + api_key=api_key, + model=model, + ) case "mock": # Not a random number, is the dimensionality used by # the default embedding model diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index e3a0281..40fb33d 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -222,5 +222,31 @@ class LLMComponent: self.llm = Gemini( model_name=gemini_settings.model, api_key=gemini_settings.api_key ) + case "mistral": + try: + from llama_index.llms.openai_like import OpenAILike # type: ignore + except ImportError as e: + raise ImportError( + "OpenAILike dependencies not found, install with `poetry install --extras llms-openai-like`" + ) from e + + prompt_style = get_prompt_style("mistral") + mistral_settings = settings.mistral + self.llm = OpenAILike( + api_base=mistral_settings.endpoint, + api_key=mistral_settings.api_key, + model=mistral_settings.model, + is_chat_model=True, + max_tokens=settings.llm.max_new_tokens, + api_version="", + temperature=settings.llm.temperature, + context_window=settings.llm.context_window, + max_new_tokens=settings.llm.max_new_tokens, + messages_to_prompt=prompt_style.messages_to_prompt, + completion_to_prompt=prompt_style.completion_to_prompt, + tokenizer=settings.llm.tokenizer, + timeout=mistral_settings.request_timeout, + reuse_client=False, + ) case "mock": self.llm = MockLLM() diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 4cf192a..13ad9d8 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -104,7 +104,6 @@ class DataSettings(BaseModel): "It will be treated as an absolute path if it starts with /" ) - class LLMSettings(BaseModel): mode: Literal[ "llamacpp", @@ -115,6 +114,7 @@ class LLMSettings(BaseModel): "mock", "ollama", "gemini", + "mistral", ] max_new_tokens: int = Field( 256, @@ -197,7 +197,7 @@ class HuggingFaceSettings(BaseModel): class EmbeddingSettings(BaseModel): mode: Literal[ - "huggingface", "openai", "azopenai", "sagemaker", "ollama", "mock", "gemini" + "huggingface", "openai", "azopenai", "sagemaker", "ollama", "mock", "gemini", "mistral" ] ingest_mode: Literal["simple", "batch", "parallel", "pipeline"] = Field( "simple", @@ -273,6 +273,15 @@ class GeminiSettings(BaseModel): ) +class MistralSettings(BaseModel): + api_key: str + endpoint: str + model: str + prompt_style: str + embedding_model: str + request_timeout: int + + class OllamaSettings(BaseModel): api_base: str = Field( "http://localhost:11434", diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 0bf06d1..bb8d623 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -523,6 +523,7 @@ class PrivateGptUi: "mock": llm_mode, "ollama": config_settings.ollama.llm_model, "gemini": config_settings.gemini.model, + "mistral": config_settings.gemini.model, } if llm_mode not in model_mapping: diff --git a/pyproject.toml b/pyproject.toml index 10e3c2b..da3bd16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ llama-index-embeddings-huggingface = {version ="^0.2.2", optional = true} llama-index-embeddings-openai = {version ="^0.1.10", optional = true} llama-index-embeddings-azure-openai = {version ="^0.1.10", optional = true} llama-index-embeddings-gemini = {version ="^0.1.8", optional = true} +llama-index-embeddings-mistralai = {version ="^0.1.6", optional = true} llama-index-vector-stores-qdrant = {version ="^0.2.10", optional = true} llama-index-vector-stores-milvus = {version ="^0.1.20", optional = true} llama-index-vector-stores-chroma = {version ="^0.1.10", optional = true} @@ -83,6 +84,7 @@ embeddings-openai = ["llama-index-embeddings-openai"] embeddings-sagemaker = ["boto3"] embeddings-azopenai = ["llama-index-embeddings-azure-openai"] embeddings-gemini = ["llama-index-embeddings-gemini"] +embeddings-mistral = ["llama-index-embeddings-mistralai"] vector-stores-qdrant = ["llama-index-vector-stores-qdrant"] vector-stores-clickhouse = ["llama-index-vector-stores-clickhouse", "clickhouse_connect"] vector-stores-chroma = ["llama-index-vector-stores-chroma"] diff --git a/settings-mistral.yaml b/settings-mistral.yaml new file mode 100644 index 0000000..71e4ac0 --- /dev/null +++ b/settings-mistral.yaml @@ -0,0 +1,16 @@ +server: + env_name: ${APP_ENV:mistral} + +llm: + mode: mistral + +embedding: + mode: mistral + +mistral: + endpoint: https://api.mistral.ai/v1 + api_key: + model: mistral-large + prompt_style: mistral + embedding_model: mistral-embed + request_timeout: 1000 \ No newline at end of file