WIP: code cleanup
This commit is contained in:
114
vibe_bot/llama_wrapper.py
Normal file
114
vibe_bot/llama_wrapper.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Wraps the openai calls in generic functions
|
||||
# Supports chat, image, edit, and embeddings
|
||||
# Allows custom endpoints for each of the above supported functions
|
||||
|
||||
import openai
|
||||
from typing import Iterable
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
def chat_completion_think(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=messages, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# Assert that thinking was used
|
||||
if response.choices[0].message.model_extra:
|
||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def chat_completion_instruct(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
)
|
||||
|
||||
# Assert that thinking wasn't used
|
||||
if response.choices[0].message.model_extra:
|
||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def image_generation(prompt: str, n=1) -> str:
|
||||
client = openai.OpenAI(base_url=OPENAI_API_IMAGE_ENDPOINT, api_key="placeholder")
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size="1024x1024",
|
||||
)
|
||||
if response.data:
|
||||
return response.data[0].url
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def image_edit(image, mask, prompt, n=1, size="1024x1024"):
|
||||
client = openai.OpenAI(base_url=OPENAI_API_EDIT_ENDPOINT, api_key="placeholder")
|
||||
response = client.images.edit(
|
||||
image=image,
|
||||
mask=mask,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size=size,
|
||||
)
|
||||
return response.data[0].url
|
||||
|
||||
|
||||
def embeddings(text, model="text-embedding-3-small"):
|
||||
client = openai.OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key="placeholder")
|
||||
response = client.embeddings.create(
|
||||
input=text,
|
||||
model=model,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
Reference in New Issue
Block a user