"""Wraps the openai calls in generic functions. Supports chat, image, edit, and embeddings. Allows custom endpoints for each of the above supported functions. """ from __future__ import annotations import json from typing import TYPE_CHECKING, cast import openai import requests if TYPE_CHECKING: from io import BufferedReader, BytesIO from openai.types.chat import ChatCompletionMessageParam def chat_completion( system_prompt: str, user_prompt: str, *, openai_url: str, openai_api_key: str, model: str, max_tokens: int = 1000, ) -> str: """Send a chat completion request and return the response. Args: system_prompt: The system prompt to use. user_prompt: The user prompt to send. openai_url: The OpenAI-compatible API URL. openai_api_key: The API key for authentication. model: The model to use for completion. max_tokens: Maximum number of tokens to generate. Returns: The model's response text, stripped of whitespace. """ client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) messages: list[ChatCompletionMessageParam] = [ { "role": "system", "content": system_prompt, }, { "role": "user", "content": user_prompt, }, ] response = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, timeout=60.0, ) if not response.choices: return "" content = response.choices[0].message.content if content: return content.strip() return "" def chat_completion_with_history( system_prompt: str, prompts: list[dict[str, str]], *, openai_url: str, openai_api_key: str, model: str, max_tokens: int = 1000, ) -> str: """Send a chat completion request with conversation history. Args: system_prompt: The system prompt to use. prompts: List of prompt dicts with role and content. openai_url: The OpenAI-compatible API URL. openai_api_key: The API key for authentication. model: The model to use for completion. max_tokens: Maximum number of tokens to generate. Returns: The model's response text, stripped of whitespace. """ client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) messages: list[ChatCompletionMessageParam] = [ cast( "ChatCompletionMessageParam", { "role": "system", "content": system_prompt, }, ), ] messages.extend(cast("list[ChatCompletionMessageParam]", prompts)) response = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, seed=-1, timeout=60.0, ) if not response.choices: return "" content = response.choices[0].message.content if content: return content.strip() return "" def chat_completion_instruct( system_prompt: str, user_prompt: str, *, openai_url: str, openai_api_key: str, model: str, max_tokens: int = 1000, ) -> str: """Send an instruction-based chat completion request. Args: system_prompt: The system prompt to use. user_prompt: The user prompt to send. openai_url: The OpenAI-compatible API URL. openai_api_key: The API key for authentication. model: The model to use for completion. max_tokens: Maximum number of tokens to generate. Returns: The model's response text, stripped of whitespace. """ client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) messages: list[ChatCompletionMessageParam] = [ { "role": "system", "content": system_prompt, }, { "role": "user", "content": user_prompt, }, ] response = client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, seed=-1, timeout=60.0, ) if not response.choices: return "" content = response.choices[0].message.content if content: return content.strip() return "" def image_generation( prompt: str, *, openai_url: str, openai_api_key: str, model: str = "gen", n: int = 1, ) -> str: """Generate an image using the given prompt. Args: prompt: The image generation prompt. openai_url: The OpenAI-compatible API URL. openai_api_key: The API key for authentication. model: The model to use for image generation. n: Number of images to generate. Returns: The base64 encoded image data. Decode and write to a file. """ client = openai.OpenAI( base_url=openai_url, api_key=openai_api_key, max_retries=0, ) try: response = client.images.generate( prompt=prompt, n=n, size="1024x1024", model=model, timeout=120.0, ) except openai.APIConnectionError: return "" if response.data: return response.data[0].b64_json or "" return "" def image_edit( image: BufferedReader | BytesIO | list[BufferedReader] | list[BytesIO], prompt: str, *, openai_url: str, openai_api_key: str, model: str = "edit", n: int = 1, ) -> str: """Edit an existing image using a prompt. Args: image: The source image as a file-like object or list thereof. prompt: The edit instruction. openai_url: The OpenAI-compatible API URL. openai_api_key: The API key for authentication. model: The model to use for image editing. n: Number of edited images to generate. Returns: The base64 encoded edited image data. """ client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) response = client.images.edit( image=image, prompt=prompt, n=n, size="1024x1024", model=model, ) if response.data: return response.data[0].b64_json or "" return "" def embedding( text: str, *, openai_url: str, openai_api_key: str, model: str, ) -> list[float]: """Generate an embedding vector for the given text. Uses a raw HTTP request to avoid the OpenAI SDK injecting unsupported parameters like encoding_format. Args: text: The text to embed. openai_url: The OpenAI-compatible API URL. openai_api_key: The API key for authentication. model: The embedding model to use. Returns: The embedding vector as a list of floats, or an empty list on failure. """ url = f"{openai_url.rstrip('/')}/embeddings" headers = { "Authorization": f"Bearer {openai_api_key}", "Content-Type": "application/json", } payload = {"model": model, "input": [text]} try: resp = requests.post(url, headers=headers, json=payload, timeout=30) resp.raise_for_status() except requests.RequestException: return [] data = resp.json() if not data.get("data"): return [] raw = data["data"][0].get("embedding") if isinstance(raw, str): raw = json.loads(raw) if not isinstance(raw, list): raw = list(raw) return list[float](raw)