mirror of
https://github.com/zylon-ai/private-gpt.git
synced 2025-12-22 10:45:42 +01:00
feat: add ollama queue
This commit is contained in:
parent
741376a085
commit
21c622ee27
1 changed files with 28 additions and 11 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
from typing import Any, Generator, Mapping, Iterator
|
||||
from tqdm import tqdm
|
||||
from collections import deque
|
||||
|
||||
try:
|
||||
from ollama import Client # type: ignore
|
||||
|
|
@ -23,9 +24,12 @@ def check_connection(client: Client) -> bool:
|
|||
|
||||
def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
|
||||
progress_bars = {}
|
||||
queue = deque()
|
||||
|
||||
def create_progress_bar(total: int) -> tqdm:
|
||||
return tqdm(total=total, desc=f"Pulling model", unit='B', unit_scale=True)
|
||||
def create_progress_bar(dgt: str, total: int) -> tqdm:
|
||||
return tqdm(total=total, desc=f"Pulling model {dgt[7:17]}...", unit='B', unit_scale=True)
|
||||
|
||||
current_digest = None
|
||||
|
||||
for chunk in generator:
|
||||
digest = chunk.get("digest")
|
||||
|
|
@ -33,21 +37,34 @@ def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
|
|||
total_size = chunk.get("total")
|
||||
|
||||
if digest and total_size is not None:
|
||||
if digest not in progress_bars:
|
||||
progress_bars[digest] = create_progress_bar(total=total_size)
|
||||
if digest not in progress_bars and completed_size > 0:
|
||||
progress_bars[digest] = create_progress_bar(digest, total=total_size)
|
||||
if current_digest is None:
|
||||
current_digest = digest
|
||||
else:
|
||||
queue.append(digest)
|
||||
|
||||
progress_bar = progress_bars[digest]
|
||||
progress_bar.update(completed_size - progress_bar.n)
|
||||
|
||||
if completed_size == total_size:
|
||||
progress_bar.close()
|
||||
del progress_bars[digest]
|
||||
if digest in progress_bars:
|
||||
progress_bar = progress_bars[digest]
|
||||
progress = completed_size - progress_bar.n
|
||||
if completed_size > 0 and total_size >= progress != progress_bar.n:
|
||||
if digest == current_digest:
|
||||
progress_bar.update(progress)
|
||||
if progress_bar.n >= total_size:
|
||||
progress_bar.close()
|
||||
if queue:
|
||||
current_digest = queue.popleft()
|
||||
else:
|
||||
current_digest = None
|
||||
else:
|
||||
# Store progress for later update
|
||||
progress_bars[digest].total = total_size
|
||||
progress_bars[digest].n = completed_size
|
||||
|
||||
# Close any remaining progress bars at the end
|
||||
for progress_bar in progress_bars.values():
|
||||
progress_bar.close()
|
||||
|
||||
|
||||
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
|
||||
try:
|
||||
installed_models = [model["name"] for model in client.list().get("models", {})]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue