295 lines
7.3 KiB
Python
295 lines
7.3 KiB
Python
"""Wraps the openai calls in generic functions.
|
|
|
|
Supports chat, image, edit, and embeddings.
|
|
Allows custom endpoints for each of the above supported functions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
import openai
|
|
import requests
|
|
|
|
if TYPE_CHECKING:
|
|
from io import BufferedReader, BytesIO
|
|
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
|
|
def chat_completion(
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
"""Send a chat completion request and return the response.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
user_prompt: The user prompt to send.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
The model's response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt,
|
|
},
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
def chat_completion_with_history(
|
|
system_prompt: str,
|
|
prompts: list[dict[str, str]],
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
"""Send a chat completion request with conversation history.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
prompts: List of prompt dicts with role and content.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
The model's response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
cast(
|
|
"ChatCompletionMessageParam",
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
),
|
|
]
|
|
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
seed=-1,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
def chat_completion_instruct(
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
"""Send an instruction-based chat completion request.
|
|
|
|
Args:
|
|
system_prompt: The system prompt to use.
|
|
user_prompt: The user prompt to send.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for completion.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
|
|
Returns:
|
|
The model's response text, stripped of whitespace.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt,
|
|
},
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
seed=-1,
|
|
timeout=60.0,
|
|
)
|
|
|
|
if not response.choices:
|
|
return ""
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
return ""
|
|
|
|
|
|
def image_generation(
|
|
prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str = "gen",
|
|
n: int = 1,
|
|
) -> str:
|
|
"""Generate an image using the given prompt.
|
|
|
|
Args:
|
|
prompt: The image generation prompt.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for image generation.
|
|
n: Number of images to generate.
|
|
|
|
Returns:
|
|
The base64 encoded image data. Decode and write to a file.
|
|
|
|
"""
|
|
client = openai.OpenAI(
|
|
base_url=openai_url,
|
|
api_key=openai_api_key,
|
|
max_retries=0,
|
|
)
|
|
try:
|
|
response = client.images.generate(
|
|
prompt=prompt,
|
|
n=n,
|
|
size="1024x1024",
|
|
model=model,
|
|
timeout=120.0,
|
|
)
|
|
except openai.APIConnectionError:
|
|
return ""
|
|
if response.data:
|
|
return response.data[0].b64_json or ""
|
|
return ""
|
|
|
|
|
|
def image_edit(
|
|
image: BufferedReader | BytesIO | list[BufferedReader] | list[BytesIO],
|
|
prompt: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str = "edit",
|
|
n: int = 1,
|
|
) -> str:
|
|
"""Edit an existing image using a prompt.
|
|
|
|
Args:
|
|
image: The source image as a file-like object or list thereof.
|
|
prompt: The edit instruction.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The model to use for image editing.
|
|
n: Number of edited images to generate.
|
|
|
|
Returns:
|
|
The base64 encoded edited image data.
|
|
|
|
"""
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
response = client.images.edit(
|
|
image=image,
|
|
prompt=prompt,
|
|
n=n,
|
|
size="1024x1024",
|
|
model=model,
|
|
)
|
|
if response.data:
|
|
return response.data[0].b64_json or ""
|
|
return ""
|
|
|
|
|
|
def embedding(
|
|
text: str,
|
|
*,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
) -> list[float]:
|
|
"""Generate an embedding vector for the given text.
|
|
|
|
Uses a raw HTTP request to avoid the OpenAI SDK injecting
|
|
unsupported parameters like encoding_format.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
openai_url: The OpenAI-compatible API URL.
|
|
openai_api_key: The API key for authentication.
|
|
model: The embedding model to use.
|
|
|
|
Returns:
|
|
The embedding vector as a list of floats, or an empty list on failure.
|
|
|
|
"""
|
|
url = f"{openai_url.rstrip('/')}/embeddings"
|
|
headers = {
|
|
"Authorization": f"Bearer {openai_api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
payload = {"model": model, "input": [text]}
|
|
|
|
try:
|
|
resp = requests.post(url, headers=headers, json=payload, timeout=30)
|
|
resp.raise_for_status()
|
|
except requests.RequestException:
|
|
return []
|
|
|
|
data = resp.json()
|
|
if not data.get("data"):
|
|
return []
|
|
|
|
raw = data["data"][0].get("embedding")
|
|
if isinstance(raw, str):
|
|
raw = json.loads(raw)
|
|
if not isinstance(raw, list):
|
|
raw = list(raw)
|
|
return list[float](raw)
|