Files
vibe-bot/vibe_bot/database.py
2026-03-09 22:36:04 -04:00

555 lines
18 KiB
Python

import sqlite3
from typing import Optional, List, Tuple
from datetime import datetime
import numpy as np
from openai import OpenAI
import logging
import llama_wrapper # type: ignore
from config import ( # type: ignore
DB_PATH,
EMBEDDING_MODEL,
EMBEDDING_ENDPOINT,
EMBEDDING_ENDPOINT_KEY,
MAX_HISTORY_MESSAGES,
SIMILARITY_THRESHOLD,
TOP_K_RESULTS,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ChatDatabase:
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
def __init__(self, db_path: str = DB_PATH):
logger.info(f"Initializing ChatDatabase with path: {db_path}")
self.db_path = db_path
self.client = OpenAI(
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
)
logger.info("Connecting to OpenAI API for embeddings")
self._initialize_database()
def _initialize_database(self):
"""Initialize the SQLite database with required tables."""
logger.info(f"Initializing SQLite database at {self.db_path}")
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create messages table
logger.info("Creating chat_messages table if not exists")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS chat_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT UNIQUE,
user_id TEXT,
username TEXT,
content TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
channel_id TEXT,
guild_id TEXT
)
"""
)
logger.info("chat_messages table initialized successfully")
# Create embeddings table for RAG
logger.info("Creating message_embeddings table if not exists")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS message_embeddings (
message_id TEXT PRIMARY KEY,
embedding BLOB,
FOREIGN KEY (message_id) REFERENCES chat_messages(message_id)
)
"""
)
logger.info("message_embeddings table initialized successfully")
# Create index for faster lookups
logger.info("Creating idx_timestamp index if not exists")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
"""
)
logger.info("idx_timestamp index created successfully")
logger.info("Creating idx_user_id index if not exists")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
"""
)
logger.info("idx_user_id index created successfully")
conn.commit()
logger.info("Database initialization completed successfully")
conn.close()
def _vector_to_bytes(self, vector: List[float]) -> bytes:
"""Convert vector to bytes for SQLite storage."""
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
result = np.array(vector, dtype=np.float32).tobytes()
logger.debug(f"Vector converted to {len(result)} bytes")
return result
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
"""Convert bytes back to vector."""
logger.debug(f"Converting {len(blob)} bytes back to vector")
result = np.frombuffer(blob, dtype=np.float32)
logger.debug(f"Vector reconstructed with {len(result)} dimensions")
return result
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors."""
logger.debug(
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
)
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
logger.debug(f"Similarity calculated: {result:.4f}")
return result
def add_message(
self,
message_id: str,
user_id: str,
username: str,
content: str,
channel_id: Optional[str] = None,
guild_id: Optional[str] = None,
) -> bool:
"""Add a message to the database and generate its embedding."""
logger.info(f"Adding message {message_id} from user {username}")
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
# Insert message
logger.debug(
f"Inserting message into chat_messages table: message_id={message_id}"
)
cursor.execute(
"""
INSERT OR REPLACE INTO chat_messages
(message_id, user_id, username, content, channel_id, guild_id)
VALUES (?, ?, ?, ?, ?, ?)
""",
(message_id, user_id, username, content, channel_id, guild_id),
)
logger.debug(f"Message {message_id} inserted into chat_messages table")
# Generate and store embedding
logger.info(f"Generating embedding for message {message_id}")
embedding = llama_wrapper.embedding(
content,
openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
model=EMBEDDING_MODEL,
)
if embedding:
logger.debug(
f"Embedding generated successfully for message {message_id}, storing in database"
)
cursor.execute(
"""
INSERT OR REPLACE INTO message_embeddings
(message_id, embedding)
VALUES (?, ?)
""",
(message_id, self._vector_to_bytes(embedding)),
)
logger.debug(
f"Embedding stored in message_embeddings table for message {message_id}"
)
else:
logger.warning(
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
)
# Clean up old messages if exceeding limit
logger.info("Checking if cleanup of old messages is needed")
self._cleanup_old_messages(cursor)
conn.commit()
logger.info(f"Successfully added message {message_id} to database")
return True
except Exception as e:
logger.error(f"Error adding message {message_id}: {e}")
conn.rollback()
return False
finally:
conn.close()
def _cleanup_old_messages(self, cursor):
"""Remove old messages to stay within the limit."""
cursor.execute(
"""
SELECT COUNT(*) FROM chat_messages
"""
)
count = cursor.fetchone()[0]
if count > MAX_HISTORY_MESSAGES:
cursor.execute(
"""
DELETE FROM chat_messages
WHERE id IN (
SELECT id FROM chat_messages
ORDER BY timestamp ASC
LIMIT ?
)
""",
(count - MAX_HISTORY_MESSAGES,),
)
# Also remove corresponding embeddings
cursor.execute(
"""
DELETE FROM message_embeddings
WHERE message_id IN (
SELECT message_id FROM chat_messages
ORDER BY timestamp ASC
LIMIT ?
)
""",
(count - MAX_HISTORY_MESSAGES,),
)
def get_recent_messages(
self, limit: int = 10
) -> List[Tuple[str, str, str, datetime]]:
"""Get recent messages from the database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT message_id, username, content, timestamp
FROM chat_messages
ORDER BY timestamp DESC
LIMIT ?
""",
(limit,),
)
messages = cursor.fetchall()
conn.close()
return messages
def search_similar_messages(
self,
query: str,
top_k: int = TOP_K_RESULTS,
min_similarity: float = SIMILARITY_THRESHOLD,
) -> List[Tuple[str, str, float]]:
"""Search for messages similar to the query using embeddings."""
query_embedding = llama_wrapper.embedding(
text=query,
model=EMBEDDING_MODEL,
openai_url=EMBEDDING_ENDPOINT,
openai_api_key=EMBEDDING_ENDPOINT_KEY,
)
if not query_embedding:
return []
query_vector = np.array(query_embedding, dtype=np.float32)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Join chat_messages and message_embeddings to get content and embeddings
cursor.execute(
"""
SELECT cm.message_id, cm.content, me.embedding
FROM chat_messages cm
JOIN message_embeddings me ON cm.message_id = me.message_id
WHERE cm.username != 'vibe-bot'
"""
)
rows = cursor.fetchall()
results: list[tuple[str, str, float]] = []
for message_id, content, embedding_blob in rows:
embedding_vector = self._bytes_to_vector(embedding_blob)
similarity = self._calculate_similarity(query_vector, embedding_vector)
if similarity >= min_similarity:
cursor.execute(
"""
SELECT content
FROM chat_messages
WHERE message_id = ?
ORDER BY timestamp DESC
""",
(f"{message_id}_response",),
)
response: str = cursor.fetchone()[0]
results.append((content, response, similarity))
conn.close()
# Sort by similarity and return top results
results.sort(key=lambda x: x[2], reverse=True)
return results[:top_k]
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
"""Get message history for a specific user."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
logger.info(f"Fetching last {limit} user messages")
cursor.execute(
"""
SELECT message_id, content, timestamp
FROM chat_messages
WHERE username != 'vibe-bot'
ORDER BY timestamp DESC
LIMIT ?
""",
(limit,),
)
messages = cursor.fetchall()
# Format is [user message, bot response]
conversations: list[tuple[str, str]] = []
for message in messages:
msg_content: str = message[1]
logger.info(f"Finding response for {msg_content[:50]}")
cursor.execute(
"""
SELECT content
FROM chat_messages
WHERE message_id = ?
ORDER BY timestamp DESC
""",
(f"{message[0]}_response",),
)
response_content: str = cursor.fetchone()
if response_content:
logger.info(f"Found response: {response_content[0][:50]}")
conversations.append((msg_content, response_content[0]))
else:
logger.info("No response found")
conn.close()
return conversations
def get_conversation_context(
self, user_id: str, current_message: str, max_context: int = 5
) -> str:
"""Get relevant conversation context for RAG."""
# Get recent messages from the user
recent_messages = self.get_user_history(user_id, limit=max_context * 2)
# Search for similar messages
similar_messages = self.search_similar_messages(
current_message, top_k=max_context
)
# Combine contexts
context_parts = []
# Add recent messages
for user_message, bot_message in recent_messages:
combined_content = f"[Recent chat]\n{user_message}\n{bot_message}"
context_parts.append(combined_content)
# Add similar messages
for user_message, bot_message, similarity in similar_messages:
combined_content = f"{user_message}\n{bot_message}"
if combined_content not in "\n".join(context_parts):
context_parts.append(f"[You remember]\n{combined_content}")
# Conversation history needs to be delivered in "newest context last" order
context_parts.reverse()
return "\n".join(context_parts[-max_context * 4 :]) # Limit total context
def clear_all_messages(self):
"""Clear all messages and embeddings from the database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("DELETE FROM message_embeddings")
cursor.execute("DELETE FROM chat_messages")
conn.commit()
conn.close()
# Global database instance
_chat_db: Optional[ChatDatabase] = None
def get_database() -> ChatDatabase:
"""Get or create the global database instance."""
global _chat_db
if _chat_db is None:
_chat_db = ChatDatabase()
return _chat_db
class CustomBotManager:
"""Manages custom bot configurations stored in SQLite database."""
def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path
self._initialize_custom_bots_table()
def _initialize_custom_bots_table(self):
"""Initialize the custom bots table in SQLite."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create table to hold custom bots
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS custom_bots (
bot_name TEXT PRIMARY KEY,
system_prompt TEXT NOT NULL,
created_by TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active INTEGER DEFAULT 1
)
"""
)
conn.commit()
conn.close()
def create_custom_bot(
self, bot_name: str, system_prompt: str, created_by: str
) -> bool:
"""Create a new custom bot configuration."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"""
INSERT OR REPLACE INTO custom_bots
(bot_name, system_prompt, created_by, is_active)
VALUES (?, ?, ?, 1)
""",
(bot_name.lower(), system_prompt, created_by),
)
conn.commit()
return True
except Exception as e:
print(f"Error creating custom bot: {e}")
conn.rollback()
return False
finally:
conn.close()
def get_custom_bot(self, bot_name: str) -> Optional[Tuple[str, str, str, datetime]]:
"""Get a custom bot configuration by name."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"""
SELECT bot_name, system_prompt, created_by, created_at
FROM custom_bots
WHERE bot_name = ? AND is_active = 1
""",
(bot_name.lower(),),
)
result = cursor.fetchone()
conn.close()
return result
def list_custom_bots(
self, user_id: Optional[str] = None
) -> List[Tuple[str, str, str]]:
"""List all custom bots, optionally filtered by creator."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
if user_id:
cursor.execute(
"""
SELECT bot_name, system_prompt, name
FROM custom_bots cb, username_map um
JOIN username_map ON custom_bots.created_by = username_map.id
WHERE is_active = 1
ORDER BY created_at DESC
"""
)
else:
cursor.execute(
"""
SELECT bot_name, system_prompt, created_by
FROM custom_bots
WHERE is_active = 1
ORDER BY created_at DESC
"""
)
bots = cursor.fetchall()
conn.close()
return bots
def delete_custom_bot(self, bot_name: str) -> bool:
"""Delete a custom bot configuration."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"""
DELETE FROM custom_bots
WHERE bot_name = ?
""",
(bot_name.lower(),),
)
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"Error deleting custom bot: {e}")
conn.rollback()
return False
finally:
conn.close()
def deactivate_custom_bot(self, bot_name: str) -> bool:
"""Deactivate a custom bot (soft delete)."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
cursor.execute(
"""
UPDATE custom_bots
SET is_active = 0
WHERE bot_name = ?
""",
(bot_name.lower(),),
)
conn.commit()
return cursor.rowcount > 0
except Exception as e:
print(f"Error deactivating custom bot: {e}")
conn.rollback()
return False
finally:
conn.close()