"""SQLite database with RAG support for chat history and embeddings.""" from __future__ import annotations import logging import sqlite3 from typing import TYPE_CHECKING import numpy as np from openai import OpenAI from vibe_bot import llama_wrapper from vibe_bot.config import ( DB_PATH, EMBEDDING_ENDPOINT, EMBEDDING_ENDPOINT_KEY, EMBEDDING_MODEL, MAX_HISTORY_MESSAGES, SIMILARITY_THRESHOLD, TOP_K_RESULTS, ) if TYPE_CHECKING: from datetime import datetime # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) class ChatDatabase: """SQLite database with RAG support for storing chat history using OpenAI embeddings. """ def __init__(self, db_path: str = DB_PATH) -> None: """Initialize the database connection. Args: db_path: Path to the SQLite database file. """ logger.info("Initializing ChatDatabase with path: %s", db_path) self.db_path = db_path self.client = OpenAI( base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY, ) logger.info("Connecting to OpenAI API for embeddings") self._initialize_database() def _initialize_database(self) -> None: """Initialize the SQLite database with required tables.""" logger.info("Initializing SQLite database at %s", self.db_path) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create messages table logger.info("Creating chat_messages table if not exists") cursor.execute( """ CREATE TABLE IF NOT EXISTS chat_messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, message_id TEXT UNIQUE, user_id TEXT, username TEXT, content TEXT, bot_name TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, channel_id TEXT, guild_id TEXT ) """, ) logger.info("chat_messages table initialized successfully") # Migrate: add bot_name column if it doesn't exist logger.info("Checking for bot_name column migration") cursor.execute("PRAGMA table_info(chat_messages)") columns = [row[1] for row in cursor.fetchall()] if "bot_name" not in columns: logger.info("Adding bot_name column to chat_messages table") cursor.execute( "ALTER TABLE chat_messages ADD COLUMN bot_name TEXT", ) logger.info("bot_name column added successfully") # Create embeddings table for RAG logger.info("Creating message_embeddings table if not exists") cursor.execute( """ CREATE TABLE IF NOT EXISTS message_embeddings ( message_id TEXT PRIMARY KEY, embedding BLOB, FOREIGN KEY (message_id) REFERENCES chat_messages(message_id) ) """, ) logger.info("message_embeddings table initialized successfully") # Create index for faster lookups logger.info("Creating idx_timestamp index if not exists") cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp) """, ) logger.info("idx_timestamp index created successfully") logger.info("Creating idx_user_id index if not exists") cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id) """, ) logger.info("idx_user_id index created successfully") conn.commit() logger.info("Database initialization completed successfully") conn.close() def _vector_to_bytes(self, vector: list[float]) -> bytes: """Convert vector to bytes for SQLite storage.""" logger.debug("Converting vector (length: %d) to bytes", len(vector)) result = np.array(vector, dtype=np.float32).tobytes() logger.debug("Vector converted to %d bytes", len(result)) return result def _bytes_to_vector(self, blob: bytes) -> np.ndarray: """Convert bytes back to vector.""" logger.debug("Converting %d bytes back to vector", len(blob)) result = np.frombuffer(blob, dtype=np.float32) logger.debug("Vector reconstructed with %d dimensions", len(result)) return result def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: """Calculate cosine similarity between two vectors.""" vec1 = vec1.flatten() vec2 = vec2.flatten() logger.debug( "Calculating cosine similarity between vectors of dimension %d", len(vec1), ) norm1 = np.linalg.norm(vec1) norm2 = np.linalg.norm(vec2) if norm1 == 0 or norm2 == 0: return 0.0 result = float(np.dot(vec1, vec2) / (norm1 * norm2)) logger.debug("Similarity calculated: %.4f", result) return result def add_message( self, *, message_id: str, user_id: str, username: str, content: str, bot_name: str | None = None, channel_id: str | None = None, guild_id: str | None = None, ) -> bool: """Add a message to the database and generate its embedding.""" logger.info("Adding message %s from user %s", message_id, username) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: # Insert message logger.debug( "Inserting message into chat_messages table: message_id=%s", message_id, ) cursor.execute( """ INSERT OR REPLACE INTO chat_messages (message_id, user_id, username, content, bot_name, channel_id, guild_id) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( message_id, user_id, username, content, bot_name, channel_id, guild_id, ), ) logger.debug("Message %s inserted into chat_messages table", message_id) # Generate and store embedding logger.info("Generating embedding for message %s", message_id) embedding = llama_wrapper.embedding( content, openai_url=EMBEDDING_ENDPOINT, openai_api_key=EMBEDDING_ENDPOINT_KEY, model=EMBEDDING_MODEL, ) if embedding: logger.debug( "Embedding generated successfully for message %s, " "storing in database", message_id, ) cursor.execute( """ INSERT OR REPLACE INTO message_embeddings (message_id, embedding) VALUES (?, ?) """, (message_id, self._vector_to_bytes(embedding)), ) logger.debug( "Embedding stored in message_embeddings table for message %s", message_id, ) else: logger.warning( "Failed to generate embedding for message %s, " "skipping embedding storage", message_id, ) # Clean up old messages if exceeding limit logger.info("Checking if cleanup of old messages is needed") self._cleanup_old_messages(cursor) conn.commit() except Exception: logger.exception("Error adding message %s", message_id) conn.rollback() return False else: logger.info("Successfully added message %s to database", message_id) return True finally: conn.close() def _cleanup_old_messages(self, cursor: sqlite3.Cursor) -> None: """Remove old messages to stay within the limit.""" cursor.execute( """ SELECT COUNT(*) FROM chat_messages """, ) count = cursor.fetchone()[0] if count > MAX_HISTORY_MESSAGES: cursor.execute( """ DELETE FROM chat_messages WHERE id IN ( SELECT id FROM chat_messages ORDER BY timestamp ASC LIMIT ? ) """, (count - MAX_HISTORY_MESSAGES,), ) # Also remove corresponding embeddings cursor.execute( """ DELETE FROM message_embeddings WHERE message_id IN ( SELECT message_id FROM chat_messages ORDER BY timestamp ASC LIMIT ? ) """, (count - MAX_HISTORY_MESSAGES,), ) def get_recent_messages( self, limit: int = 10, ) -> list[tuple[str, str, str, datetime]]: """Get recent messages from the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT message_id, username, content, timestamp FROM chat_messages ORDER BY timestamp DESC LIMIT ? """, (limit,), ) messages = cursor.fetchall() conn.close() return messages def search_similar_messages( self, query: str, top_k: int = TOP_K_RESULTS, min_similarity: float = SIMILARITY_THRESHOLD, ) -> list[tuple[str, str, float]]: """Search for messages similar to the query using embeddings.""" query_embedding = llama_wrapper.embedding( text=query, model=EMBEDDING_MODEL, openai_url=EMBEDDING_ENDPOINT, openai_api_key=EMBEDDING_ENDPOINT_KEY, ) if not query_embedding: return [] query_vector = np.array(query_embedding, dtype=np.float32) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Join chat_messages and message_embeddings to get content and embeddings cursor.execute( """ SELECT cm.message_id, cm.content, me.embedding FROM chat_messages cm JOIN message_embeddings me ON cm.message_id = me.message_id WHERE cm.username != 'vibe-bot' """, ) rows = cursor.fetchall() results: list[tuple[str, str, float]] = [] for message_id, content, embedding_blob in rows: embedding_vector = self._bytes_to_vector(embedding_blob) similarity = self._calculate_similarity(query_vector, embedding_vector) if similarity >= min_similarity: cursor.execute( """ SELECT content FROM chat_messages WHERE message_id = ? ORDER BY timestamp DESC """, (f"{message_id}_response",), ) response_row = cursor.fetchone() if response_row: results.append((content, response_row[0], similarity)) conn.close() # Sort by similarity and return top results results.sort(key=lambda x: x[2], reverse=True) return results[:top_k] def get_bot_history(self, bot_name: str, limit: int = 20) -> list[tuple[str, str]]: """Get message history for a specific custom bot. Args: bot_name: The name of the custom bot. limit: Maximum number of messages to retrieve. Returns: List of (user_message, bot_response) tuples. """ conn = sqlite3.connect(self.db_path) cursor = conn.cursor() logger.info( "Fetching last %d messages for bot %r", limit, bot_name, ) cursor.execute( """ SELECT message_id, content, timestamp FROM chat_messages WHERE bot_name = ? AND message_id NOT LIKE '%%_response' ORDER BY timestamp DESC LIMIT ? """, (bot_name, limit), ) messages = cursor.fetchall() conversations: list[tuple[str, str]] = [] for message in messages: msg_content = message[1] logger.debug("Finding response for %s...", msg_content[:50]) cursor.execute( """ SELECT content FROM chat_messages WHERE message_id = ? ORDER BY timestamp DESC """, (f"{message[0]}_response",), ) response_row = cursor.fetchone() if response_row: logger.debug("Found response: %s...", response_row[0][:50]) conversations.append((msg_content, response_row[0])) else: logger.debug("No response found") conn.close() return conversations def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]: """Get message history for a specific user.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() logger.info("Fetching last %d user messages", limit) cursor.execute( """ SELECT message_id, content, timestamp FROM chat_messages WHERE user_id = ? AND username != 'vibe-bot' ORDER BY timestamp DESC LIMIT ? """, (user_id, limit), ) messages = cursor.fetchall() # Format is [user message, bot response] conversations: list[tuple[str, str]] = [] for message in messages: msg_content = message[1] logger.debug("Finding response for %s...", msg_content[:50]) cursor.execute( """ SELECT content FROM chat_messages WHERE message_id = ? ORDER BY timestamp DESC """, (f"{message[0]}_response",), ) response_row = cursor.fetchone() if response_row: logger.debug("Found response: %s...", response_row[0][:50]) conversations.append((msg_content, response_row[0])) else: logger.debug("No response found") conn.close() return conversations def get_conversation_context( self, user_id: str, current_message: str, max_context: int = 5, ) -> list[dict[str, str]]: """Get relevant conversation context for RAG.""" # Get recent messages from the user recent_messages = self.get_user_history(user_id, limit=max_context * 2) # Search for similar messages similar_messages = self.search_similar_messages( current_message, top_k=max_context, ) # Combine contexts context_parts: list[dict[str, str]] = [] # Add recent messages for user_message, bot_message in recent_messages: context_parts.append({"role": "assistant", "content": bot_message}) context_parts.append({"role": "user", "content": user_message}) # Add similar messages for user_message, bot_message, _similarity in similar_messages: context_parts.append({"role": "assistant", "content": bot_message}) context_parts.append({"role": "user", "content": user_message}) # Conversation history needs to be delivered in "newest context last" order context_parts.reverse() return context_parts def clear_all_messages(self) -> None: """Clear all messages and embeddings from the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("DELETE FROM message_embeddings") cursor.execute("DELETE FROM chat_messages") conn.commit() conn.close() # Global database instance _chat_db: ChatDatabase | None = None def get_database() -> ChatDatabase: """Get or create the global database instance.""" global _chat_db # noqa: PLW0603 if _chat_db is None: _chat_db = ChatDatabase() return _chat_db class CustomBotManager: """Manages custom bot configurations stored in SQLite database.""" def __init__(self, db_path: str = DB_PATH) -> None: """Initialize the custom bot manager. Args: db_path: Path to the SQLite database file. """ self.db_path = db_path self._initialize_custom_bots_table() def _initialize_custom_bots_table(self) -> None: """Initialize the custom bots table in SQLite.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create table to hold custom bots cursor.execute( """ CREATE TABLE IF NOT EXISTS custom_bots ( bot_name TEXT PRIMARY KEY, system_prompt TEXT NOT NULL, created_by TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_active INTEGER DEFAULT 1 ) """, ) conn.commit() conn.close() def create_custom_bot( self, bot_name: str, system_prompt: str, created_by: str, ) -> bool: """Create a new custom bot configuration.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ INSERT OR REPLACE INTO custom_bots (bot_name, system_prompt, created_by, is_active) VALUES (?, ?, ?, 1) """, (bot_name.lower(), system_prompt, created_by), ) conn.commit() except Exception: logger.exception("Error creating custom bot") conn.rollback() return False else: return True finally: conn.close() def get_custom_bot(self, bot_name: str) -> tuple[str, str, str, datetime] | None: """Get a custom bot configuration by name.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT bot_name, system_prompt, created_by, created_at FROM custom_bots WHERE bot_name = ? AND is_active = 1 """, (bot_name.lower(),), ) result = cursor.fetchone() conn.close() if result is None: return None return (result[0], result[1], result[2], result[3]) def list_custom_bots( self, user_id: str | None = None, ) -> list[tuple[str, str, str]]: """List all custom bots, optionally filtered by creator.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() if user_id: cursor.execute( """ SELECT bot_name, system_prompt, created_by FROM custom_bots WHERE is_active = 1 AND created_by = ? ORDER BY created_at DESC """, (user_id,), ) else: cursor.execute( """ SELECT bot_name, system_prompt, created_by FROM custom_bots WHERE is_active = 1 ORDER BY created_at DESC """, ) bots = cursor.fetchall() conn.close() return bots def delete_custom_bot(self, bot_name: str) -> bool: """Delete a custom bot configuration.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ DELETE FROM custom_bots WHERE bot_name = ? """, (bot_name.lower(),), ) conn.commit() except Exception: logger.exception("Error deleting custom bot") conn.rollback() return False else: return cursor.rowcount > 0 finally: conn.close() def deactivate_custom_bot(self, bot_name: str) -> bool: """Deactivate a custom bot (soft delete).""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ UPDATE custom_bots SET is_active = 0 WHERE bot_name = ? """, (bot_name.lower(),), ) conn.commit() except Exception: logger.exception("Error deactivating custom bot") conn.rollback() return False else: return cursor.rowcount > 0 finally: conn.close()