import sqlite3 import os from typing import Optional, List, Tuple from datetime import datetime import numpy as np from openai import OpenAI # Database configuration DB_PATH = os.getenv("DB_PATH", "chat_history.db") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "qwen3-embed-4b") EMBEDDING_DIMENSION = 2048 # Default for qwen3-embed-4b MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000")) SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7")) TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5")) # OpenAI configuration OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder") OPENAI_API_EMBED_ENDPOINT = os.getenv( "OPENAI_API_EMBED_ENDPOINT", "https://llama-embed.reeselink.com" ) class ChatDatabase: """SQLite database with RAG support for storing chat history using OpenAI embeddings.""" def __init__(self, db_path: str = DB_PATH): self.db_path = db_path self.client = OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key=OPENAI_API_KEY) self._initialize_database() def _initialize_database(self): """Initialize the SQLite database with required tables.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Create messages table cursor.execute( """ CREATE TABLE IF NOT EXISTS chat_messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, message_id TEXT UNIQUE, user_id TEXT, username TEXT, content TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, channel_id TEXT, guild_id TEXT ) """ ) # Create embeddings table for RAG cursor.execute( """ CREATE TABLE IF NOT EXISTS message_embeddings ( message_id TEXT PRIMARY KEY, embedding BLOB, FOREIGN KEY (message_id) REFERENCES chat_messages(message_id) ) """ ) # Create index for faster lookups cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id) """ ) conn.commit() conn.close() def _generate_embedding(self, text: str) -> List[float]: """Generate embedding for text using OpenAI API.""" try: response = self.client.embeddings.create( model=EMBEDDING_MODEL, input=text, encoding_format="float" ) # The embedding is returned as a nested list: [[embedding_values]] # We need to extract the inner list embedding_data = response[0].embedding if isinstance(embedding_data, list) and len(embedding_data) > 0: # The first element might be the embedding array itself or a nested list first_item = embedding_data[0] if isinstance(first_item, list): # Handle nested structure: [[values]] -> [values] return first_item else: # Handle direct structure: [values] return embedding_data return [] except Exception as e: print(f"Error generating embedding: {e}") return None def _vector_to_bytes(self, vector: List[float]) -> bytes: """Convert vector to bytes for SQLite storage.""" return np.array(vector, dtype=np.float32).tobytes() def _bytes_to_vector(self, blob: bytes) -> np.ndarray: """Convert bytes back to vector.""" return np.frombuffer(blob, dtype=np.float32) def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: """Calculate cosine similarity between two vectors.""" return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) def add_message( self, message_id: str, user_id: str, username: str, content: str, channel_id: Optional[str] = None, guild_id: Optional[str] = None, ) -> bool: """Add a message to the database and generate its embedding.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: # Insert message cursor.execute( """ INSERT OR REPLACE INTO chat_messages (message_id, user_id, username, content, channel_id, guild_id) VALUES (?, ?, ?, ?, ?, ?) """, (message_id, user_id, username, content, channel_id, guild_id), ) # Generate and store embedding embedding = self._generate_embedding(content) if embedding: cursor.execute( """ INSERT OR REPLACE INTO message_embeddings (message_id, embedding) VALUES (?, ?) """, (message_id, self._vector_to_bytes(embedding)), ) # Clean up old messages if exceeding limit self._cleanup_old_messages(cursor) conn.commit() return True except Exception as e: print(f"Error adding message: {e}") conn.rollback() return False finally: conn.close() def _cleanup_old_messages(self, cursor): """Remove old messages to stay within the limit.""" cursor.execute( """ SELECT COUNT(*) FROM chat_messages """ ) count = cursor.fetchone()[0] if count > MAX_HISTORY_MESSAGES: cursor.execute( """ DELETE FROM chat_messages WHERE id IN ( SELECT id FROM chat_messages ORDER BY timestamp ASC LIMIT ? ) """, (count - MAX_HISTORY_MESSAGES,), ) # Also remove corresponding embeddings cursor.execute( """ DELETE FROM message_embeddings WHERE message_id IN ( SELECT message_id FROM chat_messages ORDER BY timestamp ASC LIMIT ? ) """, (count - MAX_HISTORY_MESSAGES,), ) def get_recent_messages( self, limit: int = 10 ) -> List[Tuple[str, str, str, datetime]]: """Get recent messages from the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT message_id, username, content, timestamp FROM chat_messages ORDER BY timestamp DESC LIMIT ? """, (limit,), ) messages = cursor.fetchall() conn.close() return messages def search_similar_messages( self, query: str, top_k: int = TOP_K_RESULTS, min_similarity: float = SIMILARITY_THRESHOLD, ) -> List[Tuple[str, str, str, float]]: """Search for messages similar to the query using embeddings.""" query_embedding = self._generate_embedding(query) if not query_embedding: return [] query_vector = np.array(query_embedding, dtype=np.float32) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Join chat_messages and message_embeddings to get content and embeddings cursor.execute( """ SELECT cm.message_id, cm.content, me.embedding FROM chat_messages cm JOIN message_embeddings me ON cm.message_id = me.message_id """ ) rows = cursor.fetchall() results = [] for message_id, content, embedding_blob in rows: embedding_vector = self._bytes_to_vector(embedding_blob) similarity = self._calculate_similarity(query_vector, embedding_vector) if similarity >= min_similarity: results.append( (message_id, content[:500], similarity) ) # Limit content length conn.close() # Sort by similarity and return top results results.sort(key=lambda x: x[2], reverse=True) return results[:top_k] def get_user_history( self, user_id: str, limit: int = 20 ) -> List[Tuple[str, str, datetime]]: """Get message history for a specific user.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT message_id, content, timestamp FROM chat_messages WHERE user_id = ? ORDER BY timestamp DESC LIMIT ? """, (user_id, limit), ) messages = cursor.fetchall() conn.close() return messages def get_conversation_context( self, user_id: str, current_message: str, max_context: int = 5 ) -> str: """Get relevant conversation context for RAG.""" # Get recent messages from the user recent_messages = self.get_user_history(user_id, limit=max_context * 2) # Search for similar messages similar_messages = self.search_similar_messages( current_message, top_k=max_context ) # Combine contexts context_parts = [] # Add recent messages for message_id, content, timestamp in recent_messages: context_parts.append(f"[{timestamp}] User: {content}") # Add similar messages for message_id, content, similarity in similar_messages: if f"[{content}" not in "\n".join(context_parts): # Avoid duplicates context_parts.append(f"[Similar] {content}") return "\n".join(context_parts[-max_context * 2 :]) # Limit total context def clear_all_messages(self): """Clear all messages and embeddings from the database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("DELETE FROM message_embeddings") cursor.execute("DELETE FROM chat_messages") conn.commit() conn.close() # Global database instance _chat_db: Optional[ChatDatabase] = None def get_database() -> ChatDatabase: """Get or create the global database instance.""" global _chat_db if _chat_db is None: _chat_db = ChatDatabase() return _chat_db class CustomBotManager: """Manages custom bot configurations stored in SQLite database.""" def __init__(self, db_path: str = DB_PATH): self.db_path = db_path self._initialize_custom_bots_table() def _initialize_custom_bots_table(self): """Initialize the custom bots table in SQLite.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS custom_bots ( bot_name TEXT PRIMARY KEY, system_prompt TEXT NOT NULL, created_by TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_active INTEGER DEFAULT 1 ) """ ) conn.commit() conn.close() def create_custom_bot( self, bot_name: str, system_prompt: str, created_by: str ) -> bool: """Create a new custom bot configuration.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ INSERT OR REPLACE INTO custom_bots (bot_name, system_prompt, created_by, is_active) VALUES (?, ?, ?, 1) """, (bot_name.lower(), system_prompt, created_by), ) conn.commit() return True except Exception as e: print(f"Error creating custom bot: {e}") conn.rollback() return False finally: conn.close() def get_custom_bot(self, bot_name: str) -> Optional[Tuple[str, str, str, datetime]]: """Get a custom bot configuration by name.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( """ SELECT bot_name, system_prompt, created_by, created_at FROM custom_bots WHERE bot_name = ? AND is_active = 1 """, (bot_name.lower(),), ) result = cursor.fetchone() conn.close() return result def list_custom_bots( self, user_id: Optional[str] = None ) -> List[Tuple[str, str, str]]: """List all custom bots, optionally filtered by creator.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() if user_id: cursor.execute( """ SELECT bot_name, system_prompt, created_by FROM custom_bots WHERE is_active = 1 ORDER BY created_at DESC """ ) else: cursor.execute( """ SELECT bot_name, system_prompt, created_by FROM custom_bots WHERE is_active = 1 ORDER BY created_at DESC """ ) bots = cursor.fetchall() conn.close() return bots def delete_custom_bot(self, bot_name: str) -> bool: """Delete a custom bot configuration.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ DELETE FROM custom_bots WHERE bot_name = ? """, (bot_name.lower(),), ) conn.commit() return cursor.rowcount > 0 except Exception as e: print(f"Error deleting custom bot: {e}") conn.rollback() return False finally: conn.close() def deactivate_custom_bot(self, bot_name: str) -> bool: """Deactivate a custom bot (soft delete).""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() try: cursor.execute( """ UPDATE custom_bots SET is_active = 0 WHERE bot_name = ? """, (bot_name.lower(),), ) conn.commit() return cursor.rowcount > 0 except Exception as e: print(f"Error deactivating custom bot: {e}") conn.rollback() return False finally: conn.close()