Files
vibe-bot/vibe_bot/llama_wrapper.py
T

255 lines
6.5 KiB
Python

"""Wraps the openai calls in generic functions.
Supports chat, image, edit, and embeddings.
Allows custom endpoints for each of the above supported functions.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, cast
import openai
if TYPE_CHECKING:
from io import BufferedReader, BytesIO
from openai.types.chat import ChatCompletionMessageParam
def chat_completion(
system_prompt: str,
user_prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send a chat completion request and return the response.
Args:
system_prompt: The system prompt to use.
user_prompt: The user prompt to send.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
)
content = response.choices[0].message.content
if content:
return content.strip()
return ""
def chat_completion_with_history(
system_prompt: str,
prompts: list[dict[str, str]],
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send a chat completion request with conversation history.
Args:
system_prompt: The system prompt to use.
prompts: List of prompt dicts with role and content.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
cast(
"ChatCompletionMessageParam",
{
"role": "system",
"content": system_prompt,
},
),
]
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
seed=-1,
)
content = response.choices[0].message.content
if content:
return content.strip()
return ""
def chat_completion_instruct(
system_prompt: str,
user_prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send an instruction-based chat completion request.
Args:
system_prompt: The system prompt to use.
user_prompt: The user prompt to send.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
seed=-1,
)
content = response.choices[0].message.content
if content:
return content.strip()
return ""
def image_generation(
prompt: str,
openai_url: str,
openai_api_key: str,
n: int = 1,
) -> str:
"""Generate an image using the given prompt.
Args:
prompt: The image generation prompt.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
n: Number of images to generate.
Returns:
The base64 encoded image data. Decode and write to a file.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.generate(
prompt=prompt,
n=n,
size="1024x1024",
model="gen",
)
if response.data:
return response.data[0].b64_json or ""
return ""
def image_edit(
image: BufferedReader | BytesIO | list[BufferedReader] | list[BytesIO],
prompt: str,
openai_url: str,
openai_api_key: str,
n: int = 1,
) -> str:
"""Edit an existing image using a prompt.
Args:
image: The source image as a file-like object or list thereof.
prompt: The edit instruction.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
n: Number of edited images to generate.
Returns:
The base64 encoded edited image data.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.edit(
image=image,
prompt=prompt,
n=n,
size="1024x1024",
model="edit",
)
if response.data:
return response.data[0].b64_json or ""
return ""
def embedding(
text: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
) -> list[float]:
"""Generate an embedding vector for the given text.
Args:
text: The text to embed.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The embedding model to use.
Returns:
The embedding vector as a list of floats, or an empty list on failure.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.embeddings.create(
input=[text],
model=model,
encoding_format="float",
)
if response:
data = response.data
raw_data = data[0].embedding
# The result could be an array of floats or a single float.
if not isinstance(raw_data, float):
return list(raw_data)
return [raw_data]
return []