diff --git a/image-edit.png b/image-edit.png new file mode 100644 index 0000000..fe8ea03 Binary files /dev/null and b/image-edit.png differ diff --git a/image-gen.png b/image-gen.png new file mode 100644 index 0000000..70b0719 Binary files /dev/null and b/image-gen.png differ diff --git a/vibe_bot/config.py b/vibe_bot/config.py new file mode 100644 index 0000000..829fbee --- /dev/null +++ b/vibe_bot/config.py @@ -0,0 +1,85 @@ +from dotenv import load_dotenv +import os +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +load_dotenv() + +# Discord +DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "") + +# Endpoints +CHAT_ENDPOINT = os.getenv("CHAT_ENDPOINT", "") +COMPLETION_ENDPOINT = os.getenv("COMPLETION_ENDPOINT", "") +IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT", "") +IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT", "") +EMBEDDING_ENDPOINT = os.getenv("EMBEDDING_ENDPOINT", "") +MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000")) + +# API Keys +CHAT_ENDPOINT_KEY = os.getenv("CHAT_ENDPOINT_KEY", "placeholder") +COMPLETION_ENDPOINT_KEY = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder") +IMAGE_GEN_ENDPOINT_KEY = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder") +IMAGE_EDIT_ENDPOINT_KEY = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder") +EMBEDDING_ENDPOINT_KEY = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder") + +# Models +CHAT_MODEL = os.getenv("CHAT_MODEL", "") +COMPLETION_MODEL = os.getenv("COMPLETION_MODEL", "") +IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "") +IMAGE_EDIT_MODEL = os.getenv("IMAGE_EDIT_MODEL", "") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "") + +# Database and embeddings +DB_PATH = os.getenv("DB_PATH", "chat_history.db") +EMBEDDING_DIMENSION = 2048 +MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000")) +SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7")) +TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5")) + +# Check token +if not DISCORD_TOKEN: + raise Exception("DISCORD_TOKEN required.") + +# Check endpoints +if not CHAT_ENDPOINT: + raise Exception("CHAT_ENDPOINT required.") + +if not COMPLETION_ENDPOINT: + raise Exception("COMPLETION_ENDPOINT required.") + +if not IMAGE_GEN_ENDPOINT: + raise Exception("IMAGE_GEN_ENDPOINT required.") + +if not IMAGE_EDIT_ENDPOINT: + raise Exception("IMAGE_EDIT_ENDPOINT required.") + +if not EMBEDDING_ENDPOINT: + raise Exception("EMBEDDING_ENDPOINT required.") + +# Check models +if not CHAT_MODEL: + raise Exception("CHAT_MODEL required.") + +if not COMPLETION_MODEL: + raise Exception("COMPLETION_MODEL required.") + +if not IMAGE_GEN_MODEL: + raise Exception("IMAGE_GEN_MODEL required.") + +if not IMAGE_EDIT_MODEL: + raise Exception("IMAGE_EDIT_MODEL required.") + +if not EMBEDDING_MODEL: + raise Exception("EMBEDDING_MODEL required.") + +logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}") +logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}") +logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}") +logger.info(f"IMAGE_EDIT_ENDPOINT set to {IMAGE_EDIT_ENDPOINT}") +logger.info(f"EMBEDDING_ENDPOINT set to {EMBEDDING_ENDPOINT}") diff --git a/vibe_bot/database.py b/vibe_bot/database.py index 3f0661e..445cea9 100644 --- a/vibe_bot/database.py +++ b/vibe_bot/database.py @@ -1,32 +1,27 @@ import sqlite3 -import os from typing import Optional, List, Tuple from datetime import datetime import numpy as np from openai import OpenAI import logging +import llama_wrapper # type: ignore +from config import ( # type: ignore + DB_PATH, + EMBEDDING_MODEL, + EMBEDDING_ENDPOINT, + EMBEDDING_ENDPOINT_KEY, + MAX_HISTORY_MESSAGES, + SIMILARITY_THRESHOLD, + TOP_K_RESULTS, +) + # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) -# Database configuration -DB_PATH = os.getenv("DB_PATH", "chat_history.db") -EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "qwen3-embed-4b") -EMBEDDING_DIMENSION = 2048 # Default for qwen3-embed-4b -MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000")) -SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7")) -TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5")) - -# OpenAI configuration -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder") -OPENAI_API_EMBED_ENDPOINT = os.getenv( - "OPENAI_API_EMBED_ENDPOINT", "https://llama-embed.reeselink.com" -) - class ChatDatabase: """SQLite database with RAG support for storing chat history using OpenAI embeddings.""" @@ -34,7 +29,9 @@ class ChatDatabase: def __init__(self, db_path: str = DB_PATH): logger.info(f"Initializing ChatDatabase with path: {db_path}") self.db_path = db_path - self.client = OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key=OPENAI_API_KEY) + self.client = OpenAI( + base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY + ) logger.info("Connecting to OpenAI API for embeddings") self._initialize_database() @@ -83,7 +80,7 @@ class ChatDatabase: """ ) logger.info("idx_timestamp index created successfully") - + logger.info("Creating idx_user_id index if not exists") cursor.execute( """ @@ -96,36 +93,6 @@ class ChatDatabase: logger.info("Database initialization completed successfully") conn.close() - def _generate_embedding(self, text: str) -> List[float]: - """Generate embedding for text using OpenAI API.""" - logger.debug(f"Generating embedding for text (length: {len(text)})") - try: - logger.info(f"Calling OpenAI API to generate embedding with model: {EMBEDDING_MODEL}") - response = self.client.embeddings.create( - model=EMBEDDING_MODEL, input=text, encoding_format="float" - ) - logger.debug("OpenAI API response received successfully") - - # The embedding is returned as a nested list: [[embedding_values]] - # We need to extract the inner list - embedding_data = response[0].embedding - if isinstance(embedding_data, list) and len(embedding_data) > 0: - # The first element might be the embedding array itself or a nested list - first_item = embedding_data[0] - if isinstance(first_item, list): - # Handle nested structure: [[values]] -> [values] - logger.debug("Extracted embedding from nested structure [[values]]") - return first_item - else: - # Handle direct structure: [values] - logger.debug("Extracted embedding from direct structure [values]") - return embedding_data - logger.warning("Embedding data is empty or invalid") - return [] - except Exception as e: - logger.error(f"Error generating embedding: {e}") - return None - def _vector_to_bytes(self, vector: List[float]) -> bytes: """Convert vector to bytes for SQLite storage.""" logger.debug(f"Converting vector (length: {len(vector)}) to bytes") @@ -142,7 +109,9 @@ class ChatDatabase: def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: """Calculate cosine similarity between two vectors.""" - logger.debug(f"Calculating cosine similarity between vectors of dimension {len(vec1)}") + logger.debug( + f"Calculating cosine similarity between vectors of dimension {len(vec1)}" + ) result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) logger.debug(f"Similarity calculated: {result:.4f}") return result @@ -163,7 +132,9 @@ class ChatDatabase: try: # Insert message - logger.debug(f"Inserting message into chat_messages table: message_id={message_id}") + logger.debug( + f"Inserting message into chat_messages table: message_id={message_id}" + ) cursor.execute( """ INSERT OR REPLACE INTO chat_messages @@ -176,9 +147,16 @@ class ChatDatabase: # Generate and store embedding logger.info(f"Generating embedding for message {message_id}") - embedding = self._generate_embedding(content) + embedding = llama_wrapper.embedding( + content, + openai_url=EMBEDDING_ENDPOINT, + openai_api_key=EMBEDDING_ENDPOINT_KEY, + model=EMBEDDING_MODEL, + ) if embedding: - logger.debug(f"Embedding generated successfully for message {message_id}, storing in database") + logger.debug( + f"Embedding generated successfully for message {message_id}, storing in database" + ) cursor.execute( """ INSERT OR REPLACE INTO message_embeddings @@ -187,9 +165,13 @@ class ChatDatabase: """, (message_id, self._vector_to_bytes(embedding)), ) - logger.debug(f"Embedding stored in message_embeddings table for message {message_id}") + logger.debug( + f"Embedding stored in message_embeddings table for message {message_id}" + ) else: - logger.warning(f"Failed to generate embedding for message {message_id}, skipping embedding storage") + logger.warning( + f"Failed to generate embedding for message {message_id}, skipping embedding storage" + ) # Clean up old messages if exceeding limit logger.info("Checking if cleanup of old messages is needed") @@ -268,9 +250,14 @@ class ChatDatabase: query: str, top_k: int = TOP_K_RESULTS, min_similarity: float = SIMILARITY_THRESHOLD, - ) -> List[Tuple[str, str, str, float]]: + ) -> List[Tuple[str, str, float]]: """Search for messages similar to the query using embeddings.""" - query_embedding = self._generate_embedding(query) + query_embedding = llama_wrapper.embedding( + text=query, + model=EMBEDDING_MODEL, + openai_url=EMBEDDING_ENDPOINT, + openai_api_key=EMBEDDING_ENDPOINT_KEY, + ) if not query_embedding: return [] @@ -285,19 +272,28 @@ class ChatDatabase: SELECT cm.message_id, cm.content, me.embedding FROM chat_messages cm JOIN message_embeddings me ON cm.message_id = me.message_id + WHERE cm.username != 'vibe-bot' """ ) rows = cursor.fetchall() - results = [] + results: list[tuple[str, str, float]] = [] for message_id, content, embedding_blob in rows: embedding_vector = self._bytes_to_vector(embedding_blob) similarity = self._calculate_similarity(query_vector, embedding_vector) if similarity >= min_similarity: - results.append( - (message_id, content[:500], similarity) - ) # Limit content length + cursor.execute( + """ + SELECT content + FROM chat_messages + WHERE message_id = ? + ORDER BY timestamp DESC + """, + (f"{message_id}_response",), + ) + response: str = cursor.fetchone()[0] + results.append((content, response, similarity)) conn.close() @@ -305,28 +301,48 @@ class ChatDatabase: results.sort(key=lambda x: x[2], reverse=True) return results[:top_k] - def get_user_history( - self, user_id: str, limit: int = 20 - ) -> List[Tuple[str, str, datetime]]: + def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]: """Get message history for a specific user.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() + logger.info(f"Fetching last {limit} user messages") cursor.execute( """ SELECT message_id, content, timestamp FROM chat_messages - WHERE user_id = ? + WHERE username != 'vibe-bot' ORDER BY timestamp DESC LIMIT ? """, - (user_id, limit), + (limit,), ) messages = cursor.fetchall() + + # Format is [user message, bot response] + conversations: list[tuple[str, str]] = [] + for message in messages: + msg_content: str = message[1] + logger.info(f"Finding response for {msg_content[:50]}") + cursor.execute( + """ + SELECT content + FROM chat_messages + WHERE message_id = ? + ORDER BY timestamp DESC + """, + (f"{message[0]}_response",), + ) + response_content: str = cursor.fetchone() + if response_content: + logger.info(f"Found response: {response_content[0][:50]}") + conversations.append((msg_content, response_content[0])) + else: + logger.info("No response found") conn.close() - return messages + return conversations def get_conversation_context( self, user_id: str, current_message: str, max_context: int = 5 @@ -344,15 +360,19 @@ class ChatDatabase: context_parts = [] # Add recent messages - for message_id, content, timestamp in recent_messages: - context_parts.append(f"[{timestamp}] User: {content}") + for user_message, bot_message in recent_messages: + combined_content = f"[Recent chat]\n{user_message}\n{bot_message}" + context_parts.append(combined_content) # Add similar messages - for message_id, content, similarity in similar_messages: - if f"[{content}" not in "\n".join(context_parts): # Avoid duplicates - context_parts.append(f"[Similar] {content}") + for user_message, bot_message, similarity in similar_messages: + combined_content = f"{user_message}\n{bot_message}" + if combined_content not in "\n".join(context_parts): + context_parts.append(f"[You remember]\n{combined_content}") - return "\n".join(context_parts[-max_context * 2 :]) # Limit total context + # Conversation history needs to be delivered in "newest context last" order + context_parts.reverse() + return "\n".join(context_parts[-max_context * 4 :]) # Limit total context def clear_all_messages(self): """Clear all messages and embeddings from the database.""" @@ -390,6 +410,7 @@ class CustomBotManager: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() + # Create table to hold custom bots cursor.execute( """ CREATE TABLE IF NOT EXISTS custom_bots ( @@ -399,7 +420,7 @@ class CustomBotManager: created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_active INTEGER DEFAULT 1 ) - """ + """ ) conn.commit() @@ -461,8 +482,9 @@ class CustomBotManager: if user_id: cursor.execute( """ - SELECT bot_name, system_prompt, created_by - FROM custom_bots + SELECT bot_name, system_prompt, name + FROM custom_bots cb, username_map um + JOIN username_map ON custom_bots.created_by = username_map.id WHERE is_active = 1 ORDER BY created_at DESC """ diff --git a/vibe_bot/llama_wrapper.py b/vibe_bot/llama_wrapper.py index 8d8b164..af85cd4 100644 --- a/vibe_bot/llama_wrapper.py +++ b/vibe_bot/llama_wrapper.py @@ -5,9 +5,12 @@ import openai from typing import Iterable from openai.types.chat import ChatCompletionMessageParam +from openai._types import FileTypes, SequenceNotStr +from typing import Union +from io import BufferedReader, BytesIO -def chat_completion_think( +def chat_completion( system_prompt: str, user_prompt: str, openai_url: str, @@ -80,35 +83,56 @@ def chat_completion_instruct( return "" -def image_generation(prompt: str, n=1) -> str: - client = openai.OpenAI(base_url=OPENAI_API_IMAGE_ENDPOINT, api_key="placeholder") +def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str: + """Generates an image using the given prompt and returns the base64 encoded image data + + Returns: + str: The base64 encoded image data. Decode and write to a file. + """ + client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) response = client.images.generate( prompt=prompt, n=n, size="1024x1024", ) if response.data: - return response.data[0].url + return response.data[0].b64_json or "" else: return "" -def image_edit(image, mask, prompt, n=1, size="1024x1024"): - client = openai.OpenAI(base_url=OPENAI_API_EDIT_ENDPOINT, api_key="placeholder") +def image_edit( + image: BufferedReader | BytesIO, + prompt: str, + openai_url: str, + openai_api_key: str, + n=1, +) -> str: + client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) response = client.images.edit( image=image, - mask=mask, prompt=prompt, n=n, - size=size, + size="1024x1024", ) - return response.data[0].url + if response.data: + return response.data[0].b64_json or "" + else: + return "" -def embeddings(text, model="text-embedding-3-small"): - client = openai.OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key="placeholder") +def embedding( + text: str, openai_url: str, openai_api_key: str, model: str +) -> list[float]: + client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key) response = client.embeddings.create( - input=text, - model=model, + input=[text], model=model, encoding_format="float" ) - return response.data[0].embedding + if response: + raw_data = response[0].embedding # type: ignore + # The result could be an array of floats or an array of an array of floats. + try: + return raw_data[0] + except Exception: + return raw_data + return [] diff --git a/vibe_bot/main.py b/vibe_bot/main.py index f36904f..2a4c9c9 100644 --- a/vibe_bot/main.py +++ b/vibe_bot/main.py @@ -5,7 +5,19 @@ import base64 from io import BytesIO from openai import OpenAI import logging -from database import get_database, CustomBotManager +from database import get_database, CustomBotManager # type: ignore +from config import ( # type: ignore + CHAT_ENDPOINT_KEY, + DISCORD_TOKEN, + CHAT_ENDPOINT, + CHAT_MODEL, + IMAGE_EDIT_ENDPOINT_KEY, + IMAGE_GEN_ENDPOINT, + IMAGE_EDIT_ENDPOINT, + MAX_COMPLETION_TOKENS, +) +import llama_wrapper # type: ignore +import requests # Configure logging logging.basicConfig( @@ -13,31 +25,11 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) -DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "placeholder") - -OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT") -IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT") -IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT") -MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000")) - -if not OPENAI_API_ENDPOINT: - raise Exception("OPENAI_API_ENDPOINT required.") - -if not IMAGE_GEN_ENDPOINT: - raise Exception("IMAGE_GEN_ENDPOINT required.") - -# Set your OpenAI API key as an environment variable -# You can also pass it directly but environment variables are safer -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder") - # Initialize the bot intents = discord.Intents.default() intents.message_content = True bot = commands.Bot(command_prefix="!", intents=intents) -# OpenAI Completions API endpoint -OPENAI_COMPLETIONS_URL = f"{OPENAI_API_ENDPOINT}/chat/completions" - @bot.event async def on_ready(): @@ -46,7 +38,7 @@ async def on_ready(): logger.info(f"Bot logged in as {bot.user}") -@bot.command(name="custom-bot") +@bot.command(name="custom-bot") # type: ignore async def custom_bot(ctx, bot_name: str, *, personality: str): """Create a custom bot with a name and personality @@ -129,14 +121,14 @@ async def list_custom_bots(ctx): f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}" ) bot_list = "🤖 **Available Custom Bots**:\n\n" - for name, prompt, creator in bots[:10]: # Limit to 10 bots - bot_list += f"• **{name}** (created by {creator})\n" + for name, prompt, creator in bots: + bot_list += f"• **{name}**\n" logger.info(f"Sending bot list response to {ctx.author.name}") await ctx.send(bot_list) -@bot.command(name="delete-custom-bot") +@bot.command(name="delete-custom-bot") # type: ignore async def delete_custom_bot(ctx, bot_name: str): """Delete a custom bot (only the creator can delete) @@ -194,16 +186,16 @@ async def on_message(message): if message.author == bot.user: return + message_author = message.author.name + message_content = message.content.lower() + logger.debug( - f"Processing message from {message.author.name}: '{message.content[:50]}...'" + f"Processing message from {message_author}: '{message_content[:50]}...'" ) ctx = await bot.get_context(message) - # Check if the message starts with a custom bot command - content = message.content.lower() - - logger.info(f"Initializing CustomBotManager to check for custom bot commands") + logger.info("Initializing CustomBotManager to check for custom bot commands") custom_bot_manager = CustomBotManager() logger.info("Fetching list of custom bots to check for matching commands") @@ -212,7 +204,7 @@ async def on_message(message): logger.info(f"Checking {len(custom_bots)} custom bots for command match") for bot_name, system_prompt, _ in custom_bots: # Check if message starts with the custom bot name followed by a space - if content.startswith(f"!{bot_name} "): + if message_content.startswith(f"!{bot_name} "): logger.info( f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}" ) @@ -224,25 +216,14 @@ async def on_message(message): ) # Prepare the payload with custom personality - payload = { - "model": "qwen3-vl-30b-a3b-instruct", - "messages": [ - { - "role": "system", - "content": system_prompt, - }, - {"role": "user", "content": user_message}, - ], - "max_completion_tokens": MAX_COMPLETION_TOKENS, - } - response_prefix = f"**{bot_name} response**" logger.info(f"Sending request to OpenAI API for bot '{bot_name}'") await handle_chat( ctx=ctx, + bot_name=bot_name, message=user_message, - payload=payload, + system_prompt=system_prompt, response_prefix=response_prefix, ) return @@ -258,24 +239,22 @@ async def doodlebob(ctx, *, message: str): logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}") await ctx.send(f"**Doodlebob erasing {message[:100]}...**") - image_prompt_payload = { - "model": "qwen3-vl-30b-a3b-instruct", - "messages": [ - { - "role": "system", - "content": ( - "Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model." - "If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self" - " reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's" - " questions." - ), - }, - {"role": "user", "content": message}, - ], - } + system_prompt = ( + "Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model." + "If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self" + " reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's" + " questions." + ) # Wait for the generated image prompt - image_prompt = await call_llm(ctx, image_prompt_payload) + image_prompt = llama_wrapper.chat_completion_instruct( + system_prompt=system_prompt, + user_prompt=message, + openai_url=CHAT_ENDPOINT, + openai_api_key=CHAT_ENDPOINT_KEY, + model=CHAT_MODEL, + max_tokens=MAX_COMPLETION_TOKENS, + ) # If the string is empty we had an error if image_prompt == "": @@ -285,32 +264,16 @@ async def doodlebob(ctx, *, message: str): # Alert the user we're generating the image await ctx.send(f"**Doodlebob calling drone strike on {image_prompt[:100]}...**") - # Create the image prompt payload - image_payload = { - "model": "default", - "prompt": image_prompt, - "n": 1, - "size": "1024x1024", - } - - # Call the image generation endpoint - response = requests.post( - f"{IMAGE_GEN_ENDPOINT}/images/generations", - json=image_payload, - timeout=120, + image_b64 = llama_wrapper.image_generation( + prompt=message, + openai_url=IMAGE_EDIT_ENDPOINT, + openai_api_key=IMAGE_EDIT_ENDPOINT_KEY, ) - if response.status_code == 200: - result = response.json() - # Send image - image_data = BytesIO(base64.b64decode(result["data"][0]["b64_json"])) - send_img = discord.File(image_data, filename="image.png") - await ctx.send(file=send_img) - - else: - print(f"❌ Error: {response.status_code}") - print(response.text) - return None + # Save the image to a file + edited_image_data = BytesIO(base64.b64decode(image_b64)) + send_img = discord.File(edited_image_data, filename="image.png") + await ctx.send(file=send_img) @bot.command(name="retcon") @@ -321,31 +284,23 @@ async def retcon(ctx, *, message: str): await ctx.send(f"**Rewriting history to match {message[:100]}...**") - client = OpenAI(base_url=IMAGE_EDIT_ENDPOINT, api_key=OPENAI_API_KEY) - - result = client.images.edit( - model="placeholder", - image=[image_bytestream], + image_b64 = llama_wrapper.image_edit( + image=image_bytestream, prompt=message, - size="1024x1024", + openai_url=IMAGE_EDIT_ENDPOINT, + openai_api_key=IMAGE_EDIT_ENDPOINT_KEY, ) - image_base64 = result.data[0].b64_json - image_bytes = base64.b64decode(image_base64) - # Save the image to a file - edited_image_data = BytesIO(image_bytes) + edited_image_data = BytesIO(base64.b64decode(image_b64)) send_img = discord.File(edited_image_data, filename="image.png") await ctx.send(file=send_img) -async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str): - # Check if API key is set - if not OPENAI_API_KEY: - await ctx.send( - "Error: OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable." - ) - return +async def handle_chat( + ctx, *, bot_name: str, message: str, system_prompt: str, response_prefix: str +): + await ctx.send(f"{bot_name} is searching its databanks for {message[:50]}...") # Get database instance db = get_database() @@ -356,32 +311,27 @@ async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str) ) if context: - payload["messages"][0][ - "content" - ] += f"\n\nRelevant conversation history:\n{context}" + user_message = f"\n\nRelevant conversation history:\n{context}\n\n{message}" + else: + user_message = message - payload["messages"][1]["content"] = message + logger.info(user_message) - print(payload) + system_prompt_edit = ( + "Keep your responses somewhat short, limited to 500 words or less. " + f"{system_prompt}" + ) try: - # Initialize OpenAI client - client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_ENDPOINT) - - # Call OpenAI API - response = client.chat.completions.create( - model=payload["model"], - messages=payload["messages"], - max_completion_tokens=MAX_COMPLETION_TOKENS, - frequency_penalty=1.5, - presence_penalty=1.5, - temperature=1, - seed=-1, + bot_response = llama_wrapper.chat_completion_instruct( + system_prompt=system_prompt_edit, + user_prompt=user_message, + openai_url=CHAT_ENDPOINT, + openai_api_key=CHAT_ENDPOINT_KEY, + model=CHAT_MODEL, + max_tokens=MAX_COMPLETION_TOKENS, ) - # Extract the generated text - generated_text = response.choices[0].message.content.strip() - # Store both user message and bot response in the database db.add_message( message_id=f"{ctx.message.id}", @@ -394,68 +344,24 @@ async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str) db.add_message( message_id=f"{ctx.message.id}_response", - user_id=str(bot.user.id), - username=bot.user.name, - content=f"Bot: {generated_text}", + user_id=str(bot.user.id), # type: ignore + username=bot.user.name, # type: ignore + content=f"Bot: {bot_response}", channel_id=str(ctx.channel.id), guild_id=str(ctx.guild.id) if ctx.guild else None, ) # Send the response back to the chat await ctx.send(response_prefix) - while generated_text: - send_chunk = generated_text[:1000] - generated_text = generated_text[1000:] + while bot_response: + send_chunk = bot_response[:1000] + bot_response = bot_response[1000:] await ctx.send(send_chunk) - except requests.exceptions.HTTPError as e: - await ctx.send(f"Error: OpenAI API error - {e}") - except requests.exceptions.Timeout: - await ctx.send("Error: Request timed out. Please try again.") except Exception as e: await ctx.send(f"Error: {str(e)}") -async def call_llm(ctx, payload: dict) -> str: - # Check if API key is set - if not OPENAI_API_KEY: - await ctx.send( - "Error: OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable." - ) - return "" - - # Set headers - headers = { - "Authorization": f"Bearer {OPENAI_API_KEY}", - "Content-Type": "application/json", - } - - try: - # Initialize OpenAI client - client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_ENDPOINT) - - # Call OpenAI API - response = client.chat.completions.create( - model=payload["model"], - messages=payload["messages"], - max_tokens=MAX_COMPLETION_TOKENS, - ) - - # Extract the generated text - generated_text = response.choices[0].message.content.strip() - print(generated_text) - - return generated_text - - except requests.exceptions.HTTPError as e: - await ctx.send(f"Error: OpenAI API error - {e}") - except requests.exceptions.Timeout: - await ctx.send("Error: Request timed out. Please try again.") - except Exception as e: - await ctx.send(f"Error: {str(e)}") - return "" - - # Run the bot if __name__ == "__main__": bot.run(DISCORD_TOKEN) diff --git a/vibe_bot/tests/conftest.py b/vibe_bot/tests/conftest.py deleted file mode 100644 index 8c95b76..0000000 --- a/vibe_bot/tests/conftest.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import pytest -from dotenv import load_dotenv - -# Try to load .env.test first, fallback to .env -env_test_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env.test') -env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env') - -if os.path.exists(env_test_path): - load_dotenv(env_test_path) - print("✓ Loaded environment variables from .env.test") -elif os.path.exists(env_path): - load_dotenv(env_path) - print("✓ Loaded environment variables from .env") - -@pytest.fixture(autouse=True, scope="session") -def verify_env_loaded(): - """Verify critical environment variables are loaded before tests run""" - required_vars = [ - "DISCORD_TOKEN", - "OPENAI_API_ENDPOINT", - "IMAGE_GEN_ENDPOINT", - "IMAGE_EDIT_ENDPOINT" - ] - - missing_vars = [var for var in required_vars if var not in os.environ] - if missing_vars: - pytest.fail(f"Missing required environment variables: {', '.join(missing_vars)}") - - yield diff --git a/vibe_bot/tests/test_llama_wrapper.py b/vibe_bot/tests/test_llama_wrapper.py index d923256..9152d19 100644 --- a/vibe_bot/tests/test_llama_wrapper.py +++ b/vibe_bot/tests/test_llama_wrapper.py @@ -1,71 +1,112 @@ # Tests all functions in the llama-wrapper.py file # Run with: python -m pytest test_llama_wrapper.py -v -from discord import message -import pytest from ..llama_wrapper import ( - chat_completion_think, + chat_completion, chat_completion_instruct, image_generation, image_edit, - embeddings, + embedding, ) -from dotenv import load_dotenv -import os - -OPENAI_API_CHAT_ENDPOINT = os.getenv( - "OPENAI_API_CHAT_ENDPOINT", "https://llama-cpp.reeselink.com" +from ..config import ( + CHAT_ENDPOINT, + CHAT_MODEL, + CHAT_ENDPOINT_KEY, + IMAGE_EDIT_ENDPOINT, + IMAGE_EDIT_ENDPOINT_KEY, + IMAGE_GEN_ENDPOINT, + IMAGE_GEN_ENDPOINT_KEY, + EMBEDDING_ENDPOINT, + EMBEDDING_ENDPOINT_KEY, ) -OPENAI_API_IMAGE_ENDPOINT = os.getenv("OPENAI_API_IMAGE_ENDPOINT") -OPENAI_API_EDIT_ENDPOINT = os.getenv("OPENAI_API_EDIT_ENDPOINT") -OPENAI_API_EMBED_ENDPOINT = os.getenv("OPENAI_API_EMBED_ENDPOINT") +from io import BytesIO +import base64 +import tempfile +from pathlib import Path +import numpy as np -# Default models -DEFAULT_CHAT_MODEL = os.getenv("DEFAULT_CHAT_MODEL", "qwen3.5-35b-a3b") -DEFAULT_EMBED_MODEL = os.getenv("DEFAULT_EMBED_MODEL", "text-embedding-3-small") -DEFAULT_IMAGE_MODEL = os.getenv("DEFAULT_IMAGE_MODEL", "dall-e-3") -DEFAULT_EDIT_MODEL = os.getenv("DEFAULT_EDIT_MODEL", "dall-e-2") + +TEMPDIR = Path(tempfile.mkdtemp()) def test_chat_completion_think(): - # This test will fail without an actual API endpoint - # But it's here to show the structure - chat_completion_think( + result = chat_completion( system_prompt="You are a helpful assistant.", user_prompt="Tell me about Everquest", - openai_url=OPENAI_API_CHAT_ENDPOINT, - openai_api_key="placeholder", - model=DEFAULT_CHAT_MODEL, + openai_url=CHAT_ENDPOINT, + openai_api_key=CHAT_ENDPOINT_KEY, + model=CHAT_MODEL, max_tokens=100, ) + print(result) def test_chat_completion_instruct(): - # This test will fail without an actual API endpoint - # But it's here to show the structure - chat_completion_instruct( + result = chat_completion_instruct( system_prompt="You are a helpful assistant.", user_prompt="Tell me about Everquest", - openai_url=OPENAI_API_CHAT_ENDPOINT, - openai_api_key="placeholder", - model=DEFAULT_CHAT_MODEL, + openai_url=CHAT_ENDPOINT, + openai_api_key=CHAT_ENDPOINT_KEY, + model=CHAT_MODEL, max_tokens=100, ) + print(result) def test_image_generation(): - # This test will fail without an actual API endpoint - # But it's here to show the structure - pass + result = image_generation( + prompt="Generate an image of a horse", + openai_url=IMAGE_GEN_ENDPOINT, + openai_api_key=IMAGE_GEN_ENDPOINT_KEY, + ) + with open("image-gen.png", "wb") as f: + f.write(base64.b64decode(result)) def test_image_edit(): - # This test will fail without an actual API endpoint - # But it's here to show the structure - pass + with open("image-gen.png", "rb") as f: + image_data = BytesIO(f.read()) + result = image_edit( + image=image_data, + prompt="Paint the words 'horse' on the horse.", + openai_url=IMAGE_EDIT_ENDPOINT, + openai_api_key=IMAGE_EDIT_ENDPOINT_KEY, + ) + with open("image-edit.png", "wb") as f: + f.write(base64.b64decode(result)) + + +def _cosine_similarity(a, b): + """ + Close to 1: very similar + Close to 0: orthogonal + Close to -1: opposite + """ + a, b = np.array(a), np.array(b) + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) def test_embeddings(): - # This test will fail without an actual API endpoint - # But it's here to show the structure - pass + result1 = embedding( + "this is a horse", + openai_url=EMBEDDING_ENDPOINT, + openai_api_key=EMBEDDING_ENDPOINT_KEY, + model="qwen3-embed-4b", + ) + result2 = embedding( + "this is a horse also", + openai_url=EMBEDDING_ENDPOINT, + openai_api_key=EMBEDDING_ENDPOINT_KEY, + model="qwen3-embed-4b", + ) + result3 = embedding( + "this is a donkey", + openai_url=EMBEDDING_ENDPOINT, + openai_api_key=EMBEDDING_ENDPOINT_KEY, + model="qwen3-embed-4b", + ) + similarity_1 = _cosine_similarity(result1, result2) + assert similarity_1 > 0.9 + + similarity_2 = _cosine_similarity(result1, result3) + assert similarity_2 < 0.5