feat: add ollama queue

This commit is contained in:
Javier Martinez 2024-08-01 11:55:41 +02:00
parent 741376a085
commit 21c622ee27
No known key found for this signature in database

View file

@ -1,6 +1,7 @@
import logging import logging
from typing import Any, Generator, Mapping, Iterator from typing import Any, Generator, Mapping, Iterator
from tqdm import tqdm from tqdm import tqdm
from collections import deque
try: try:
from ollama import Client # type: ignore from ollama import Client # type: ignore
@ -23,9 +24,12 @@ def check_connection(client: Client) -> bool:
def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
progress_bars = {} progress_bars = {}
queue = deque()
def create_progress_bar(total: int) -> tqdm: def create_progress_bar(dgt: str, total: int) -> tqdm:
return tqdm(total=total, desc=f"Pulling model", unit='B', unit_scale=True) return tqdm(total=total, desc=f"Pulling model {dgt[7:17]}...", unit='B', unit_scale=True)
current_digest = None
for chunk in generator: for chunk in generator:
digest = chunk.get("digest") digest = chunk.get("digest")
@ -33,21 +37,34 @@ def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
total_size = chunk.get("total") total_size = chunk.get("total")
if digest and total_size is not None: if digest and total_size is not None:
if digest not in progress_bars: if digest not in progress_bars and completed_size > 0:
progress_bars[digest] = create_progress_bar(total=total_size) progress_bars[digest] = create_progress_bar(digest, total=total_size)
if current_digest is None:
current_digest = digest
else:
queue.append(digest)
progress_bar = progress_bars[digest] if digest in progress_bars:
progress_bar.update(completed_size - progress_bar.n) progress_bar = progress_bars[digest]
progress = completed_size - progress_bar.n
if completed_size == total_size: if completed_size > 0 and total_size >= progress != progress_bar.n:
progress_bar.close() if digest == current_digest:
del progress_bars[digest] progress_bar.update(progress)
if progress_bar.n >= total_size:
progress_bar.close()
if queue:
current_digest = queue.popleft()
else:
current_digest = None
else:
# Store progress for later update
progress_bars[digest].total = total_size
progress_bars[digest].n = completed_size
# Close any remaining progress bars at the end # Close any remaining progress bars at the end
for progress_bar in progress_bars.values(): for progress_bar in progress_bars.values():
progress_bar.close() progress_bar.close()
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
try: try:
installed_models = [model["name"] for model in client.list().get("models", {})] installed_models = [model["name"] for model in client.list().get("models", {})]