fix linting, formatting, and add tests

This commit is contained in:
2026-05-23 19:06:53 -04:00
parent 5e708b009c
commit 6ec9fbe85f
15 changed files with 2911 additions and 491 deletions
+137 -42
View File
@@ -1,23 +1,46 @@
# Wraps the openai calls in generic functions
# Supports chat, image, edit, and embeddings
# Allows custom endpoints for each of the above supported functions
"""Wraps the openai calls in generic functions.
Supports chat, image, edit, and embeddings.
Allows custom endpoints for each of the above supported functions.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, cast
import openai
from typing import Iterable
from openai.types.chat import ChatCompletionMessageParam
from io import BufferedReader, BytesIO
if TYPE_CHECKING:
from io import BufferedReader, BytesIO
from openai.types.chat import ChatCompletionMessageParam
def chat_completion(
system_prompt: str,
user_prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send a chat completion request and return the response.
Args:
system_prompt: The system prompt to use.
user_prompt: The user prompt to send.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: Iterable[ChatCompletionMessageParam] = [
messages: list[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
@@ -28,35 +51,51 @@ def chat_completion(
},
]
response = client.chat.completions.create(
model=model, messages=messages, max_tokens=max_tokens
model=model,
messages=messages,
max_tokens=max_tokens,
)
# Assert that thinking was used
if response.choices[0].message.model_extra:
assert response.choices[0].message.model_extra.get("reasoning_content")
content = response.choices[0].message.content
if content:
return content.strip()
else:
return ""
return ""
def chat_completion_with_history(
system_prompt: str,
prompts: Iterable[ChatCompletionMessageParam],
prompts: list[dict[str, str]],
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send a chat completion request with conversation history.
Args:
system_prompt: The system prompt to use.
prompts: List of prompt dicts with role and content.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: Iterable[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
}
] + prompts # type: ignore
messages: list[ChatCompletionMessageParam] = [
cast(
"ChatCompletionMessageParam",
{
"role": "system",
"content": system_prompt,
},
),
]
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
response = client.chat.completions.create(
model=model,
messages=messages,
@@ -67,20 +106,34 @@ def chat_completion_with_history(
content = response.choices[0].message.content
if content:
return content.strip()
else:
return ""
return ""
def chat_completion_instruct(
system_prompt: str,
user_prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
) -> str:
"""Send an instruction-based chat completion request.
Args:
system_prompt: The system prompt to use.
user_prompt: The user prompt to send.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
Returns:
The model's response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: Iterable[ChatCompletionMessageParam] = [
messages: list[ChatCompletionMessageParam] = [
{
"role": "system",
"content": system_prompt,
@@ -100,26 +153,37 @@ def chat_completion_instruct(
content = response.choices[0].message.content
if content:
return content.strip()
else:
return ""
return ""
def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str:
"""Generates an image using the given prompt and returns the base64 encoded image data
def image_generation(
prompt: str,
openai_url: str,
openai_api_key: str,
n: int = 1,
) -> str:
"""Generate an image using the given prompt.
Args:
prompt: The image generation prompt.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
n: Number of images to generate.
Returns:
str: The base64 encoded image data. Decode and write to a file.
The base64 encoded image data. Decode and write to a file.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.generate(
prompt=prompt,
n=n,
size="1024x1024",
model="gen",
)
if response.data:
return response.data[0].b64_json or ""
else:
return ""
return ""
def image_edit(
@@ -127,33 +191,64 @@ def image_edit(
prompt: str,
openai_url: str,
openai_api_key: str,
n=1,
n: int = 1,
) -> str:
"""Edit an existing image using a prompt.
Args:
image: The source image as a file-like object or list thereof.
prompt: The edit instruction.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
n: Number of edited images to generate.
Returns:
The base64 encoded edited image data.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.edit(
image=image,
prompt=prompt,
n=n,
size="1024x1024",
model="edit",
)
if response.data:
return response.data[0].b64_json or ""
else:
return ""
return ""
def embedding(
text: str, openai_url: str, openai_api_key: str, model: str
text: str,
*,
openai_url: str,
openai_api_key: str,
model: str,
) -> list[float]:
"""Generate an embedding vector for the given text.
Args:
text: The text to embed.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The embedding model to use.
Returns:
The embedding vector as a list of floats, or an empty list on failure.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.embeddings.create(
input=[text], model=model, encoding_format="float"
input=[text],
model=model,
encoding_format="float",
)
if response:
raw_data = response[0].embedding # type: ignore
# The result could be an array of floats or an array of an array of floats.
try:
return raw_data[0]
except Exception:
return raw_data
data = response.data
raw_data = data[0].embedding
# The result could be an array of floats or a single float.
if not isinstance(raw_data, float):
return list(raw_data)
return [raw_data]
return []