Update poetry lock, and fix run for template prompt format

This commit is contained in:
Louis 2023-12-03 18:46:18 +01:00
parent 5bc5054000
commit af1463637b
4 changed files with 957 additions and 764 deletions

View file

@ -178,8 +178,10 @@ class AbstractPromptStyle(abc.ABC):
class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
_DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
def __init__(self, default_system_prompt: str | None) -> None:
super().__init__()
def __init__(
self, default_system_prompt: str | None, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
self.default_system_prompt = default_system_prompt
@ -235,9 +237,13 @@ class LlamaIndexPromptStyle(AbstractPromptStyleWithSystemPrompt):
```
"""
def __init__(self, default_system_prompt: str | None = None) -> None:
def __init__(
self, default_system_prompt: str | None = None, *args: Any, **kwargs: Any
) -> None:
# If no system prompt is given, the default one of the implementation is used.
super().__init__(default_system_prompt=default_system_prompt)
# default_system_prompt can be None here
kwargs["default_system_prompt"] = default_system_prompt
super().__init__(*args, **kwargs)
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return messages_to_prompt(messages, self.default_system_prompt)
@ -264,12 +270,14 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
self,
default_system_prompt: str | None = None,
add_generation_prompt: bool = True,
*args: Any,
**kwargs: Any,
) -> None:
# We have to define a default system prompt here as the LLM will not
# use the default llama_utils functions.
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.system_prompt: str = default_system_prompt
kwargs["default_system_prompt"] = default_system_prompt
super().__init__(*args, **kwargs)
self.add_generation_prompt = add_generation_prompt
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
@ -300,7 +308,11 @@ class VigognePromptStyle(AbstractPromptStyleWithSystemPrompt):
class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
def __init__(
self, prompt_style: str, default_system_prompt: str | None = None
self,
prompt_style: str,
default_system_prompt: str | None = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Wrapper for llama_cpp_python defined prompt format.
@ -309,7 +321,8 @@ class LlamaCppPromptStyle(AbstractPromptStyleWithSystemPrompt):
"""
assert prompt_style.startswith("llama_cpp.")
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
kwargs["default_system_prompt"] = default_system_prompt
super().__init__(*args, **kwargs)
self.prompt_style = prompt_style[len("llama_cpp.") :]
if self.prompt_style is None:
@ -339,6 +352,8 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
template_dir: str | None = None,
add_generation_prompt: bool = True,
default_system_prompt: str | None = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Prompt format using a Jinja template.
@ -350,7 +365,8 @@ class TemplatePromptStyle(AbstractPromptStyleWithSystemPrompt):
given in the messages.
"""
default_system_prompt = default_system_prompt or DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
kwargs["default_system_prompt"] = default_system_prompt
super().__init__(*args, **kwargs)
self._add_generation_prompt = add_generation_prompt