private-gpt/private_gpt/components/llm/custom/sagemaker.py
lopagela 64c5ae214a
feat: Drop loguru and use builtin logging (#1133)
* Configure simple builtin logging

Changed the 2 existing `print` in the `private_gpt` code base into actual python logging, stop using loguru (dependency will be dropped in a later commit).
Try to use the `key=value` logging convention in logs (to indicate what dynamic values represents, and what is dynamic vs not).
Using `%s` log style, so that the string formatting is pushed inside the logger, giving the ability to the logger to determine if the string need to be formatted or not (i.e. strings from debug logs might not be formatted if the log level is not debug)
The (basic) builtin log configuration have been placed in `private_gpt/__init__.py` in order to initialize the logging system even before we start to launch any python code in `private_gpt` package (ensuring we get any initialization log formatted as we want to)
Disabled `uvicorn` custom logging format, resulting in having uvicorn logs being outputted in our formatted.

Some more concise format could be used if we want to, especially:
```
COMPACT_LOG_FORMAT = '%(asctime)s.%(msecs)03d [%(levelname)s] %(name)s - %(message)s'
```

Python documentation and cookbook on logging for reference:
* https://docs.python.org/3/library/logging.html
* https://docs.python.org/3/howto/logging.html

* Removing loguru from the dependencies

Result of `poetry remove loguru`

* PR feedback: using `logger` variable name instead of `log`

---------

Co-authored-by: Louis Melchior <louis@jaris.io>
2023-10-29 19:11:02 +01:00

249 lines
8.5 KiB
Python

# mypy: ignore-errors
from __future__ import annotations
import io
import json
import logging
from typing import TYPE_CHECKING, Any
import boto3 # type: ignore
from llama_index.bridge.pydantic import Field
from llama_index.llms import (
CompletionResponse,
CustomLLM,
LLMMetadata,
)
from llama_index.llms.base import llm_completion_callback
from llama_index.llms.llama_utils import (
completion_to_prompt as generic_completion_to_prompt,
)
from llama_index.llms.llama_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
if TYPE_CHECKING:
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
CompletionResponseGen,
)
logger = logging.getLogger(__name__)
class LineIterator:
r"""A helper class for parsing the byte stream input from TGI container.
The output of the model will be in the following format:
```
b'data:{"token": {"text": " a"}}\n\n'
b'data:{"token": {"text": " challenging"}}\n\n'
b'data:{"token": {"text": " problem"
b'}}'
...
```
While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split
across PayloadPart events. For example:
```
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
```
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character)
within the buffer via the 'scan_lines' function. It maintains the position of the
last read position to ensure that previous bytes are not exposed again. It will
also save any pending lines that doe not end with a '\n' to make sure truncations
are concatinated
"""
def __init__(self, stream: Any) -> None:
"""Line iterator initializer."""
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self) -> Any:
"""Self iterator."""
return self
def __next__(self) -> Any:
"""Next element from iterator."""
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
logger.warning("Unknown event type=%s", chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
class SagemakerLLM(CustomLLM):
"""Sagemaker Inference Endpoint models.
To use, you must supply the endpoint name from your deployed
Sagemaker model & the region where it is deployed.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Sagemaker endpoint.
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
"""
endpoint_name: str = Field(description="")
temperature: float = Field(description="The temperature to use for sampling.")
max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
context_window: int = Field(
description="The maximum number of context tokens for the model."
)
messages_to_prompt: Any = Field(
description="The function to convert messages to a prompt.", exclude=True
)
completion_to_prompt: Any = Field(
description="The function to convert a completion to a prompt.", exclude=True
)
generate_kwargs: dict[str, Any] = Field(
default_factory=dict, description="Kwargs used for generation."
)
model_kwargs: dict[str, Any] = Field(
default_factory=dict, description="Kwargs used for model initialization."
)
verbose: bool = Field(description="Whether to print verbose output.")
_boto_client: Any = boto3.client(
"sagemaker-runtime",
) # TODO make it an optional field
def __init__(
self,
endpoint_name: str | None = "",
temperature: float = 0.1,
max_new_tokens: int = 512, # to review defaults
context_window: int = 2048, # to review defaults
messages_to_prompt: Any = None,
completion_to_prompt: Any = None,
callback_manager: CallbackManager | None = None,
generate_kwargs: dict[str, Any] | None = None,
model_kwargs: dict[str, Any] | None = None,
verbose: bool = True,
) -> None:
"""SagemakerLLM initializer."""
model_kwargs = model_kwargs or {}
model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
completion_to_prompt = completion_to_prompt or generic_completion_to_prompt
generate_kwargs = generate_kwargs or {}
generate_kwargs.update(
{"temperature": temperature, "max_tokens": max_new_tokens}
)
super().__init__(
endpoint_name=endpoint_name,
temperature=temperature,
context_window=context_window,
max_new_tokens=max_new_tokens,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
callback_manager=callback_manager,
generate_kwargs=generate_kwargs,
model_kwargs=model_kwargs,
verbose=verbose,
)
@property
def inference_params(self):
# TODO expose the rest of params
return {
"do_sample": True,
"top_p": 0.7,
"temperature": self.temperature,
"top_k": 50,
"max_new_tokens": self.max_new_tokens,
}
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name="Sagemaker LLama 2",
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
self.generate_kwargs.update({"stream": False})
is_formatted = kwargs.pop("formatted", False)
if not is_formatted:
prompt = self.completion_to_prompt(prompt)
request_params = {
"inputs": prompt,
"stream": False,
"parameters": self.inference_params,
}
resp = self._boto_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
response_body = resp["Body"]
response_str = response_body.read().decode("utf-8")
response_dict = eval(response_str)
return CompletionResponse(
text=response_dict[0]["generated_text"][len(prompt) :], raw=resp
)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
def get_stream():
text = ""
request_params = {
"inputs": prompt,
"stream": True,
"parameters": self.inference_params,
}
resp = self._boto_client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
event_stream = resp["Body"]
start_json = b"{"
stop_token = "<|endoftext|>"
for line in LineIterator(event_stream):
if line != b"" and start_json in line:
data = json.loads(line[line.find(start_json) :].decode("utf-8"))
if data["token"]["text"] != stop_token:
delta = data["token"]["text"]
text += delta
yield CompletionResponse(delta=delta, text=text, raw=data)
return get_stream()