115 lines
3.0 KiB
Python
115 lines
3.0 KiB
Python
# Wraps the openai calls in generic functions
|
|
# Supports chat, image, edit, and embeddings
|
|
# Allows custom endpoints for each of the above supported functions
|
|
|
|
import openai
|
|
from typing import Iterable
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
|
|
def chat_completion_think(
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: Iterable[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt,
|
|
},
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=model, messages=messages, max_tokens=max_tokens
|
|
)
|
|
|
|
# Assert that thinking was used
|
|
if response.choices[0].message.model_extra:
|
|
assert response.choices[0].message.model_extra.get("reasoning_content")
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
else:
|
|
return ""
|
|
|
|
|
|
def chat_completion_instruct(
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
openai_url: str,
|
|
openai_api_key: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
) -> str:
|
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
|
messages: Iterable[ChatCompletionMessageParam] = [
|
|
{
|
|
"role": "system",
|
|
"content": system_prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_prompt,
|
|
},
|
|
]
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
extra_body={
|
|
"chat_template_kwargs": {"enable_thinking": False},
|
|
},
|
|
)
|
|
|
|
# Assert that thinking wasn't used
|
|
if response.choices[0].message.model_extra:
|
|
assert response.choices[0].message.model_extra.get("reasoning_content")
|
|
|
|
content = response.choices[0].message.content
|
|
if content:
|
|
return content.strip()
|
|
else:
|
|
return ""
|
|
|
|
|
|
def image_generation(prompt: str, n=1) -> str:
|
|
client = openai.OpenAI(base_url=OPENAI_API_IMAGE_ENDPOINT, api_key="placeholder")
|
|
response = client.images.generate(
|
|
prompt=prompt,
|
|
n=n,
|
|
size="1024x1024",
|
|
)
|
|
if response.data:
|
|
return response.data[0].url
|
|
else:
|
|
return ""
|
|
|
|
|
|
def image_edit(image, mask, prompt, n=1, size="1024x1024"):
|
|
client = openai.OpenAI(base_url=OPENAI_API_EDIT_ENDPOINT, api_key="placeholder")
|
|
response = client.images.edit(
|
|
image=image,
|
|
mask=mask,
|
|
prompt=prompt,
|
|
n=n,
|
|
size=size,
|
|
)
|
|
return response.data[0].url
|
|
|
|
|
|
def embeddings(text, model="text-embedding-3-small"):
|
|
client = openai.OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key="placeholder")
|
|
response = client.embeddings.create(
|
|
input=text,
|
|
model=model,
|
|
)
|
|
return response.data[0].embedding
|