human cleanup

This commit is contained in:
2026-03-09 22:36:04 -04:00
parent 3defce1365
commit 488912a991
8 changed files with 377 additions and 329 deletions

BIN
image-edit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 MiB

BIN
image-gen.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

85
vibe_bot/config.py Normal file
View File

@@ -0,0 +1,85 @@
from dotenv import load_dotenv
import os
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
load_dotenv()
# Discord
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "")
# Endpoints
CHAT_ENDPOINT = os.getenv("CHAT_ENDPOINT", "")
COMPLETION_ENDPOINT = os.getenv("COMPLETION_ENDPOINT", "")
IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT", "")
IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT", "")
EMBEDDING_ENDPOINT = os.getenv("EMBEDDING_ENDPOINT", "")
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
# API Keys
CHAT_ENDPOINT_KEY = os.getenv("CHAT_ENDPOINT_KEY", "placeholder")
COMPLETION_ENDPOINT_KEY = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder")
IMAGE_GEN_ENDPOINT_KEY = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder")
IMAGE_EDIT_ENDPOINT_KEY = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder")
EMBEDDING_ENDPOINT_KEY = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder")
# Models
CHAT_MODEL = os.getenv("CHAT_MODEL", "")
COMPLETION_MODEL = os.getenv("COMPLETION_MODEL", "")
IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "")
IMAGE_EDIT_MODEL = os.getenv("IMAGE_EDIT_MODEL", "")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "")
# Database and embeddings
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
EMBEDDING_DIMENSION = 2048
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
# Check token
if not DISCORD_TOKEN:
raise Exception("DISCORD_TOKEN required.")
# Check endpoints
if not CHAT_ENDPOINT:
raise Exception("CHAT_ENDPOINT required.")
if not COMPLETION_ENDPOINT:
raise Exception("COMPLETION_ENDPOINT required.")
if not IMAGE_GEN_ENDPOINT:
raise Exception("IMAGE_GEN_ENDPOINT required.")
if not IMAGE_EDIT_ENDPOINT:
raise Exception("IMAGE_EDIT_ENDPOINT required.")
if not EMBEDDING_ENDPOINT:
raise Exception("EMBEDDING_ENDPOINT required.")
# Check models
if not CHAT_MODEL:
raise Exception("CHAT_MODEL required.")
if not COMPLETION_MODEL:
raise Exception("COMPLETION_MODEL required.")
if not IMAGE_GEN_MODEL:
raise Exception("IMAGE_GEN_MODEL required.")
if not IMAGE_EDIT_MODEL:
raise Exception("IMAGE_EDIT_MODEL required.")
if not EMBEDDING_MODEL:
raise Exception("EMBEDDING_MODEL required.")
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
logger.info(f"IMAGE_EDIT_ENDPOINT set to {IMAGE_EDIT_ENDPOINT}")
logger.info(f"EMBEDDING_ENDPOINT set to {EMBEDDING_ENDPOINT}")

View File

@@ -1,32 +1,27 @@
import sqlite3 import sqlite3
import os
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from datetime import datetime from datetime import datetime
import numpy as np import numpy as np
from openai import OpenAI from openai import OpenAI
import logging import logging
import llama_wrapper # type: ignore
from config import ( # type: ignore
DB_PATH,
EMBEDDING_MODEL,
EMBEDDING_ENDPOINT,
EMBEDDING_ENDPOINT_KEY,
MAX_HISTORY_MESSAGES,
SIMILARITY_THRESHOLD,
TOP_K_RESULTS,
)
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Database configuration
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "qwen3-embed-4b")
EMBEDDING_DIMENSION = 2048 # Default for qwen3-embed-4b
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
# OpenAI configuration
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder")
OPENAI_API_EMBED_ENDPOINT = os.getenv(
"OPENAI_API_EMBED_ENDPOINT", "https://llama-embed.reeselink.com"
)
class ChatDatabase: class ChatDatabase:
"""SQLite database with RAG support for storing chat history using OpenAI embeddings.""" """SQLite database with RAG support for storing chat history using OpenAI embeddings."""
@@ -34,7 +29,9 @@ class ChatDatabase:
def __init__(self, db_path: str = DB_PATH): def __init__(self, db_path: str = DB_PATH):
logger.info(f"Initializing ChatDatabase with path: {db_path}") logger.info(f"Initializing ChatDatabase with path: {db_path}")
self.db_path = db_path self.db_path = db_path
self.client = OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key=OPENAI_API_KEY) self.client = OpenAI(
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
)
logger.info("Connecting to OpenAI API for embeddings") logger.info("Connecting to OpenAI API for embeddings")
self._initialize_database() self._initialize_database()
@@ -96,36 +93,6 @@ class ChatDatabase:
logger.info("Database initialization completed successfully") logger.info("Database initialization completed successfully")
conn.close() conn.close()
def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text using OpenAI API."""
logger.debug(f"Generating embedding for text (length: {len(text)})")
try:
logger.info(f"Calling OpenAI API to generate embedding with model: {EMBEDDING_MODEL}")
response = self.client.embeddings.create(
model=EMBEDDING_MODEL, input=text, encoding_format="float"
)
logger.debug("OpenAI API response received successfully")
# The embedding is returned as a nested list: [[embedding_values]]
# We need to extract the inner list
embedding_data = response[0].embedding
if isinstance(embedding_data, list) and len(embedding_data) > 0:
# The first element might be the embedding array itself or a nested list
first_item = embedding_data[0]
if isinstance(first_item, list):
# Handle nested structure: [[values]] -> [values]
logger.debug("Extracted embedding from nested structure [[values]]")
return first_item
else:
# Handle direct structure: [values]
logger.debug("Extracted embedding from direct structure [values]")
return embedding_data
logger.warning("Embedding data is empty or invalid")
return []
except Exception as e:
logger.error(f"Error generating embedding: {e}")
return None
def _vector_to_bytes(self, vector: List[float]) -> bytes: def _vector_to_bytes(self, vector: List[float]) -> bytes:
"""Convert vector to bytes for SQLite storage.""" """Convert vector to bytes for SQLite storage."""
logger.debug(f"Converting vector (length: {len(vector)}) to bytes") logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
@@ -142,7 +109,9 @@ class ChatDatabase:
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors.""" """Calculate cosine similarity between two vectors."""
logger.debug(f"Calculating cosine similarity between vectors of dimension {len(vec1)}") logger.debug(
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
)
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
logger.debug(f"Similarity calculated: {result:.4f}") logger.debug(f"Similarity calculated: {result:.4f}")
return result return result
@@ -163,7 +132,9 @@ class ChatDatabase:
try: try:
# Insert message # Insert message
logger.debug(f"Inserting message into chat_messages table: message_id={message_id}") logger.debug(
f"Inserting message into chat_messages table: message_id={message_id}"
)
cursor.execute( cursor.execute(
""" """
INSERT OR REPLACE INTO chat_messages INSERT OR REPLACE INTO chat_messages
@@ -176,9 +147,16 @@ class ChatDatabase:
# Generate and store embedding # Generate and store embedding
logger.info(f"Generating embedding for message {message_id}") logger.info(f"Generating embedding for message {message_id}")
embedding = self._generate_embedding(content) embedding = llama_wrapper.embedding(
content,
openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
model=EMBEDDING_MODEL,
)
if embedding: if embedding:
logger.debug(f"Embedding generated successfully for message {message_id}, storing in database") logger.debug(
f"Embedding generated successfully for message {message_id}, storing in database"
)
cursor.execute( cursor.execute(
""" """
INSERT OR REPLACE INTO message_embeddings INSERT OR REPLACE INTO message_embeddings
@@ -187,9 +165,13 @@ class ChatDatabase:
""", """,
(message_id, self._vector_to_bytes(embedding)), (message_id, self._vector_to_bytes(embedding)),
) )
logger.debug(f"Embedding stored in message_embeddings table for message {message_id}") logger.debug(
f"Embedding stored in message_embeddings table for message {message_id}"
)
else: else:
logger.warning(f"Failed to generate embedding for message {message_id}, skipping embedding storage") logger.warning(
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
)
# Clean up old messages if exceeding limit # Clean up old messages if exceeding limit
logger.info("Checking if cleanup of old messages is needed") logger.info("Checking if cleanup of old messages is needed")
@@ -268,9 +250,14 @@ class ChatDatabase:
query: str, query: str,
top_k: int = TOP_K_RESULTS, top_k: int = TOP_K_RESULTS,
min_similarity: float = SIMILARITY_THRESHOLD, min_similarity: float = SIMILARITY_THRESHOLD,
) -> List[Tuple[str, str, str, float]]: ) -> List[Tuple[str, str, float]]:
"""Search for messages similar to the query using embeddings.""" """Search for messages similar to the query using embeddings."""
query_embedding = self._generate_embedding(query) query_embedding = llama_wrapper.embedding(
text=query,
model=EMBEDDING_MODEL,
openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
)
if not query_embedding: if not query_embedding:
return [] return []
@@ -285,19 +272,28 @@ class ChatDatabase:
SELECT cm.message_id, cm.content, me.embedding SELECT cm.message_id, cm.content, me.embedding
FROM chat_messages cm FROM chat_messages cm
JOIN message_embeddings me ON cm.message_id = me.message_id JOIN message_embeddings me ON cm.message_id = me.message_id
WHERE cm.username != 'vibe-bot'
""" """
) )
rows = cursor.fetchall() rows = cursor.fetchall()
results = [] results: list[tuple[str, str, float]] = []
for message_id, content, embedding_blob in rows: for message_id, content, embedding_blob in rows:
embedding_vector = self._bytes_to_vector(embedding_blob) embedding_vector = self._bytes_to_vector(embedding_blob)
similarity = self._calculate_similarity(query_vector, embedding_vector) similarity = self._calculate_similarity(query_vector, embedding_vector)
if similarity >= min_similarity: if similarity >= min_similarity:
results.append( cursor.execute(
(message_id, content[:500], similarity) """
) # Limit content length SELECT content
FROM chat_messages
WHERE message_id = ?
ORDER BY timestamp DESC
""",
(f"{message_id}_response",),
)
response: str = cursor.fetchone()[0]
results.append((content, response, similarity))
conn.close() conn.close()
@@ -305,28 +301,48 @@ class ChatDatabase:
results.sort(key=lambda x: x[2], reverse=True) results.sort(key=lambda x: x[2], reverse=True)
return results[:top_k] return results[:top_k]
def get_user_history( def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
self, user_id: str, limit: int = 20
) -> List[Tuple[str, str, datetime]]:
"""Get message history for a specific user.""" """Get message history for a specific user."""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
logger.info(f"Fetching last {limit} user messages")
cursor.execute( cursor.execute(
""" """
SELECT message_id, content, timestamp SELECT message_id, content, timestamp
FROM chat_messages FROM chat_messages
WHERE user_id = ? WHERE username != 'vibe-bot'
ORDER BY timestamp DESC ORDER BY timestamp DESC
LIMIT ? LIMIT ?
""", """,
(user_id, limit), (limit,),
) )
messages = cursor.fetchall() messages = cursor.fetchall()
# Format is [user message, bot response]
conversations: list[tuple[str, str]] = []
for message in messages:
msg_content: str = message[1]
logger.info(f"Finding response for {msg_content[:50]}")
cursor.execute(
"""
SELECT content
FROM chat_messages
WHERE message_id = ?
ORDER BY timestamp DESC
""",
(f"{message[0]}_response",),
)
response_content: str = cursor.fetchone()
if response_content:
logger.info(f"Found response: {response_content[0][:50]}")
conversations.append((msg_content, response_content[0]))
else:
logger.info("No response found")
conn.close() conn.close()
return messages return conversations
def get_conversation_context( def get_conversation_context(
self, user_id: str, current_message: str, max_context: int = 5 self, user_id: str, current_message: str, max_context: int = 5
@@ -344,15 +360,19 @@ class ChatDatabase:
context_parts = [] context_parts = []
# Add recent messages # Add recent messages
for message_id, content, timestamp in recent_messages: for user_message, bot_message in recent_messages:
context_parts.append(f"[{timestamp}] User: {content}") combined_content = f"[Recent chat]\n{user_message}\n{bot_message}"
context_parts.append(combined_content)
# Add similar messages # Add similar messages
for message_id, content, similarity in similar_messages: for user_message, bot_message, similarity in similar_messages:
if f"[{content}" not in "\n".join(context_parts): # Avoid duplicates combined_content = f"{user_message}\n{bot_message}"
context_parts.append(f"[Similar] {content}") if combined_content not in "\n".join(context_parts):
context_parts.append(f"[You remember]\n{combined_content}")
return "\n".join(context_parts[-max_context * 2 :]) # Limit total context # Conversation history needs to be delivered in "newest context last" order
context_parts.reverse()
return "\n".join(context_parts[-max_context * 4 :]) # Limit total context
def clear_all_messages(self): def clear_all_messages(self):
"""Clear all messages and embeddings from the database.""" """Clear all messages and embeddings from the database."""
@@ -390,6 +410,7 @@ class CustomBotManager:
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
# Create table to hold custom bots
cursor.execute( cursor.execute(
""" """
CREATE TABLE IF NOT EXISTS custom_bots ( CREATE TABLE IF NOT EXISTS custom_bots (
@@ -461,8 +482,9 @@ class CustomBotManager:
if user_id: if user_id:
cursor.execute( cursor.execute(
""" """
SELECT bot_name, system_prompt, created_by SELECT bot_name, system_prompt, name
FROM custom_bots FROM custom_bots cb, username_map um
JOIN username_map ON custom_bots.created_by = username_map.id
WHERE is_active = 1 WHERE is_active = 1
ORDER BY created_at DESC ORDER BY created_at DESC
""" """

View File

@@ -5,9 +5,12 @@
import openai import openai
from typing import Iterable from typing import Iterable
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from openai._types import FileTypes, SequenceNotStr
from typing import Union
from io import BufferedReader, BytesIO
def chat_completion_think( def chat_completion(
system_prompt: str, system_prompt: str,
user_prompt: str, user_prompt: str,
openai_url: str, openai_url: str,
@@ -80,35 +83,56 @@ def chat_completion_instruct(
return "" return ""
def image_generation(prompt: str, n=1) -> str: def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str:
client = openai.OpenAI(base_url=OPENAI_API_IMAGE_ENDPOINT, api_key="placeholder") """Generates an image using the given prompt and returns the base64 encoded image data
Returns:
str: The base64 encoded image data. Decode and write to a file.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.generate( response = client.images.generate(
prompt=prompt, prompt=prompt,
n=n, n=n,
size="1024x1024", size="1024x1024",
) )
if response.data: if response.data:
return response.data[0].url return response.data[0].b64_json or ""
else: else:
return "" return ""
def image_edit(image, mask, prompt, n=1, size="1024x1024"): def image_edit(
client = openai.OpenAI(base_url=OPENAI_API_EDIT_ENDPOINT, api_key="placeholder") image: BufferedReader | BytesIO,
prompt: str,
openai_url: str,
openai_api_key: str,
n=1,
) -> str:
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.edit( response = client.images.edit(
image=image, image=image,
mask=mask,
prompt=prompt, prompt=prompt,
n=n, n=n,
size=size, size="1024x1024",
) )
return response.data[0].url if response.data:
return response.data[0].b64_json or ""
else:
return ""
def embeddings(text, model="text-embedding-3-small"): def embedding(
client = openai.OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key="placeholder") text: str, openai_url: str, openai_api_key: str, model: str
) -> list[float]:
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.embeddings.create( response = client.embeddings.create(
input=text, input=[text], model=model, encoding_format="float"
model=model,
) )
return response.data[0].embedding if response:
raw_data = response[0].embedding # type: ignore
# The result could be an array of floats or an array of an array of floats.
try:
return raw_data[0]
except Exception:
return raw_data
return []

View File

@@ -5,7 +5,19 @@ import base64
from io import BytesIO from io import BytesIO
from openai import OpenAI from openai import OpenAI
import logging import logging
from database import get_database, CustomBotManager from database import get_database, CustomBotManager # type: ignore
from config import ( # type: ignore
CHAT_ENDPOINT_KEY,
DISCORD_TOKEN,
CHAT_ENDPOINT,
CHAT_MODEL,
IMAGE_EDIT_ENDPOINT_KEY,
IMAGE_GEN_ENDPOINT,
IMAGE_EDIT_ENDPOINT,
MAX_COMPLETION_TOKENS,
)
import llama_wrapper # type: ignore
import requests
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -13,31 +25,11 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "placeholder")
OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT")
IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT")
IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT")
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
if not OPENAI_API_ENDPOINT:
raise Exception("OPENAI_API_ENDPOINT required.")
if not IMAGE_GEN_ENDPOINT:
raise Exception("IMAGE_GEN_ENDPOINT required.")
# Set your OpenAI API key as an environment variable
# You can also pass it directly but environment variables are safer
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder")
# Initialize the bot # Initialize the bot
intents = discord.Intents.default() intents = discord.Intents.default()
intents.message_content = True intents.message_content = True
bot = commands.Bot(command_prefix="!", intents=intents) bot = commands.Bot(command_prefix="!", intents=intents)
# OpenAI Completions API endpoint
OPENAI_COMPLETIONS_URL = f"{OPENAI_API_ENDPOINT}/chat/completions"
@bot.event @bot.event
async def on_ready(): async def on_ready():
@@ -46,7 +38,7 @@ async def on_ready():
logger.info(f"Bot logged in as {bot.user}") logger.info(f"Bot logged in as {bot.user}")
@bot.command(name="custom-bot") @bot.command(name="custom-bot") # type: ignore
async def custom_bot(ctx, bot_name: str, *, personality: str): async def custom_bot(ctx, bot_name: str, *, personality: str):
"""Create a custom bot with a name and personality """Create a custom bot with a name and personality
@@ -129,14 +121,14 @@ async def list_custom_bots(ctx):
f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}" f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}"
) )
bot_list = "🤖 **Available Custom Bots**:\n\n" bot_list = "🤖 **Available Custom Bots**:\n\n"
for name, prompt, creator in bots[:10]: # Limit to 10 bots for name, prompt, creator in bots:
bot_list += f"• **{name}** (created by {creator})\n" bot_list += f"• **{name}**\n"
logger.info(f"Sending bot list response to {ctx.author.name}") logger.info(f"Sending bot list response to {ctx.author.name}")
await ctx.send(bot_list) await ctx.send(bot_list)
@bot.command(name="delete-custom-bot") @bot.command(name="delete-custom-bot") # type: ignore
async def delete_custom_bot(ctx, bot_name: str): async def delete_custom_bot(ctx, bot_name: str):
"""Delete a custom bot (only the creator can delete) """Delete a custom bot (only the creator can delete)
@@ -194,16 +186,16 @@ async def on_message(message):
if message.author == bot.user: if message.author == bot.user:
return return
message_author = message.author.name
message_content = message.content.lower()
logger.debug( logger.debug(
f"Processing message from {message.author.name}: '{message.content[:50]}...'" f"Processing message from {message_author}: '{message_content[:50]}...'"
) )
ctx = await bot.get_context(message) ctx = await bot.get_context(message)
# Check if the message starts with a custom bot command logger.info("Initializing CustomBotManager to check for custom bot commands")
content = message.content.lower()
logger.info(f"Initializing CustomBotManager to check for custom bot commands")
custom_bot_manager = CustomBotManager() custom_bot_manager = CustomBotManager()
logger.info("Fetching list of custom bots to check for matching commands") logger.info("Fetching list of custom bots to check for matching commands")
@@ -212,7 +204,7 @@ async def on_message(message):
logger.info(f"Checking {len(custom_bots)} custom bots for command match") logger.info(f"Checking {len(custom_bots)} custom bots for command match")
for bot_name, system_prompt, _ in custom_bots: for bot_name, system_prompt, _ in custom_bots:
# Check if message starts with the custom bot name followed by a space # Check if message starts with the custom bot name followed by a space
if content.startswith(f"!{bot_name} "): if message_content.startswith(f"!{bot_name} "):
logger.info( logger.info(
f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}" f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}"
) )
@@ -224,25 +216,14 @@ async def on_message(message):
) )
# Prepare the payload with custom personality # Prepare the payload with custom personality
payload = {
"model": "qwen3-vl-30b-a3b-instruct",
"messages": [
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": user_message},
],
"max_completion_tokens": MAX_COMPLETION_TOKENS,
}
response_prefix = f"**{bot_name} response**" response_prefix = f"**{bot_name} response**"
logger.info(f"Sending request to OpenAI API for bot '{bot_name}'") logger.info(f"Sending request to OpenAI API for bot '{bot_name}'")
await handle_chat( await handle_chat(
ctx=ctx, ctx=ctx,
bot_name=bot_name,
message=user_message, message=user_message,
payload=payload, system_prompt=system_prompt,
response_prefix=response_prefix, response_prefix=response_prefix,
) )
return return
@@ -258,24 +239,22 @@ async def doodlebob(ctx, *, message: str):
logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}") logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}")
await ctx.send(f"**Doodlebob erasing {message[:100]}...**") await ctx.send(f"**Doodlebob erasing {message[:100]}...**")
image_prompt_payload = { system_prompt = (
"model": "qwen3-vl-30b-a3b-instruct",
"messages": [
{
"role": "system",
"content": (
"Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model." "Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model."
"If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self" "If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self"
" reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's" " reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's"
" questions." " questions."
), )
},
{"role": "user", "content": message},
],
}
# Wait for the generated image prompt # Wait for the generated image prompt
image_prompt = await call_llm(ctx, image_prompt_payload) image_prompt = llama_wrapper.chat_completion_instruct(
system_prompt=system_prompt,
user_prompt=message,
openai_url=CHAT_ENDPOINT,
openai_api_key=CHAT_ENDPOINT_KEY,
model=CHAT_MODEL,
max_tokens=MAX_COMPLETION_TOKENS,
)
# If the string is empty we had an error # If the string is empty we had an error
if image_prompt == "": if image_prompt == "":
@@ -285,33 +264,17 @@ async def doodlebob(ctx, *, message: str):
# Alert the user we're generating the image # Alert the user we're generating the image
await ctx.send(f"**Doodlebob calling drone strike on {image_prompt[:100]}...**") await ctx.send(f"**Doodlebob calling drone strike on {image_prompt[:100]}...**")
# Create the image prompt payload image_b64 = llama_wrapper.image_generation(
image_payload = { prompt=message,
"model": "default", openai_url=IMAGE_EDIT_ENDPOINT,
"prompt": image_prompt, openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
"n": 1,
"size": "1024x1024",
}
# Call the image generation endpoint
response = requests.post(
f"{IMAGE_GEN_ENDPOINT}/images/generations",
json=image_payload,
timeout=120,
) )
if response.status_code == 200: # Save the image to a file
result = response.json() edited_image_data = BytesIO(base64.b64decode(image_b64))
# Send image send_img = discord.File(edited_image_data, filename="image.png")
image_data = BytesIO(base64.b64decode(result["data"][0]["b64_json"]))
send_img = discord.File(image_data, filename="image.png")
await ctx.send(file=send_img) await ctx.send(file=send_img)
else:
print(f"❌ Error: {response.status_code}")
print(response.text)
return None
@bot.command(name="retcon") @bot.command(name="retcon")
async def retcon(ctx, *, message: str): async def retcon(ctx, *, message: str):
@@ -321,31 +284,23 @@ async def retcon(ctx, *, message: str):
await ctx.send(f"**Rewriting history to match {message[:100]}...**") await ctx.send(f"**Rewriting history to match {message[:100]}...**")
client = OpenAI(base_url=IMAGE_EDIT_ENDPOINT, api_key=OPENAI_API_KEY) image_b64 = llama_wrapper.image_edit(
image=image_bytestream,
result = client.images.edit(
model="placeholder",
image=[image_bytestream],
prompt=message, prompt=message,
size="1024x1024", openai_url=IMAGE_EDIT_ENDPOINT,
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
) )
image_base64 = result.data[0].b64_json
image_bytes = base64.b64decode(image_base64)
# Save the image to a file # Save the image to a file
edited_image_data = BytesIO(image_bytes) edited_image_data = BytesIO(base64.b64decode(image_b64))
send_img = discord.File(edited_image_data, filename="image.png") send_img = discord.File(edited_image_data, filename="image.png")
await ctx.send(file=send_img) await ctx.send(file=send_img)
async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str): async def handle_chat(
# Check if API key is set ctx, *, bot_name: str, message: str, system_prompt: str, response_prefix: str
if not OPENAI_API_KEY: ):
await ctx.send( await ctx.send(f"{bot_name} is searching its databanks for {message[:50]}...")
"Error: OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable."
)
return
# Get database instance # Get database instance
db = get_database() db = get_database()
@@ -356,31 +311,26 @@ async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str)
) )
if context: if context:
payload["messages"][0][ user_message = f"\n\nRelevant conversation history:\n{context}\n\n{message}"
"content" else:
] += f"\n\nRelevant conversation history:\n{context}" user_message = message
payload["messages"][1]["content"] = message logger.info(user_message)
print(payload) system_prompt_edit = (
"Keep your responses somewhat short, limited to 500 words or less. "
try: f"{system_prompt}"
# Initialize OpenAI client
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_ENDPOINT)
# Call OpenAI API
response = client.chat.completions.create(
model=payload["model"],
messages=payload["messages"],
max_completion_tokens=MAX_COMPLETION_TOKENS,
frequency_penalty=1.5,
presence_penalty=1.5,
temperature=1,
seed=-1,
) )
# Extract the generated text try:
generated_text = response.choices[0].message.content.strip() bot_response = llama_wrapper.chat_completion_instruct(
system_prompt=system_prompt_edit,
user_prompt=user_message,
openai_url=CHAT_ENDPOINT,
openai_api_key=CHAT_ENDPOINT_KEY,
model=CHAT_MODEL,
max_tokens=MAX_COMPLETION_TOKENS,
)
# Store both user message and bot response in the database # Store both user message and bot response in the database
db.add_message( db.add_message(
@@ -394,68 +344,24 @@ async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str)
db.add_message( db.add_message(
message_id=f"{ctx.message.id}_response", message_id=f"{ctx.message.id}_response",
user_id=str(bot.user.id), user_id=str(bot.user.id), # type: ignore
username=bot.user.name, username=bot.user.name, # type: ignore
content=f"Bot: {generated_text}", content=f"Bot: {bot_response}",
channel_id=str(ctx.channel.id), channel_id=str(ctx.channel.id),
guild_id=str(ctx.guild.id) if ctx.guild else None, guild_id=str(ctx.guild.id) if ctx.guild else None,
) )
# Send the response back to the chat # Send the response back to the chat
await ctx.send(response_prefix) await ctx.send(response_prefix)
while generated_text: while bot_response:
send_chunk = generated_text[:1000] send_chunk = bot_response[:1000]
generated_text = generated_text[1000:] bot_response = bot_response[1000:]
await ctx.send(send_chunk) await ctx.send(send_chunk)
except requests.exceptions.HTTPError as e:
await ctx.send(f"Error: OpenAI API error - {e}")
except requests.exceptions.Timeout:
await ctx.send("Error: Request timed out. Please try again.")
except Exception as e: except Exception as e:
await ctx.send(f"Error: {str(e)}") await ctx.send(f"Error: {str(e)}")
async def call_llm(ctx, payload: dict) -> str:
# Check if API key is set
if not OPENAI_API_KEY:
await ctx.send(
"Error: OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable."
)
return ""
# Set headers
headers = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json",
}
try:
# Initialize OpenAI client
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_ENDPOINT)
# Call OpenAI API
response = client.chat.completions.create(
model=payload["model"],
messages=payload["messages"],
max_tokens=MAX_COMPLETION_TOKENS,
)
# Extract the generated text
generated_text = response.choices[0].message.content.strip()
print(generated_text)
return generated_text
except requests.exceptions.HTTPError as e:
await ctx.send(f"Error: OpenAI API error - {e}")
except requests.exceptions.Timeout:
await ctx.send("Error: Request timed out. Please try again.")
except Exception as e:
await ctx.send(f"Error: {str(e)}")
return ""
# Run the bot # Run the bot
if __name__ == "__main__": if __name__ == "__main__":
bot.run(DISCORD_TOKEN) bot.run(DISCORD_TOKEN)

View File

@@ -1,30 +0,0 @@
import os
import pytest
from dotenv import load_dotenv
# Try to load .env.test first, fallback to .env
env_test_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env.test')
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env')
if os.path.exists(env_test_path):
load_dotenv(env_test_path)
print("✓ Loaded environment variables from .env.test")
elif os.path.exists(env_path):
load_dotenv(env_path)
print("✓ Loaded environment variables from .env")
@pytest.fixture(autouse=True, scope="session")
def verify_env_loaded():
"""Verify critical environment variables are loaded before tests run"""
required_vars = [
"DISCORD_TOKEN",
"OPENAI_API_ENDPOINT",
"IMAGE_GEN_ENDPOINT",
"IMAGE_EDIT_ENDPOINT"
]
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
pytest.fail(f"Missing required environment variables: {', '.join(missing_vars)}")
yield

View File

@@ -1,71 +1,112 @@
# Tests all functions in the llama-wrapper.py file # Tests all functions in the llama-wrapper.py file
# Run with: python -m pytest test_llama_wrapper.py -v # Run with: python -m pytest test_llama_wrapper.py -v
from discord import message
import pytest
from ..llama_wrapper import ( from ..llama_wrapper import (
chat_completion_think, chat_completion,
chat_completion_instruct, chat_completion_instruct,
image_generation, image_generation,
image_edit, image_edit,
embeddings, embedding,
) )
from dotenv import load_dotenv from ..config import (
import os CHAT_ENDPOINT,
CHAT_MODEL,
OPENAI_API_CHAT_ENDPOINT = os.getenv( CHAT_ENDPOINT_KEY,
"OPENAI_API_CHAT_ENDPOINT", "https://llama-cpp.reeselink.com" IMAGE_EDIT_ENDPOINT,
IMAGE_EDIT_ENDPOINT_KEY,
IMAGE_GEN_ENDPOINT,
IMAGE_GEN_ENDPOINT_KEY,
EMBEDDING_ENDPOINT,
EMBEDDING_ENDPOINT_KEY,
) )
OPENAI_API_IMAGE_ENDPOINT = os.getenv("OPENAI_API_IMAGE_ENDPOINT") from io import BytesIO
OPENAI_API_EDIT_ENDPOINT = os.getenv("OPENAI_API_EDIT_ENDPOINT") import base64
OPENAI_API_EMBED_ENDPOINT = os.getenv("OPENAI_API_EMBED_ENDPOINT") import tempfile
from pathlib import Path
import numpy as np
# Default models
DEFAULT_CHAT_MODEL = os.getenv("DEFAULT_CHAT_MODEL", "qwen3.5-35b-a3b") TEMPDIR = Path(tempfile.mkdtemp())
DEFAULT_EMBED_MODEL = os.getenv("DEFAULT_EMBED_MODEL", "text-embedding-3-small")
DEFAULT_IMAGE_MODEL = os.getenv("DEFAULT_IMAGE_MODEL", "dall-e-3")
DEFAULT_EDIT_MODEL = os.getenv("DEFAULT_EDIT_MODEL", "dall-e-2")
def test_chat_completion_think(): def test_chat_completion_think():
# This test will fail without an actual API endpoint result = chat_completion(
# But it's here to show the structure
chat_completion_think(
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
user_prompt="Tell me about Everquest", user_prompt="Tell me about Everquest",
openai_url=OPENAI_API_CHAT_ENDPOINT, openai_url=CHAT_ENDPOINT,
openai_api_key="placeholder", openai_api_key=CHAT_ENDPOINT_KEY,
model=DEFAULT_CHAT_MODEL, model=CHAT_MODEL,
max_tokens=100, max_tokens=100,
) )
print(result)
def test_chat_completion_instruct(): def test_chat_completion_instruct():
# This test will fail without an actual API endpoint result = chat_completion_instruct(
# But it's here to show the structure
chat_completion_instruct(
system_prompt="You are a helpful assistant.", system_prompt="You are a helpful assistant.",
user_prompt="Tell me about Everquest", user_prompt="Tell me about Everquest",
openai_url=OPENAI_API_CHAT_ENDPOINT, openai_url=CHAT_ENDPOINT,
openai_api_key="placeholder", openai_api_key=CHAT_ENDPOINT_KEY,
model=DEFAULT_CHAT_MODEL, model=CHAT_MODEL,
max_tokens=100, max_tokens=100,
) )
print(result)
def test_image_generation(): def test_image_generation():
# This test will fail without an actual API endpoint result = image_generation(
# But it's here to show the structure prompt="Generate an image of a horse",
pass openai_url=IMAGE_GEN_ENDPOINT,
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
)
with open("image-gen.png", "wb") as f:
f.write(base64.b64decode(result))
def test_image_edit(): def test_image_edit():
# This test will fail without an actual API endpoint with open("image-gen.png", "rb") as f:
# But it's here to show the structure image_data = BytesIO(f.read())
pass result = image_edit(
image=image_data,
prompt="Paint the words 'horse' on the horse.",
openai_url=IMAGE_EDIT_ENDPOINT,
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
)
with open("image-edit.png", "wb") as f:
f.write(base64.b64decode(result))
def _cosine_similarity(a, b):
"""
Close to 1: very similar
Close to 0: orthogonal
Close to -1: opposite
"""
a, b = np.array(a), np.array(b)
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def test_embeddings(): def test_embeddings():
# This test will fail without an actual API endpoint result1 = embedding(
# But it's here to show the structure "this is a horse",
pass openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
model="qwen3-embed-4b",
)
result2 = embedding(
"this is a horse also",
openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
model="qwen3-embed-4b",
)
result3 = embedding(
"this is a donkey",
openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
model="qwen3-embed-4b",
)
similarity_1 = _cosine_similarity(result1, result2)
assert similarity_1 > 0.9
similarity_2 = _cosine_similarity(result1, result3)
assert similarity_2 < 0.5