human cleanup
This commit is contained in:
BIN
image-edit.png
Normal file
BIN
image-edit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
BIN
image-gen.png
Normal file
BIN
image-gen.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.7 MiB |
85
vibe_bot/config.py
Normal file
85
vibe_bot/config.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Discord
|
||||||
|
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "")
|
||||||
|
|
||||||
|
# Endpoints
|
||||||
|
CHAT_ENDPOINT = os.getenv("CHAT_ENDPOINT", "")
|
||||||
|
COMPLETION_ENDPOINT = os.getenv("COMPLETION_ENDPOINT", "")
|
||||||
|
IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT", "")
|
||||||
|
IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT", "")
|
||||||
|
EMBEDDING_ENDPOINT = os.getenv("EMBEDDING_ENDPOINT", "")
|
||||||
|
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
|
||||||
|
|
||||||
|
# API Keys
|
||||||
|
CHAT_ENDPOINT_KEY = os.getenv("CHAT_ENDPOINT_KEY", "placeholder")
|
||||||
|
COMPLETION_ENDPOINT_KEY = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder")
|
||||||
|
IMAGE_GEN_ENDPOINT_KEY = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder")
|
||||||
|
IMAGE_EDIT_ENDPOINT_KEY = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder")
|
||||||
|
EMBEDDING_ENDPOINT_KEY = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder")
|
||||||
|
|
||||||
|
# Models
|
||||||
|
CHAT_MODEL = os.getenv("CHAT_MODEL", "")
|
||||||
|
COMPLETION_MODEL = os.getenv("COMPLETION_MODEL", "")
|
||||||
|
IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "")
|
||||||
|
IMAGE_EDIT_MODEL = os.getenv("IMAGE_EDIT_MODEL", "")
|
||||||
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "")
|
||||||
|
|
||||||
|
# Database and embeddings
|
||||||
|
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
|
||||||
|
EMBEDDING_DIMENSION = 2048
|
||||||
|
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
||||||
|
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
||||||
|
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
|
||||||
|
|
||||||
|
# Check token
|
||||||
|
if not DISCORD_TOKEN:
|
||||||
|
raise Exception("DISCORD_TOKEN required.")
|
||||||
|
|
||||||
|
# Check endpoints
|
||||||
|
if not CHAT_ENDPOINT:
|
||||||
|
raise Exception("CHAT_ENDPOINT required.")
|
||||||
|
|
||||||
|
if not COMPLETION_ENDPOINT:
|
||||||
|
raise Exception("COMPLETION_ENDPOINT required.")
|
||||||
|
|
||||||
|
if not IMAGE_GEN_ENDPOINT:
|
||||||
|
raise Exception("IMAGE_GEN_ENDPOINT required.")
|
||||||
|
|
||||||
|
if not IMAGE_EDIT_ENDPOINT:
|
||||||
|
raise Exception("IMAGE_EDIT_ENDPOINT required.")
|
||||||
|
|
||||||
|
if not EMBEDDING_ENDPOINT:
|
||||||
|
raise Exception("EMBEDDING_ENDPOINT required.")
|
||||||
|
|
||||||
|
# Check models
|
||||||
|
if not CHAT_MODEL:
|
||||||
|
raise Exception("CHAT_MODEL required.")
|
||||||
|
|
||||||
|
if not COMPLETION_MODEL:
|
||||||
|
raise Exception("COMPLETION_MODEL required.")
|
||||||
|
|
||||||
|
if not IMAGE_GEN_MODEL:
|
||||||
|
raise Exception("IMAGE_GEN_MODEL required.")
|
||||||
|
|
||||||
|
if not IMAGE_EDIT_MODEL:
|
||||||
|
raise Exception("IMAGE_EDIT_MODEL required.")
|
||||||
|
|
||||||
|
if not EMBEDDING_MODEL:
|
||||||
|
raise Exception("EMBEDDING_MODEL required.")
|
||||||
|
|
||||||
|
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
|
||||||
|
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
|
||||||
|
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
|
||||||
|
logger.info(f"IMAGE_EDIT_ENDPOINT set to {IMAGE_EDIT_ENDPOINT}")
|
||||||
|
logger.info(f"EMBEDDING_ENDPOINT set to {EMBEDDING_ENDPOINT}")
|
||||||
@@ -1,32 +1,27 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import os
|
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import llama_wrapper # type: ignore
|
||||||
|
from config import ( # type: ignore
|
||||||
|
DB_PATH,
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
EMBEDDING_ENDPOINT,
|
||||||
|
EMBEDDING_ENDPOINT_KEY,
|
||||||
|
MAX_HISTORY_MESSAGES,
|
||||||
|
SIMILARITY_THRESHOLD,
|
||||||
|
TOP_K_RESULTS,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Database configuration
|
|
||||||
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
|
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "qwen3-embed-4b")
|
|
||||||
EMBEDDING_DIMENSION = 2048 # Default for qwen3-embed-4b
|
|
||||||
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
|
||||||
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
|
||||||
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
|
|
||||||
|
|
||||||
# OpenAI configuration
|
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder")
|
|
||||||
OPENAI_API_EMBED_ENDPOINT = os.getenv(
|
|
||||||
"OPENAI_API_EMBED_ENDPOINT", "https://llama-embed.reeselink.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatDatabase:
|
class ChatDatabase:
|
||||||
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
|
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
|
||||||
@@ -34,7 +29,9 @@ class ChatDatabase:
|
|||||||
def __init__(self, db_path: str = DB_PATH):
|
def __init__(self, db_path: str = DB_PATH):
|
||||||
logger.info(f"Initializing ChatDatabase with path: {db_path}")
|
logger.info(f"Initializing ChatDatabase with path: {db_path}")
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self.client = OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key=OPENAI_API_KEY)
|
self.client = OpenAI(
|
||||||
|
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
|
||||||
|
)
|
||||||
logger.info("Connecting to OpenAI API for embeddings")
|
logger.info("Connecting to OpenAI API for embeddings")
|
||||||
self._initialize_database()
|
self._initialize_database()
|
||||||
|
|
||||||
@@ -96,36 +93,6 @@ class ChatDatabase:
|
|||||||
logger.info("Database initialization completed successfully")
|
logger.info("Database initialization completed successfully")
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _generate_embedding(self, text: str) -> List[float]:
|
|
||||||
"""Generate embedding for text using OpenAI API."""
|
|
||||||
logger.debug(f"Generating embedding for text (length: {len(text)})")
|
|
||||||
try:
|
|
||||||
logger.info(f"Calling OpenAI API to generate embedding with model: {EMBEDDING_MODEL}")
|
|
||||||
response = self.client.embeddings.create(
|
|
||||||
model=EMBEDDING_MODEL, input=text, encoding_format="float"
|
|
||||||
)
|
|
||||||
logger.debug("OpenAI API response received successfully")
|
|
||||||
|
|
||||||
# The embedding is returned as a nested list: [[embedding_values]]
|
|
||||||
# We need to extract the inner list
|
|
||||||
embedding_data = response[0].embedding
|
|
||||||
if isinstance(embedding_data, list) and len(embedding_data) > 0:
|
|
||||||
# The first element might be the embedding array itself or a nested list
|
|
||||||
first_item = embedding_data[0]
|
|
||||||
if isinstance(first_item, list):
|
|
||||||
# Handle nested structure: [[values]] -> [values]
|
|
||||||
logger.debug("Extracted embedding from nested structure [[values]]")
|
|
||||||
return first_item
|
|
||||||
else:
|
|
||||||
# Handle direct structure: [values]
|
|
||||||
logger.debug("Extracted embedding from direct structure [values]")
|
|
||||||
return embedding_data
|
|
||||||
logger.warning("Embedding data is empty or invalid")
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating embedding: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _vector_to_bytes(self, vector: List[float]) -> bytes:
|
def _vector_to_bytes(self, vector: List[float]) -> bytes:
|
||||||
"""Convert vector to bytes for SQLite storage."""
|
"""Convert vector to bytes for SQLite storage."""
|
||||||
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
|
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
|
||||||
@@ -142,7 +109,9 @@ class ChatDatabase:
|
|||||||
|
|
||||||
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||||
"""Calculate cosine similarity between two vectors."""
|
"""Calculate cosine similarity between two vectors."""
|
||||||
logger.debug(f"Calculating cosine similarity between vectors of dimension {len(vec1)}")
|
logger.debug(
|
||||||
|
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
|
||||||
|
)
|
||||||
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||||
logger.debug(f"Similarity calculated: {result:.4f}")
|
logger.debug(f"Similarity calculated: {result:.4f}")
|
||||||
return result
|
return result
|
||||||
@@ -163,7 +132,9 @@ class ChatDatabase:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Insert message
|
# Insert message
|
||||||
logger.debug(f"Inserting message into chat_messages table: message_id={message_id}")
|
logger.debug(
|
||||||
|
f"Inserting message into chat_messages table: message_id={message_id}"
|
||||||
|
)
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO chat_messages
|
INSERT OR REPLACE INTO chat_messages
|
||||||
@@ -176,9 +147,16 @@ class ChatDatabase:
|
|||||||
|
|
||||||
# Generate and store embedding
|
# Generate and store embedding
|
||||||
logger.info(f"Generating embedding for message {message_id}")
|
logger.info(f"Generating embedding for message {message_id}")
|
||||||
embedding = self._generate_embedding(content)
|
embedding = llama_wrapper.embedding(
|
||||||
|
content,
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model=EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
if embedding:
|
if embedding:
|
||||||
logger.debug(f"Embedding generated successfully for message {message_id}, storing in database")
|
logger.debug(
|
||||||
|
f"Embedding generated successfully for message {message_id}, storing in database"
|
||||||
|
)
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO message_embeddings
|
INSERT OR REPLACE INTO message_embeddings
|
||||||
@@ -187,9 +165,13 @@ class ChatDatabase:
|
|||||||
""",
|
""",
|
||||||
(message_id, self._vector_to_bytes(embedding)),
|
(message_id, self._vector_to_bytes(embedding)),
|
||||||
)
|
)
|
||||||
logger.debug(f"Embedding stored in message_embeddings table for message {message_id}")
|
logger.debug(
|
||||||
|
f"Embedding stored in message_embeddings table for message {message_id}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Failed to generate embedding for message {message_id}, skipping embedding storage")
|
logger.warning(
|
||||||
|
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up old messages if exceeding limit
|
# Clean up old messages if exceeding limit
|
||||||
logger.info("Checking if cleanup of old messages is needed")
|
logger.info("Checking if cleanup of old messages is needed")
|
||||||
@@ -268,9 +250,14 @@ class ChatDatabase:
|
|||||||
query: str,
|
query: str,
|
||||||
top_k: int = TOP_K_RESULTS,
|
top_k: int = TOP_K_RESULTS,
|
||||||
min_similarity: float = SIMILARITY_THRESHOLD,
|
min_similarity: float = SIMILARITY_THRESHOLD,
|
||||||
) -> List[Tuple[str, str, str, float]]:
|
) -> List[Tuple[str, str, float]]:
|
||||||
"""Search for messages similar to the query using embeddings."""
|
"""Search for messages similar to the query using embeddings."""
|
||||||
query_embedding = self._generate_embedding(query)
|
query_embedding = llama_wrapper.embedding(
|
||||||
|
text=query,
|
||||||
|
model=EMBEDDING_MODEL,
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
)
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -285,19 +272,28 @@ class ChatDatabase:
|
|||||||
SELECT cm.message_id, cm.content, me.embedding
|
SELECT cm.message_id, cm.content, me.embedding
|
||||||
FROM chat_messages cm
|
FROM chat_messages cm
|
||||||
JOIN message_embeddings me ON cm.message_id = me.message_id
|
JOIN message_embeddings me ON cm.message_id = me.message_id
|
||||||
|
WHERE cm.username != 'vibe-bot'
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
results = []
|
results: list[tuple[str, str, float]] = []
|
||||||
for message_id, content, embedding_blob in rows:
|
for message_id, content, embedding_blob in rows:
|
||||||
embedding_vector = self._bytes_to_vector(embedding_blob)
|
embedding_vector = self._bytes_to_vector(embedding_blob)
|
||||||
similarity = self._calculate_similarity(query_vector, embedding_vector)
|
similarity = self._calculate_similarity(query_vector, embedding_vector)
|
||||||
|
|
||||||
if similarity >= min_similarity:
|
if similarity >= min_similarity:
|
||||||
results.append(
|
cursor.execute(
|
||||||
(message_id, content[:500], similarity)
|
"""
|
||||||
) # Limit content length
|
SELECT content
|
||||||
|
FROM chat_messages
|
||||||
|
WHERE message_id = ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
""",
|
||||||
|
(f"{message_id}_response",),
|
||||||
|
)
|
||||||
|
response: str = cursor.fetchone()[0]
|
||||||
|
results.append((content, response, similarity))
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@@ -305,28 +301,48 @@ class ChatDatabase:
|
|||||||
results.sort(key=lambda x: x[2], reverse=True)
|
results.sort(key=lambda x: x[2], reverse=True)
|
||||||
return results[:top_k]
|
return results[:top_k]
|
||||||
|
|
||||||
def get_user_history(
|
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||||
self, user_id: str, limit: int = 20
|
|
||||||
) -> List[Tuple[str, str, datetime]]:
|
|
||||||
"""Get message history for a specific user."""
|
"""Get message history for a specific user."""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
logger.info(f"Fetching last {limit} user messages")
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT message_id, content, timestamp
|
SELECT message_id, content, timestamp
|
||||||
FROM chat_messages
|
FROM chat_messages
|
||||||
WHERE user_id = ?
|
WHERE username != 'vibe-bot'
|
||||||
ORDER BY timestamp DESC
|
ORDER BY timestamp DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""",
|
""",
|
||||||
(user_id, limit),
|
(limit,),
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = cursor.fetchall()
|
messages = cursor.fetchall()
|
||||||
|
|
||||||
|
# Format is [user message, bot response]
|
||||||
|
conversations: list[tuple[str, str]] = []
|
||||||
|
for message in messages:
|
||||||
|
msg_content: str = message[1]
|
||||||
|
logger.info(f"Finding response for {msg_content[:50]}")
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT content
|
||||||
|
FROM chat_messages
|
||||||
|
WHERE message_id = ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
""",
|
||||||
|
(f"{message[0]}_response",),
|
||||||
|
)
|
||||||
|
response_content: str = cursor.fetchone()
|
||||||
|
if response_content:
|
||||||
|
logger.info(f"Found response: {response_content[0][:50]}")
|
||||||
|
conversations.append((msg_content, response_content[0]))
|
||||||
|
else:
|
||||||
|
logger.info("No response found")
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return messages
|
return conversations
|
||||||
|
|
||||||
def get_conversation_context(
|
def get_conversation_context(
|
||||||
self, user_id: str, current_message: str, max_context: int = 5
|
self, user_id: str, current_message: str, max_context: int = 5
|
||||||
@@ -344,15 +360,19 @@ class ChatDatabase:
|
|||||||
context_parts = []
|
context_parts = []
|
||||||
|
|
||||||
# Add recent messages
|
# Add recent messages
|
||||||
for message_id, content, timestamp in recent_messages:
|
for user_message, bot_message in recent_messages:
|
||||||
context_parts.append(f"[{timestamp}] User: {content}")
|
combined_content = f"[Recent chat]\n{user_message}\n{bot_message}"
|
||||||
|
context_parts.append(combined_content)
|
||||||
|
|
||||||
# Add similar messages
|
# Add similar messages
|
||||||
for message_id, content, similarity in similar_messages:
|
for user_message, bot_message, similarity in similar_messages:
|
||||||
if f"[{content}" not in "\n".join(context_parts): # Avoid duplicates
|
combined_content = f"{user_message}\n{bot_message}"
|
||||||
context_parts.append(f"[Similar] {content}")
|
if combined_content not in "\n".join(context_parts):
|
||||||
|
context_parts.append(f"[You remember]\n{combined_content}")
|
||||||
|
|
||||||
return "\n".join(context_parts[-max_context * 2 :]) # Limit total context
|
# Conversation history needs to be delivered in "newest context last" order
|
||||||
|
context_parts.reverse()
|
||||||
|
return "\n".join(context_parts[-max_context * 4 :]) # Limit total context
|
||||||
|
|
||||||
def clear_all_messages(self):
|
def clear_all_messages(self):
|
||||||
"""Clear all messages and embeddings from the database."""
|
"""Clear all messages and embeddings from the database."""
|
||||||
@@ -390,6 +410,7 @@ class CustomBotManager:
|
|||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Create table to hold custom bots
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS custom_bots (
|
CREATE TABLE IF NOT EXISTS custom_bots (
|
||||||
@@ -399,7 +420,7 @@ class CustomBotManager:
|
|||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
is_active INTEGER DEFAULT 1
|
is_active INTEGER DEFAULT 1
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -461,8 +482,9 @@ class CustomBotManager:
|
|||||||
if user_id:
|
if user_id:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT bot_name, system_prompt, created_by
|
SELECT bot_name, system_prompt, name
|
||||||
FROM custom_bots
|
FROM custom_bots cb, username_map um
|
||||||
|
JOIN username_map ON custom_bots.created_by = username_map.id
|
||||||
WHERE is_active = 1
|
WHERE is_active = 1
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,9 +5,12 @@
|
|||||||
import openai
|
import openai
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
from openai._types import FileTypes, SequenceNotStr
|
||||||
|
from typing import Union
|
||||||
|
from io import BufferedReader, BytesIO
|
||||||
|
|
||||||
|
|
||||||
def chat_completion_think(
|
def chat_completion(
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
user_prompt: str,
|
user_prompt: str,
|
||||||
openai_url: str,
|
openai_url: str,
|
||||||
@@ -80,35 +83,56 @@ def chat_completion_instruct(
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def image_generation(prompt: str, n=1) -> str:
|
def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str:
|
||||||
client = openai.OpenAI(base_url=OPENAI_API_IMAGE_ENDPOINT, api_key="placeholder")
|
"""Generates an image using the given prompt and returns the base64 encoded image data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The base64 encoded image data. Decode and write to a file.
|
||||||
|
"""
|
||||||
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
response = client.images.generate(
|
response = client.images.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=n,
|
n=n,
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
)
|
)
|
||||||
if response.data:
|
if response.data:
|
||||||
return response.data[0].url
|
return response.data[0].b64_json or ""
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def image_edit(image, mask, prompt, n=1, size="1024x1024"):
|
def image_edit(
|
||||||
client = openai.OpenAI(base_url=OPENAI_API_EDIT_ENDPOINT, api_key="placeholder")
|
image: BufferedReader | BytesIO,
|
||||||
|
prompt: str,
|
||||||
|
openai_url: str,
|
||||||
|
openai_api_key: str,
|
||||||
|
n=1,
|
||||||
|
) -> str:
|
||||||
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
response = client.images.edit(
|
response = client.images.edit(
|
||||||
image=image,
|
image=image,
|
||||||
mask=mask,
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
n=n,
|
n=n,
|
||||||
size=size,
|
size="1024x1024",
|
||||||
)
|
)
|
||||||
return response.data[0].url
|
if response.data:
|
||||||
|
return response.data[0].b64_json or ""
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def embeddings(text, model="text-embedding-3-small"):
|
def embedding(
|
||||||
client = openai.OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key="placeholder")
|
text: str, openai_url: str, openai_api_key: str, model: str
|
||||||
|
) -> list[float]:
|
||||||
|
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||||
response = client.embeddings.create(
|
response = client.embeddings.create(
|
||||||
input=text,
|
input=[text], model=model, encoding_format="float"
|
||||||
model=model,
|
|
||||||
)
|
)
|
||||||
return response.data[0].embedding
|
if response:
|
||||||
|
raw_data = response[0].embedding # type: ignore
|
||||||
|
# The result could be an array of floats or an array of an array of floats.
|
||||||
|
try:
|
||||||
|
return raw_data[0]
|
||||||
|
except Exception:
|
||||||
|
return raw_data
|
||||||
|
return []
|
||||||
|
|||||||
248
vibe_bot/main.py
248
vibe_bot/main.py
@@ -5,7 +5,19 @@ import base64
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import logging
|
import logging
|
||||||
from database import get_database, CustomBotManager
|
from database import get_database, CustomBotManager # type: ignore
|
||||||
|
from config import ( # type: ignore
|
||||||
|
CHAT_ENDPOINT_KEY,
|
||||||
|
DISCORD_TOKEN,
|
||||||
|
CHAT_ENDPOINT,
|
||||||
|
CHAT_MODEL,
|
||||||
|
IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
|
IMAGE_GEN_ENDPOINT,
|
||||||
|
IMAGE_EDIT_ENDPOINT,
|
||||||
|
MAX_COMPLETION_TOKENS,
|
||||||
|
)
|
||||||
|
import llama_wrapper # type: ignore
|
||||||
|
import requests
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -13,31 +25,11 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "placeholder")
|
|
||||||
|
|
||||||
OPENAI_API_ENDPOINT = os.getenv("OPENAI_API_ENDPOINT")
|
|
||||||
IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT")
|
|
||||||
IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT")
|
|
||||||
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
|
|
||||||
|
|
||||||
if not OPENAI_API_ENDPOINT:
|
|
||||||
raise Exception("OPENAI_API_ENDPOINT required.")
|
|
||||||
|
|
||||||
if not IMAGE_GEN_ENDPOINT:
|
|
||||||
raise Exception("IMAGE_GEN_ENDPOINT required.")
|
|
||||||
|
|
||||||
# Set your OpenAI API key as an environment variable
|
|
||||||
# You can also pass it directly but environment variables are safer
|
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder")
|
|
||||||
|
|
||||||
# Initialize the bot
|
# Initialize the bot
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||||
|
|
||||||
# OpenAI Completions API endpoint
|
|
||||||
OPENAI_COMPLETIONS_URL = f"{OPENAI_API_ENDPOINT}/chat/completions"
|
|
||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_ready():
|
async def on_ready():
|
||||||
@@ -46,7 +38,7 @@ async def on_ready():
|
|||||||
logger.info(f"Bot logged in as {bot.user}")
|
logger.info(f"Bot logged in as {bot.user}")
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="custom-bot")
|
@bot.command(name="custom-bot") # type: ignore
|
||||||
async def custom_bot(ctx, bot_name: str, *, personality: str):
|
async def custom_bot(ctx, bot_name: str, *, personality: str):
|
||||||
"""Create a custom bot with a name and personality
|
"""Create a custom bot with a name and personality
|
||||||
|
|
||||||
@@ -129,14 +121,14 @@ async def list_custom_bots(ctx):
|
|||||||
f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}"
|
f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}"
|
||||||
)
|
)
|
||||||
bot_list = "🤖 **Available Custom Bots**:\n\n"
|
bot_list = "🤖 **Available Custom Bots**:\n\n"
|
||||||
for name, prompt, creator in bots[:10]: # Limit to 10 bots
|
for name, prompt, creator in bots:
|
||||||
bot_list += f"• **{name}** (created by {creator})\n"
|
bot_list += f"• **{name}**\n"
|
||||||
|
|
||||||
logger.info(f"Sending bot list response to {ctx.author.name}")
|
logger.info(f"Sending bot list response to {ctx.author.name}")
|
||||||
await ctx.send(bot_list)
|
await ctx.send(bot_list)
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="delete-custom-bot")
|
@bot.command(name="delete-custom-bot") # type: ignore
|
||||||
async def delete_custom_bot(ctx, bot_name: str):
|
async def delete_custom_bot(ctx, bot_name: str):
|
||||||
"""Delete a custom bot (only the creator can delete)
|
"""Delete a custom bot (only the creator can delete)
|
||||||
|
|
||||||
@@ -194,16 +186,16 @@ async def on_message(message):
|
|||||||
if message.author == bot.user:
|
if message.author == bot.user:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
message_author = message.author.name
|
||||||
|
message_content = message.content.lower()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Processing message from {message.author.name}: '{message.content[:50]}...'"
|
f"Processing message from {message_author}: '{message_content[:50]}...'"
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx = await bot.get_context(message)
|
ctx = await bot.get_context(message)
|
||||||
|
|
||||||
# Check if the message starts with a custom bot command
|
logger.info("Initializing CustomBotManager to check for custom bot commands")
|
||||||
content = message.content.lower()
|
|
||||||
|
|
||||||
logger.info(f"Initializing CustomBotManager to check for custom bot commands")
|
|
||||||
custom_bot_manager = CustomBotManager()
|
custom_bot_manager = CustomBotManager()
|
||||||
|
|
||||||
logger.info("Fetching list of custom bots to check for matching commands")
|
logger.info("Fetching list of custom bots to check for matching commands")
|
||||||
@@ -212,7 +204,7 @@ async def on_message(message):
|
|||||||
logger.info(f"Checking {len(custom_bots)} custom bots for command match")
|
logger.info(f"Checking {len(custom_bots)} custom bots for command match")
|
||||||
for bot_name, system_prompt, _ in custom_bots:
|
for bot_name, system_prompt, _ in custom_bots:
|
||||||
# Check if message starts with the custom bot name followed by a space
|
# Check if message starts with the custom bot name followed by a space
|
||||||
if content.startswith(f"!{bot_name} "):
|
if message_content.startswith(f"!{bot_name} "):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}"
|
f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}"
|
||||||
)
|
)
|
||||||
@@ -224,25 +216,14 @@ async def on_message(message):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the payload with custom personality
|
# Prepare the payload with custom personality
|
||||||
payload = {
|
|
||||||
"model": "qwen3-vl-30b-a3b-instruct",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": system_prompt,
|
|
||||||
},
|
|
||||||
{"role": "user", "content": user_message},
|
|
||||||
],
|
|
||||||
"max_completion_tokens": MAX_COMPLETION_TOKENS,
|
|
||||||
}
|
|
||||||
|
|
||||||
response_prefix = f"**{bot_name} response**"
|
response_prefix = f"**{bot_name} response**"
|
||||||
|
|
||||||
logger.info(f"Sending request to OpenAI API for bot '{bot_name}'")
|
logger.info(f"Sending request to OpenAI API for bot '{bot_name}'")
|
||||||
await handle_chat(
|
await handle_chat(
|
||||||
ctx=ctx,
|
ctx=ctx,
|
||||||
|
bot_name=bot_name,
|
||||||
message=user_message,
|
message=user_message,
|
||||||
payload=payload,
|
system_prompt=system_prompt,
|
||||||
response_prefix=response_prefix,
|
response_prefix=response_prefix,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -258,24 +239,22 @@ async def doodlebob(ctx, *, message: str):
|
|||||||
logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}")
|
logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}")
|
||||||
await ctx.send(f"**Doodlebob erasing {message[:100]}...**")
|
await ctx.send(f"**Doodlebob erasing {message[:100]}...**")
|
||||||
|
|
||||||
image_prompt_payload = {
|
system_prompt = (
|
||||||
"model": "qwen3-vl-30b-a3b-instruct",
|
"Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model."
|
||||||
"messages": [
|
"If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self"
|
||||||
{
|
" reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's"
|
||||||
"role": "system",
|
" questions."
|
||||||
"content": (
|
)
|
||||||
"Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model."
|
|
||||||
"If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self"
|
|
||||||
" reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's"
|
|
||||||
" questions."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": message},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Wait for the generated image prompt
|
# Wait for the generated image prompt
|
||||||
image_prompt = await call_llm(ctx, image_prompt_payload)
|
image_prompt = llama_wrapper.chat_completion_instruct(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_prompt=message,
|
||||||
|
openai_url=CHAT_ENDPOINT,
|
||||||
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
|
model=CHAT_MODEL,
|
||||||
|
max_tokens=MAX_COMPLETION_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
# If the string is empty we had an error
|
# If the string is empty we had an error
|
||||||
if image_prompt == "":
|
if image_prompt == "":
|
||||||
@@ -285,32 +264,16 @@ async def doodlebob(ctx, *, message: str):
|
|||||||
# Alert the user we're generating the image
|
# Alert the user we're generating the image
|
||||||
await ctx.send(f"**Doodlebob calling drone strike on {image_prompt[:100]}...**")
|
await ctx.send(f"**Doodlebob calling drone strike on {image_prompt[:100]}...**")
|
||||||
|
|
||||||
# Create the image prompt payload
|
image_b64 = llama_wrapper.image_generation(
|
||||||
image_payload = {
|
prompt=message,
|
||||||
"model": "default",
|
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||||
"prompt": image_prompt,
|
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
"n": 1,
|
|
||||||
"size": "1024x1024",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Call the image generation endpoint
|
|
||||||
response = requests.post(
|
|
||||||
f"{IMAGE_GEN_ENDPOINT}/images/generations",
|
|
||||||
json=image_payload,
|
|
||||||
timeout=120,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
# Save the image to a file
|
||||||
result = response.json()
|
edited_image_data = BytesIO(base64.b64decode(image_b64))
|
||||||
# Send image
|
send_img = discord.File(edited_image_data, filename="image.png")
|
||||||
image_data = BytesIO(base64.b64decode(result["data"][0]["b64_json"]))
|
await ctx.send(file=send_img)
|
||||||
send_img = discord.File(image_data, filename="image.png")
|
|
||||||
await ctx.send(file=send_img)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f"❌ Error: {response.status_code}")
|
|
||||||
print(response.text)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="retcon")
|
@bot.command(name="retcon")
|
||||||
@@ -321,31 +284,23 @@ async def retcon(ctx, *, message: str):
|
|||||||
|
|
||||||
await ctx.send(f"**Rewriting history to match {message[:100]}...**")
|
await ctx.send(f"**Rewriting history to match {message[:100]}...**")
|
||||||
|
|
||||||
client = OpenAI(base_url=IMAGE_EDIT_ENDPOINT, api_key=OPENAI_API_KEY)
|
image_b64 = llama_wrapper.image_edit(
|
||||||
|
image=image_bytestream,
|
||||||
result = client.images.edit(
|
|
||||||
model="placeholder",
|
|
||||||
image=[image_bytestream],
|
|
||||||
prompt=message,
|
prompt=message,
|
||||||
size="1024x1024",
|
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||||
|
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_base64 = result.data[0].b64_json
|
|
||||||
image_bytes = base64.b64decode(image_base64)
|
|
||||||
|
|
||||||
# Save the image to a file
|
# Save the image to a file
|
||||||
edited_image_data = BytesIO(image_bytes)
|
edited_image_data = BytesIO(base64.b64decode(image_b64))
|
||||||
send_img = discord.File(edited_image_data, filename="image.png")
|
send_img = discord.File(edited_image_data, filename="image.png")
|
||||||
await ctx.send(file=send_img)
|
await ctx.send(file=send_img)
|
||||||
|
|
||||||
|
|
||||||
async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str):
|
async def handle_chat(
|
||||||
# Check if API key is set
|
ctx, *, bot_name: str, message: str, system_prompt: str, response_prefix: str
|
||||||
if not OPENAI_API_KEY:
|
):
|
||||||
await ctx.send(
|
await ctx.send(f"{bot_name} is searching its databanks for {message[:50]}...")
|
||||||
"Error: OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get database instance
|
# Get database instance
|
||||||
db = get_database()
|
db = get_database()
|
||||||
@@ -356,32 +311,27 @@ async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str)
|
|||||||
)
|
)
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
payload["messages"][0][
|
user_message = f"\n\nRelevant conversation history:\n{context}\n\n{message}"
|
||||||
"content"
|
else:
|
||||||
] += f"\n\nRelevant conversation history:\n{context}"
|
user_message = message
|
||||||
|
|
||||||
payload["messages"][1]["content"] = message
|
logger.info(user_message)
|
||||||
|
|
||||||
print(payload)
|
system_prompt_edit = (
|
||||||
|
"Keep your responses somewhat short, limited to 500 words or less. "
|
||||||
|
f"{system_prompt}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize OpenAI client
|
bot_response = llama_wrapper.chat_completion_instruct(
|
||||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_ENDPOINT)
|
system_prompt=system_prompt_edit,
|
||||||
|
user_prompt=user_message,
|
||||||
# Call OpenAI API
|
openai_url=CHAT_ENDPOINT,
|
||||||
response = client.chat.completions.create(
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
model=payload["model"],
|
model=CHAT_MODEL,
|
||||||
messages=payload["messages"],
|
max_tokens=MAX_COMPLETION_TOKENS,
|
||||||
max_completion_tokens=MAX_COMPLETION_TOKENS,
|
|
||||||
frequency_penalty=1.5,
|
|
||||||
presence_penalty=1.5,
|
|
||||||
temperature=1,
|
|
||||||
seed=-1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract the generated text
|
|
||||||
generated_text = response.choices[0].message.content.strip()
|
|
||||||
|
|
||||||
# Store both user message and bot response in the database
|
# Store both user message and bot response in the database
|
||||||
db.add_message(
|
db.add_message(
|
||||||
message_id=f"{ctx.message.id}",
|
message_id=f"{ctx.message.id}",
|
||||||
@@ -394,68 +344,24 @@ async def handle_chat(ctx, *, message: str, payload: dict, response_prefix: str)
|
|||||||
|
|
||||||
db.add_message(
|
db.add_message(
|
||||||
message_id=f"{ctx.message.id}_response",
|
message_id=f"{ctx.message.id}_response",
|
||||||
user_id=str(bot.user.id),
|
user_id=str(bot.user.id), # type: ignore
|
||||||
username=bot.user.name,
|
username=bot.user.name, # type: ignore
|
||||||
content=f"Bot: {generated_text}",
|
content=f"Bot: {bot_response}",
|
||||||
channel_id=str(ctx.channel.id),
|
channel_id=str(ctx.channel.id),
|
||||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send the response back to the chat
|
# Send the response back to the chat
|
||||||
await ctx.send(response_prefix)
|
await ctx.send(response_prefix)
|
||||||
while generated_text:
|
while bot_response:
|
||||||
send_chunk = generated_text[:1000]
|
send_chunk = bot_response[:1000]
|
||||||
generated_text = generated_text[1000:]
|
bot_response = bot_response[1000:]
|
||||||
await ctx.send(send_chunk)
|
await ctx.send(send_chunk)
|
||||||
|
|
||||||
except requests.exceptions.HTTPError as e:
|
|
||||||
await ctx.send(f"Error: OpenAI API error - {e}")
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
await ctx.send("Error: Request timed out. Please try again.")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await ctx.send(f"Error: {str(e)}")
|
await ctx.send(f"Error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def call_llm(ctx, payload: dict) -> str:
|
|
||||||
# Check if API key is set
|
|
||||||
if not OPENAI_API_KEY:
|
|
||||||
await ctx.send(
|
|
||||||
"Error: OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable."
|
|
||||||
)
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Set headers
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize OpenAI client
|
|
||||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_ENDPOINT)
|
|
||||||
|
|
||||||
# Call OpenAI API
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model=payload["model"],
|
|
||||||
messages=payload["messages"],
|
|
||||||
max_tokens=MAX_COMPLETION_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract the generated text
|
|
||||||
generated_text = response.choices[0].message.content.strip()
|
|
||||||
print(generated_text)
|
|
||||||
|
|
||||||
return generated_text
|
|
||||||
|
|
||||||
except requests.exceptions.HTTPError as e:
|
|
||||||
await ctx.send(f"Error: OpenAI API error - {e}")
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
await ctx.send("Error: Request timed out. Please try again.")
|
|
||||||
except Exception as e:
|
|
||||||
await ctx.send(f"Error: {str(e)}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
# Run the bot
|
# Run the bot
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bot.run(DISCORD_TOKEN)
|
bot.run(DISCORD_TOKEN)
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
import os
|
|
||||||
import pytest
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# Try to load .env.test first, fallback to .env
|
|
||||||
env_test_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env.test')
|
|
||||||
env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env')
|
|
||||||
|
|
||||||
if os.path.exists(env_test_path):
|
|
||||||
load_dotenv(env_test_path)
|
|
||||||
print("✓ Loaded environment variables from .env.test")
|
|
||||||
elif os.path.exists(env_path):
|
|
||||||
load_dotenv(env_path)
|
|
||||||
print("✓ Loaded environment variables from .env")
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="session")
|
|
||||||
def verify_env_loaded():
|
|
||||||
"""Verify critical environment variables are loaded before tests run"""
|
|
||||||
required_vars = [
|
|
||||||
"DISCORD_TOKEN",
|
|
||||||
"OPENAI_API_ENDPOINT",
|
|
||||||
"IMAGE_GEN_ENDPOINT",
|
|
||||||
"IMAGE_EDIT_ENDPOINT"
|
|
||||||
]
|
|
||||||
|
|
||||||
missing_vars = [var for var in required_vars if var not in os.environ]
|
|
||||||
if missing_vars:
|
|
||||||
pytest.fail(f"Missing required environment variables: {', '.join(missing_vars)}")
|
|
||||||
|
|
||||||
yield
|
|
||||||
@@ -1,71 +1,112 @@
|
|||||||
# Tests all functions in the llama-wrapper.py file
|
# Tests all functions in the llama-wrapper.py file
|
||||||
# Run with: python -m pytest test_llama_wrapper.py -v
|
# Run with: python -m pytest test_llama_wrapper.py -v
|
||||||
|
|
||||||
from discord import message
|
|
||||||
import pytest
|
|
||||||
from ..llama_wrapper import (
|
from ..llama_wrapper import (
|
||||||
chat_completion_think,
|
chat_completion,
|
||||||
chat_completion_instruct,
|
chat_completion_instruct,
|
||||||
image_generation,
|
image_generation,
|
||||||
image_edit,
|
image_edit,
|
||||||
embeddings,
|
embedding,
|
||||||
)
|
)
|
||||||
from dotenv import load_dotenv
|
from ..config import (
|
||||||
import os
|
CHAT_ENDPOINT,
|
||||||
|
CHAT_MODEL,
|
||||||
OPENAI_API_CHAT_ENDPOINT = os.getenv(
|
CHAT_ENDPOINT_KEY,
|
||||||
"OPENAI_API_CHAT_ENDPOINT", "https://llama-cpp.reeselink.com"
|
IMAGE_EDIT_ENDPOINT,
|
||||||
|
IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
|
IMAGE_GEN_ENDPOINT,
|
||||||
|
IMAGE_GEN_ENDPOINT_KEY,
|
||||||
|
EMBEDDING_ENDPOINT,
|
||||||
|
EMBEDDING_ENDPOINT_KEY,
|
||||||
)
|
)
|
||||||
OPENAI_API_IMAGE_ENDPOINT = os.getenv("OPENAI_API_IMAGE_ENDPOINT")
|
from io import BytesIO
|
||||||
OPENAI_API_EDIT_ENDPOINT = os.getenv("OPENAI_API_EDIT_ENDPOINT")
|
import base64
|
||||||
OPENAI_API_EMBED_ENDPOINT = os.getenv("OPENAI_API_EMBED_ENDPOINT")
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Default models
|
|
||||||
DEFAULT_CHAT_MODEL = os.getenv("DEFAULT_CHAT_MODEL", "qwen3.5-35b-a3b")
|
TEMPDIR = Path(tempfile.mkdtemp())
|
||||||
DEFAULT_EMBED_MODEL = os.getenv("DEFAULT_EMBED_MODEL", "text-embedding-3-small")
|
|
||||||
DEFAULT_IMAGE_MODEL = os.getenv("DEFAULT_IMAGE_MODEL", "dall-e-3")
|
|
||||||
DEFAULT_EDIT_MODEL = os.getenv("DEFAULT_EDIT_MODEL", "dall-e-2")
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_think():
|
def test_chat_completion_think():
|
||||||
# This test will fail without an actual API endpoint
|
result = chat_completion(
|
||||||
# But it's here to show the structure
|
|
||||||
chat_completion_think(
|
|
||||||
system_prompt="You are a helpful assistant.",
|
system_prompt="You are a helpful assistant.",
|
||||||
user_prompt="Tell me about Everquest",
|
user_prompt="Tell me about Everquest",
|
||||||
openai_url=OPENAI_API_CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
openai_api_key="placeholder",
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
model=DEFAULT_CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_instruct():
|
def test_chat_completion_instruct():
|
||||||
# This test will fail without an actual API endpoint
|
result = chat_completion_instruct(
|
||||||
# But it's here to show the structure
|
|
||||||
chat_completion_instruct(
|
|
||||||
system_prompt="You are a helpful assistant.",
|
system_prompt="You are a helpful assistant.",
|
||||||
user_prompt="Tell me about Everquest",
|
user_prompt="Tell me about Everquest",
|
||||||
openai_url=OPENAI_API_CHAT_ENDPOINT,
|
openai_url=CHAT_ENDPOINT,
|
||||||
openai_api_key="placeholder",
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
model=DEFAULT_CHAT_MODEL,
|
model=CHAT_MODEL,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation():
|
def test_image_generation():
|
||||||
# This test will fail without an actual API endpoint
|
result = image_generation(
|
||||||
# But it's here to show the structure
|
prompt="Generate an image of a horse",
|
||||||
pass
|
openai_url=IMAGE_GEN_ENDPOINT,
|
||||||
|
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||||
|
)
|
||||||
|
with open("image-gen.png", "wb") as f:
|
||||||
|
f.write(base64.b64decode(result))
|
||||||
|
|
||||||
|
|
||||||
def test_image_edit():
|
def test_image_edit():
|
||||||
# This test will fail without an actual API endpoint
|
with open("image-gen.png", "rb") as f:
|
||||||
# But it's here to show the structure
|
image_data = BytesIO(f.read())
|
||||||
pass
|
result = image_edit(
|
||||||
|
image=image_data,
|
||||||
|
prompt="Paint the words 'horse' on the horse.",
|
||||||
|
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||||
|
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||||
|
)
|
||||||
|
with open("image-edit.png", "wb") as f:
|
||||||
|
f.write(base64.b64decode(result))
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(a, b):
|
||||||
|
"""
|
||||||
|
Close to 1: very similar
|
||||||
|
Close to 0: orthogonal
|
||||||
|
Close to -1: opposite
|
||||||
|
"""
|
||||||
|
a, b = np.array(a), np.array(b)
|
||||||
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||||
|
|
||||||
|
|
||||||
def test_embeddings():
|
def test_embeddings():
|
||||||
# This test will fail without an actual API endpoint
|
result1 = embedding(
|
||||||
# But it's here to show the structure
|
"this is a horse",
|
||||||
pass
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model="qwen3-embed-4b",
|
||||||
|
)
|
||||||
|
result2 = embedding(
|
||||||
|
"this is a horse also",
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model="qwen3-embed-4b",
|
||||||
|
)
|
||||||
|
result3 = embedding(
|
||||||
|
"this is a donkey",
|
||||||
|
openai_url=EMBEDDING_ENDPOINT,
|
||||||
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||||
|
model="qwen3-embed-4b",
|
||||||
|
)
|
||||||
|
similarity_1 = _cosine_similarity(result1, result2)
|
||||||
|
assert similarity_1 > 0.9
|
||||||
|
|
||||||
|
similarity_2 = _cosine_similarity(result1, result3)
|
||||||
|
assert similarity_2 < 0.5
|
||||||
|
|||||||
Reference in New Issue
Block a user