import sqlite3 from typing import Optional, List, Tuple from datetime import datetime import numpy as np from openai import OpenAI import logging import llama_wrapper # type: ignore from config import ( # type: ignore DB_PATH, EMBEDDING_MODEL, EMBEDDING_ENDPOINT, EMBEDDING_ENDPOINT_KEY, MAX_HISTORY_MESSAGES, SIMILARITY_THRESHOLD, TOP_K_RESULTS, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class ChatDatabase: """SQLite database with RAG support for storing chat history using OpenAI embeddings.""" def __init__(self, db_path: str = DB_PATH): logger.info(f"Initializing ChatDatabase with path: {db_path}") self.db_path = db_path self.client = OpenAI( base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY ) logger.info("Connecting to OpenAI API for embeddings") self._initialize_database() def _initialize_database(self): """Initialize the SQLite database with required tables.""" logger.info(f"Initializing SQLite database at {self.db_path}") conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create messages table logger.info("Creating chat_messages table if not exists") cursor.execute( """ CREATE TABLE IF NOT EXISTS chat_messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, message_id TEXT UNIQUE, user_id TEXT, username TEXT, content TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, channel_id TEXT, guild_id TEXT ) """ ) logger.info("chat_messages table initialized successfully") # Create embeddings table for RAG logger.info("Creating message_embeddings table if not exists") cursor.execute( """ CREATE TABLE IF NOT EXISTS message_embeddings ( message_id TEXT PRIMARY KEY, embedding BLOB, FOREIGN KEY (message_id) REFERENCES chat_messages(message_id) ) """ ) logger.info("message_embeddings table initialized successfully") # Create index for faster lookups logger.info("Creating idx_timestamp index if not exists") cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp) """ ) logger.info("idx_timestamp index created successfully") logger.info("Creating idx_user_id index if not exists") cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id) """ ) logger.info("idx_user_id index created successfully") conn.commit() logger.info("Database initialization completed successfully") conn.close() def _vector_to_bytes(self, vector: List[float]) -> bytes: """Convert vector to bytes for SQLite storage.""" logger.debug(f"Converting vector (length: {len(vector)}) to bytes") result = np.array(vector, dtype=np.float32).tobytes() logger.debug(f"Vector converted to {len(result)} bytes") return result def _bytes_to_vector(self, blob: bytes) -> np.ndarray: """Convert bytes back to vector.""" logger.debug(f"Converting {len(blob)} bytes back to vector") result = np.frombuffer(blob, dtype=np.float32) logger.debug(f"Vector reconstructed with {len(result)} dimensions") return result def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: """Calculate cosine similarity between two vectors.""" logger.debug( f"Calculating cosine similarity between vectors of dimension {len(vec1)}" ) result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) logger.debug(f"Similarity calculated: {result:.4f}") return result def add_message( self, message_id: str, user_id: str, username: str, content: str, channel_id: Optional[str] = None, guild_id: Optional[str] = None, ) -> bool: """Add a message to the database and generate its embedding.""" logger.info(f"Adding message {message_id} from user {username}") conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: # Insert message logger.debug( f"Inserting message into chat_messages table: message_id={message_id}" ) cursor.execute( """ INSERT OR REPLACE INTO chat_messages (message_id, user_id, username, content, channel_id, guild_id) VALUES (?, ?, ?, ?, ?, ?) """, (message_id, user_id, username, content, channel_id, guild_id), ) logger.debug(f"Message {message_id} inserted into chat_messages table") # Generate and store embedding logger.info(f"Generating embedding for message {message_id}") embedding = llama_wrapper.embedding( content, openai_url=EMBEDDING_ENDPOINT, openai_api_key=EMBEDDING_ENDPOINT_KEY, model=EMBEDDING_MODEL, ) if embedding: logger.debug( f"Embedding generated successfully for message {message_id}, storing in database" ) cursor.execute( """ INSERT OR REPLACE INTO message_embeddings (message_id, embedding) VALUES (?, ?) """, (message_id, self._vector_to_bytes(embedding)), ) logger.debug( f"Embedding stored in message_embeddings table for message {message_id}" ) else: logger.warning( f"Failed to generate embedding for message {message_id}, skipping embedding storage" ) # Clean up old messages if exceeding limit logger.info("Checking if cleanup of old messages is needed") self._cleanup_old_messages(cursor) conn.commit() logger.info(f"Successfully added message {message_id} to database") return True except Exception as e: logger.error(f"Error adding message {message_id}: {e}") conn.rollback() return False finally: conn.close() def _cleanup_old_messages(self, cursor): """Remove old messages to stay within the limit.""" cursor.execute( """ SELECT COUNT(*) FROM chat_messages """ ) count = cursor.fetchone()[0] if count > MAX_HISTORY_MESSAGES: cursor.execute( """ DELETE FROM chat_messages WHERE id IN ( SELECT id FROM chat_messages ORDER BY timestamp ASC LIMIT ? ) """, (count - MAX_HISTORY_MESSAGES,), ) # Also remove corresponding embeddings cursor.execute( """ DELETE FROM message_embeddings WHERE message_id IN ( SELECT message_id FROM chat_messages ORDER BY timestamp ASC LIMIT ? ) """, (count - MAX_HISTORY_MESSAGES,), ) def get_recent_messages( self, limit: int = 10 ) -> List[Tuple[str, str, str, datetime]]: """Get recent messages from the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT message_id, username, content, timestamp FROM chat_messages ORDER BY timestamp DESC LIMIT ? """, (limit,), ) messages = cursor.fetchall() conn.close() return messages def search_similar_messages( self, query: str, top_k: int = TOP_K_RESULTS, min_similarity: float = SIMILARITY_THRESHOLD, ) -> List[Tuple[str, str, float]]: """Search for messages similar to the query using embeddings.""" query_embedding = llama_wrapper.embedding( text=query, model=EMBEDDING_MODEL, openai_url=EMBEDDING_ENDPOINT, openai_api_key=EMBEDDING_ENDPOINT_KEY, ) if not query_embedding: return [] query_vector = np.array(query_embedding, dtype=np.float32) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Join chat_messages and message_embeddings to get content and embeddings cursor.execute( """ SELECT cm.message_id, cm.content, me.embedding FROM chat_messages cm JOIN message_embeddings me ON cm.message_id = me.message_id WHERE cm.username != 'vibe-bot' """ ) rows = cursor.fetchall() results: list[tuple[str, str, float]] = [] for message_id, content, embedding_blob in rows: embedding_vector = self._bytes_to_vector(embedding_blob) similarity = self._calculate_similarity(query_vector, embedding_vector) if similarity >= min_similarity: cursor.execute( """ SELECT content FROM chat_messages WHERE message_id = ? ORDER BY timestamp DESC """, (f"{message_id}_response",), ) response: str = cursor.fetchone()[0] results.append((content, response, similarity)) conn.close() # Sort by similarity and return top results results.sort(key=lambda x: x[2], reverse=True) return results[:top_k] def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]: """Get message history for a specific user.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() logger.info(f"Fetching last {limit} user messages") cursor.execute( """ SELECT message_id, content, timestamp FROM chat_messages WHERE username != 'vibe-bot' ORDER BY timestamp DESC LIMIT ? """, (limit,), ) messages = cursor.fetchall() # Format is [user message, bot response] conversations: list[tuple[str, str]] = [] for message in messages: msg_content: str = message[1] logger.info(f"Finding response for {msg_content[:50]}") cursor.execute( """ SELECT content FROM chat_messages WHERE message_id = ? ORDER BY timestamp DESC """, (f"{message[0]}_response",), ) response_content: str = cursor.fetchone() if response_content: logger.info(f"Found response: {response_content[0][:50]}") conversations.append((msg_content, response_content[0])) else: logger.info("No response found") conn.close() return conversations def get_conversation_context( self, user_id: str, current_message: str, max_context: int = 5 ) -> str: """Get relevant conversation context for RAG.""" # Get recent messages from the user recent_messages = self.get_user_history(user_id, limit=max_context * 2) # Search for similar messages similar_messages = self.search_similar_messages( current_message, top_k=max_context ) # Combine contexts context_parts = [] # Add recent messages for user_message, bot_message in recent_messages: combined_content = f"[Recent chat]\n{user_message}\n{bot_message}" context_parts.append(combined_content) # Add similar messages for user_message, bot_message, similarity in similar_messages: combined_content = f"{user_message}\n{bot_message}" if combined_content not in "\n".join(context_parts): context_parts.append(f"[You remember]\n{combined_content}") # Conversation history needs to be delivered in "newest context last" order context_parts.reverse() return "\n".join(context_parts[-max_context * 4 :]) # Limit total context def clear_all_messages(self): """Clear all messages and embeddings from the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("DELETE FROM message_embeddings") cursor.execute("DELETE FROM chat_messages") conn.commit() conn.close() # Global database instance _chat_db: Optional[ChatDatabase] = None def get_database() -> ChatDatabase: """Get or create the global database instance.""" global _chat_db if _chat_db is None: _chat_db = ChatDatabase() return _chat_db class CustomBotManager: """Manages custom bot configurations stored in SQLite database.""" def __init__(self, db_path: str = DB_PATH): self.db_path = db_path self._initialize_custom_bots_table() def _initialize_custom_bots_table(self): """Initialize the custom bots table in SQLite.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create table to hold custom bots cursor.execute( """ CREATE TABLE IF NOT EXISTS custom_bots ( bot_name TEXT PRIMARY KEY, system_prompt TEXT NOT NULL, created_by TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_active INTEGER DEFAULT 1 ) """ ) conn.commit() conn.close() def create_custom_bot( self, bot_name: str, system_prompt: str, created_by: str ) -> bool: """Create a new custom bot configuration.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ INSERT OR REPLACE INTO custom_bots (bot_name, system_prompt, created_by, is_active) VALUES (?, ?, ?, 1) """, (bot_name.lower(), system_prompt, created_by), ) conn.commit() return True except Exception as e: print(f"Error creating custom bot: {e}") conn.rollback() return False finally: conn.close() def get_custom_bot(self, bot_name: str) -> Optional[Tuple[str, str, str, datetime]]: """Get a custom bot configuration by name.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT bot_name, system_prompt, created_by, created_at FROM custom_bots WHERE bot_name = ? AND is_active = 1 """, (bot_name.lower(),), ) result = cursor.fetchone() conn.close() return result def list_custom_bots( self, user_id: Optional[str] = None ) -> List[Tuple[str, str, str]]: """List all custom bots, optionally filtered by creator.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() if user_id: cursor.execute( """ SELECT bot_name, system_prompt, name FROM custom_bots cb, username_map um JOIN username_map ON custom_bots.created_by = username_map.id WHERE is_active = 1 ORDER BY created_at DESC """ ) else: cursor.execute( """ SELECT bot_name, system_prompt, created_by FROM custom_bots WHERE is_active = 1 ORDER BY created_at DESC """ ) bots = cursor.fetchall() conn.close() return bots def delete_custom_bot(self, bot_name: str) -> bool: """Delete a custom bot configuration.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ DELETE FROM custom_bots WHERE bot_name = ? """, (bot_name.lower(),), ) conn.commit() return cursor.rowcount > 0 except Exception as e: print(f"Error deleting custom bot: {e}") conn.rollback() return False finally: conn.close() def deactivate_custom_bot(self, bot_name: str) -> bool: """Deactivate a custom bot (soft delete).""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ UPDATE custom_bots SET is_active = 0 WHERE bot_name = ? """, (bot_name.lower(),), ) conn.commit() return cursor.rowcount > 0 except Exception as e: print(f"Error deactivating custom bot: {e}") conn.rollback() return False finally: conn.close()