mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 07:40:12 +01:00
Support n_batch to improve inference performance
This commit is contained in:
parent
52eb020256
commit
ad661933cb
3 changed files with 5 additions and 2 deletions
|
|
@ -17,6 +17,7 @@ persist_directory = os.environ.get('PERSIST_DIRECTORY')
|
|||
model_type = os.environ.get('MODEL_TYPE')
|
||||
model_path = os.environ.get('MODEL_PATH')
|
||||
model_n_ctx = os.environ.get('MODEL_N_CTX')
|
||||
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
|
||||
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
|
||||
|
||||
from constants import CHROMA_SETTINGS
|
||||
|
|
@ -32,9 +33,9 @@ def main():
|
|||
# Prepare the LLM
|
||||
match model_type:
|
||||
case "LlamaCpp":
|
||||
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=False)
|
||||
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False)
|
||||
case "GPT4All":
|
||||
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', callbacks=callbacks, verbose=False)
|
||||
llm = GPT4All(model=model_path, n_ctx=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False)
|
||||
case _default:
|
||||
print(f"Model {model_type} not supported!")
|
||||
exit;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue