fix linting, formatting, and add tests
This commit is contained in:
+137
-42
@@ -1,23 +1,46 @@
|
||||
# Wraps the openai calls in generic functions
|
||||
# Supports chat, image, edit, and embeddings
|
||||
# Allows custom endpoints for each of the above supported functions
|
||||
"""Wraps the openai calls in generic functions.
|
||||
|
||||
Supports chat, image, edit, and embeddings.
|
||||
Allows custom endpoints for each of the above supported functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import openai
|
||||
from typing import Iterable
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from io import BufferedReader, BytesIO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BufferedReader, BytesIO
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
def chat_completion(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
"""Send a chat completion request and return the response.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
user_prompt: The user prompt to send.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
The model's response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
@@ -28,35 +51,51 @@ def chat_completion(
|
||||
},
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=messages, max_tokens=max_tokens
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Assert that thinking was used
|
||||
if response.choices[0].message.model_extra:
|
||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def chat_completion_with_history(
|
||||
system_prompt: str,
|
||||
prompts: Iterable[ChatCompletionMessageParam],
|
||||
prompts: list[dict[str, str]],
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
"""Send a chat completion request with conversation history.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
prompts: List of prompt dicts with role and content.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
The model's response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
] + prompts # type: ignore
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
cast(
|
||||
"ChatCompletionMessageParam",
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
),
|
||||
]
|
||||
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
@@ -67,20 +106,34 @@ def chat_completion_with_history(
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def chat_completion_instruct(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
"""Send an instruction-based chat completion request.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
user_prompt: The user prompt to send.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
The model's response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
@@ -100,26 +153,37 @@ def chat_completion_instruct(
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str:
|
||||
"""Generates an image using the given prompt and returns the base64 encoded image data
|
||||
def image_generation(
|
||||
prompt: str,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
n: int = 1,
|
||||
) -> str:
|
||||
"""Generate an image using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The image generation prompt.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
n: Number of images to generate.
|
||||
|
||||
Returns:
|
||||
str: The base64 encoded image data. Decode and write to a file.
|
||||
The base64 encoded image data. Decode and write to a file.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size="1024x1024",
|
||||
model="gen",
|
||||
)
|
||||
if response.data:
|
||||
return response.data[0].b64_json or ""
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def image_edit(
|
||||
@@ -127,33 +191,64 @@ def image_edit(
|
||||
prompt: str,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
n=1,
|
||||
n: int = 1,
|
||||
) -> str:
|
||||
"""Edit an existing image using a prompt.
|
||||
|
||||
Args:
|
||||
image: The source image as a file-like object or list thereof.
|
||||
prompt: The edit instruction.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
n: Number of edited images to generate.
|
||||
|
||||
Returns:
|
||||
The base64 encoded edited image data.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
response = client.images.edit(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size="1024x1024",
|
||||
model="edit",
|
||||
)
|
||||
if response.data:
|
||||
return response.data[0].b64_json or ""
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def embedding(
|
||||
text: str, openai_url: str, openai_api_key: str, model: str
|
||||
text: str,
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
) -> list[float]:
|
||||
"""Generate an embedding vector for the given text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The embedding model to use.
|
||||
|
||||
Returns:
|
||||
The embedding vector as a list of floats, or an empty list on failure.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
response = client.embeddings.create(
|
||||
input=[text], model=model, encoding_format="float"
|
||||
input=[text],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
if response:
|
||||
raw_data = response[0].embedding # type: ignore
|
||||
# The result could be an array of floats or an array of an array of floats.
|
||||
try:
|
||||
return raw_data[0]
|
||||
except Exception:
|
||||
return raw_data
|
||||
data = response.data
|
||||
raw_data = data[0].embedding
|
||||
# The result could be an array of floats or a single float.
|
||||
if not isinstance(raw_data, float):
|
||||
return list(raw_data)
|
||||
return [raw_data]
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user