429 lines
12 KiB
Python
429 lines
12 KiB
Python
"""Wraps the openai calls in generic functions.
|
|
|
|
Supports chat, image, edit, and embeddings.
|
|
Allows custom endpoints for each of the above supported functions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import TYPE_CHECKING, Awaitable, Callable, cast
|
|
|
|
import openai
|
|
import requests
|
|
|
|
if TYPE_CHECKING:
|
|
from io import BufferedReader, BytesIO
|
|
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
|
|
def chat_completion(
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
"""Send a chat completion request and return the response.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
user_prompt: The user prompt to send.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
The model's response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt,
|
|
},
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
def chat_completion_with_history(
|
|
system_prompt: str,
|
|
prompts: list[dict[str, str]],
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
"""Send a chat completion request with conversation history.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
prompts: List of prompt dicts with role and content.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
The model's response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
cast(
|
|
"ChatCompletionMessageParam",
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
),
|
|
]
|
|
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
seed=-1,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
def chat_completion_instruct(
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
"""Send an instruction-based chat completion request.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
user_prompt: The user prompt to send.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
The model's response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt,
|
|
},
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
seed=-1,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
async def chat_completion_with_tools(
|
|
system_prompt: str,
|
|
prompts: list[dict[str, str]],
|
|
tools: list[dict[str, object]],
|
|
tool_executor: Callable[[str, dict[str, str]], str],
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
max_tool_rounds: int = 5,
|
|
tool_call_notifier: (
|
|
Callable[[str, dict[str, str]], None]
|
|
| Callable[[str, dict[str, str]], Awaitable[None]]
|
|
| None
|
|
) = None,
|
|
) -> str:
|
|
"""Send a chat completion request with tool support and iterative tool calling.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
prompts: List of prompt dicts with role and content.
|
|
tools: List of tool definitions in OpenAI format.
|
|
tool_executor: A callable that takes (tool_name: str, tool_args: dict) -> str.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
max_tool_rounds: Maximum number of tool call rounds before giving up.
|
|
tool_call_notifier: Optional callback invoked before each tool call
|
|
with (tool_name, tool_args).
|
|
|
|
Returns:
|
|
The model's final response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
cast(
|
|
"ChatCompletionMessageParam",
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
),
|
|
]
|
|
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
|
|
|
for _round in range(max_tool_rounds):
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools, # type: ignore[arg-type]
|
|
max_tokens=max_tokens,
|
|
seed=-1,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
message = response.choices[0].message
|
|
|
|
# Check if the model wants to call a tool
|
|
tool_calls = message.tool_calls
|
|
if tool_calls:
|
|
assistant_msg: dict[str, object] = {
|
|
"role": "assistant",
|
|
"content": message.content or "",
|
|
}
|
|
tool_call_dicts: list[dict[str, object]] = []
|
|
for tool_call in tool_calls:
|
|
tool_call_dicts.append(
|
|
{
|
|
"id": tool_call.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_call.function.name, # type: ignore[union-attr]
|
|
"arguments": tool_call.function.arguments, # type: ignore[union-attr]
|
|
},
|
|
},
|
|
)
|
|
assistant_msg["tool_calls"] = tool_call_dicts
|
|
messages.append(cast("ChatCompletionMessageParam", assistant_msg))
|
|
|
|
# Execute each tool call and add results to messages
|
|
for tool_call in tool_calls:
|
|
tool_name = tool_call.function.name # type: ignore[union-attr]
|
|
tool_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
|
|
|
if tool_call_notifier:
|
|
result = tool_call_notifier(tool_name, tool_args)
|
|
if hasattr(result, "__await__"):
|
|
await result # type: ignore[misc]
|
|
|
|
tool_result = tool_executor(tool_name, tool_args)
|
|
|
|
messages.append(
|
|
cast(
|
|
"ChatCompletionMessageParam",
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"content": tool_result,
|
|
},
|
|
),
|
|
)
|
|
continue
|
|
|
|
# No more tool calls, return the final response
|
|
content = message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
return ""
|
|
|
|
|
|
def image_generation(
|
|
prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str = "gen",
|
|
n: int = 1,
|
|
) -> str:
|
|
"""Generate an image using the given prompt.
|
|
|
|
Args:
|
|
prompt: The image generation prompt.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for image generation.
|
|
n: Number of images to generate.
|
|
|
|
Returns:
|
|
The base64 encoded image data. Decode and write to a file.
|
|
|
|
"""
|
|
client = openai.OpenAI(
|
|
base_url=openai_url,
|
|
api_key=openai_api_key,
|
|
max_retries=0,
|
|
)
|
|
try:
|
|
response = client.images.generate(
|
|
prompt=prompt,
|
|
n=n,
|
|
size="1024x1024",
|
|
model=model,
|
|
timeout=120.0,
|
|
)
|
|
except openai.APIConnectionError:
|
|
return ""
|
|
if response.data:
|
|
return response.data[0].b64_json or ""
|
|
return ""
|
|
|
|
|
|
def image_edit(
|
|
image: BufferedReader | BytesIO | list[BufferedReader] | list[BytesIO],
|
|
prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str = "edit",
|
|
n: int = 1,
|
|
) -> str:
|
|
"""Edit an existing image using a prompt.
|
|
|
|
Args:
|
|
image: The source image as a file-like object or list thereof.
|
|
prompt: The edit instruction.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for image editing.
|
|
n: Number of edited images to generate.
|
|
|
|
Returns:
|
|
The base64 encoded edited image data.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
response = client.images.edit(
|
|
image=image,
|
|
prompt=prompt,
|
|
n=n,
|
|
size="1024x1024",
|
|
model=model,
|
|
)
|
|
if response.data:
|
|
return response.data[0].b64_json or ""
|
|
return ""
|
|
|
|
|
|
def embedding(
|
|
text: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
) -> list[float]:
|
|
"""Generate an embedding vector for the given text.
|
|
|
|
Uses a raw HTTP request to avoid the OpenAI SDK injecting
|
|
unsupported parameters like encoding_format.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The embedding model to use.
|
|
|
|
Returns:
|
|
The embedding vector as a list of floats, or an empty list on failure.
|
|
|
|
"""
|
|
url = f"{openai_url.rstrip('/')}/embeddings"
|
|
headers = {
|
|
"Authorization": f"Bearer {openai_api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
payload = {"model": model, "input": [text]}
|
|
|
|
try:
|
|
resp = requests.post(url, headers=headers, json=payload, timeout=30)
|
|
resp.raise_for_status()
|
|
except requests.RequestException:
|
|
return []
|
|
|
|
data = resp.json()
|
|
|
|
# Handle both OpenAI-style response ({"data": [...]}) and
|
|
# Ollama-style response ([{...}]) where the API returns a list directly
|
|
if isinstance(data, list):
|
|
first = data[0]
|
|
if not isinstance(first, dict):
|
|
return []
|
|
raw = first.get("embedding")
|
|
elif isinstance(data, dict):
|
|
if not data.get("data"):
|
|
return []
|
|
raw = data["data"][0].get("embedding")
|
|
else:
|
|
return []
|
|
|
|
if raw is None:
|
|
return []
|
|
|
|
if isinstance(raw, str):
|
|
raw = json.loads(raw)
|
|
if not isinstance(raw, list):
|
|
raw = list(raw)
|
|
if not raw:
|
|
return []
|
|
return list[float](raw)
|