Files
vibe-bot/vibe_bot/llama_wrapper.py
T
2026-05-24 15:26:13 -04:00

429 lines
12 KiB
Python

"""Wraps the openai calls in generic functions.
Supports chat, image, edit, and embeddings.
Allows custom endpoints for each of the above supported functions.
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Awaitable, Callable, cast
import openai
import requests
if TYPE_CHECKING:
from io import BufferedReader, BytesIO
from openai.types.chat import ChatCompletionMessageParam
def chat_completion(
system_prompt: str,
user_prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send a chat completion request and return the response.
Args:
system_prompt: The system prompt to use.
user_prompt: The user prompt to send.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
timeout=60.0,
)
if not response.choices:
return ""
content = response.choices[0].message.content
if content:
return content.strip()
return ""
def chat_completion_with_history(
system_prompt: str,
prompts: list[dict[str, str]],
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send a chat completion request with conversation history.
Args:
system_prompt: The system prompt to use.
prompts: List of prompt dicts with role and content.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
cast(
"ChatCompletionMessageParam",
{
"role": "system",
"content": system_prompt,
},
),
]
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
content = response.choices[0].message.content
if content:
return content.strip()
return ""
def chat_completion_instruct(
system_prompt: str,
user_prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send an instruction-based chat completion request.
Args:
system_prompt: The system prompt to use.
user_prompt: The user prompt to send.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
},
]
response = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
content = response.choices[0].message.content
if content:
return content.strip()
return ""
async def chat_completion_with_tools(
system_prompt: str,
prompts: list[dict[str, str]],
tools: list[dict[str, object]],
tool_executor: Callable[[str, dict[str, str]], str],
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
max_tool_rounds: int = 5,
tool_call_notifier: (
Callable[[str, dict[str, str]], None]
| Callable[[str, dict[str, str]], Awaitable[None]]
| None
) = None,
) -> str:
"""Send a chat completion request with tool support and iterative tool calling.
Args:
system_prompt: The system prompt to use.
prompts: List of prompt dicts with role and content.
tools: List of tool definitions in OpenAI format.
tool_executor: A callable that takes (tool_name: str, tool_args: dict) -> str.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
max_tool_rounds: Maximum number of tool call rounds before giving up.
tool_call_notifier: Optional callback invoked before each tool call
with (tool_name, tool_args).
Returns:
The model's final response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
cast(
"ChatCompletionMessageParam",
{
"role": "system",
"content": system_prompt,
},
),
]
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
for _round in range(max_tool_rounds):
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools, # type: ignore[arg-type]
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
message = response.choices[0].message
# Check if the model wants to call a tool
tool_calls = message.tool_calls
if tool_calls:
assistant_msg: dict[str, object] = {
"role": "assistant",
"content": message.content or "",
}
tool_call_dicts: list[dict[str, object]] = []
for tool_call in tool_calls:
tool_call_dicts.append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name, # type: ignore[union-attr]
"arguments": tool_call.function.arguments, # type: ignore[union-attr]
},
},
)
assistant_msg["tool_calls"] = tool_call_dicts
messages.append(cast("ChatCompletionMessageParam", assistant_msg))
# Execute each tool call and add results to messages
for tool_call in tool_calls:
tool_name = tool_call.function.name # type: ignore[union-attr]
tool_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
if tool_call_notifier:
result = tool_call_notifier(tool_name, tool_args)
if hasattr(result, "__await__"):
await result # type: ignore[misc]
tool_result = tool_executor(tool_name, tool_args)
messages.append(
cast(
"ChatCompletionMessageParam",
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
},
),
)
continue
# No more tool calls, return the final response
content = message.content
if content:
return content.strip()
return ""
return ""
def image_generation(
prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str = "gen",
n: int = 1,
) -> str:
"""Generate an image using the given prompt.
Args:
prompt: The image generation prompt.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for image generation.
n: Number of images to generate.
Returns:
The base64 encoded image data. Decode and write to a file.
"""
client = openai.OpenAI(
base_url=openai_url,
api_key=openai_api_key,
max_retries=0,
)
try:
response = client.images.generate(
prompt=prompt,
n=n,
size="1024x1024",
model=model,
timeout=120.0,
)
except openai.APIConnectionError:
return ""
if response.data:
return response.data[0].b64_json or ""
return ""
def image_edit(
image: BufferedReader | BytesIO | list[BufferedReader] | list[BytesIO],
prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str = "edit",
n: int = 1,
) -> str:
"""Edit an existing image using a prompt.
Args:
image: The source image as a file-like object or list thereof.
prompt: The edit instruction.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for image editing.
n: Number of edited images to generate.
Returns:
The base64 encoded edited image data.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.edit(
image=image,
prompt=prompt,
n=n,
size="1024x1024",
model=model,
)
if response.data:
return response.data[0].b64_json or ""
return ""
def embedding(
text: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
) -> list[float]:
"""Generate an embedding vector for the given text.
Uses a raw HTTP request to avoid the OpenAI SDK injecting
unsupported parameters like encoding_format.
Args:
text: The text to embed.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The embedding model to use.
Returns:
The embedding vector as a list of floats, or an empty list on failure.
"""
url = f"{openai_url.rstrip('/')}/embeddings"
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
}
payload = {"model": model, "input": [text]}
try:
resp = requests.post(url, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
except requests.RequestException:
return []
data = resp.json()
# Handle both OpenAI-style response ({"data": [...]}) and
# Ollama-style response ([{...}]) where the API returns a list directly
if isinstance(data, list):
first = data[0]
if not isinstance(first, dict):
return []
raw = first.get("embedding")
elif isinstance(data, dict):
if not data.get("data"):
return []
raw = data["data"][0].get("embedding")
else:
return []
if raw is None:
return []
if isinstance(raw, str):
raw = json.loads(raw)
if not isinstance(raw, list):
raw = list(raw)
if not raw:
return []
return list[float](raw)