From 21c622ee278e3da97fef15cd180b28aa906d9dc2 Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Thu, 1 Aug 2024 11:55:41 +0200 Subject: [PATCH] feat: add ollama queue --- private_gpt/utils/ollama.py | 39 ++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/private_gpt/utils/ollama.py b/private_gpt/utils/ollama.py index e13f709..95070e4 100644 --- a/private_gpt/utils/ollama.py +++ b/private_gpt/utils/ollama.py @@ -1,6 +1,7 @@ import logging from typing import Any, Generator, Mapping, Iterator from tqdm import tqdm +from collections import deque try: from ollama import Client # type: ignore @@ -23,9 +24,12 @@ def check_connection(client: Client) -> bool: def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: progress_bars = {} + queue = deque() - def create_progress_bar(total: int) -> tqdm: - return tqdm(total=total, desc=f"Pulling model", unit='B', unit_scale=True) + def create_progress_bar(dgt: str, total: int) -> tqdm: + return tqdm(total=total, desc=f"Pulling model {dgt[7:17]}...", unit='B', unit_scale=True) + + current_digest = None for chunk in generator: digest = chunk.get("digest") @@ -33,21 +37,34 @@ def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: total_size = chunk.get("total") if digest and total_size is not None: - if digest not in progress_bars: - progress_bars[digest] = create_progress_bar(total=total_size) + if digest not in progress_bars and completed_size > 0: + progress_bars[digest] = create_progress_bar(digest, total=total_size) + if current_digest is None: + current_digest = digest + else: + queue.append(digest) - progress_bar = progress_bars[digest] - progress_bar.update(completed_size - progress_bar.n) - - if completed_size == total_size: - progress_bar.close() - del progress_bars[digest] + if digest in progress_bars: + progress_bar = progress_bars[digest] + progress = completed_size - progress_bar.n + if completed_size > 0 and total_size >= progress != progress_bar.n: + if digest == current_digest: + progress_bar.update(progress) + if progress_bar.n >= total_size: + progress_bar.close() + if queue: + current_digest = queue.popleft() + else: + current_digest = None + else: + # Store progress for later update + progress_bars[digest].total = total_size + progress_bars[digest].n = completed_size # Close any remaining progress bars at the end for progress_bar in progress_bars.values(): progress_bar.close() - def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: try: installed_models = [model["name"] for model in client.list().get("models", {})]