679 lines
21 KiB
Python
679 lines
21 KiB
Python
"""SQLite database with RAG support for chat history and embeddings."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
from typing import TYPE_CHECKING
|
|
|
|
import numpy as np
|
|
from openai import OpenAI
|
|
|
|
from vibe_bot import llama_wrapper
|
|
from vibe_bot.config import (
|
|
DB_PATH,
|
|
EMBEDDING_ENDPOINT,
|
|
EMBEDDING_ENDPOINT_KEY,
|
|
EMBEDDING_MODEL,
|
|
MAX_HISTORY_MESSAGES,
|
|
SIMILARITY_THRESHOLD,
|
|
TOP_K_RESULTS,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from datetime import datetime
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatDatabase:
|
|
"""SQLite database with RAG support for storing chat history
|
|
using OpenAI embeddings.
|
|
"""
|
|
|
|
def __init__(self, db_path: str = DB_PATH) -> None:
|
|
"""Initialize the database connection.
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database file.
|
|
|
|
"""
|
|
logger.info("Initializing ChatDatabase with path: %s", db_path)
|
|
self.db_path = db_path
|
|
self.client = OpenAI(
|
|
base_url=EMBEDDING_ENDPOINT,
|
|
api_key=EMBEDDING_ENDPOINT_KEY,
|
|
)
|
|
logger.info("Connecting to OpenAI API for embeddings")
|
|
self._initialize_database()
|
|
|
|
def _initialize_database(self) -> None:
|
|
"""Initialize the SQLite database with required tables."""
|
|
logger.info("Initializing SQLite database at %s", self.db_path)
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create messages table
|
|
logger.info("Creating chat_messages table if not exists")
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
message_id TEXT UNIQUE,
|
|
user_id TEXT,
|
|
username TEXT,
|
|
content TEXT,
|
|
bot_name TEXT,
|
|
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
channel_id TEXT,
|
|
guild_id TEXT
|
|
)
|
|
""",
|
|
)
|
|
logger.info("chat_messages table initialized successfully")
|
|
|
|
# Migrate: add bot_name column if it doesn't exist
|
|
logger.info("Checking for bot_name column migration")
|
|
cursor.execute("PRAGMA table_info(chat_messages)")
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
if "bot_name" not in columns:
|
|
logger.info("Adding bot_name column to chat_messages table")
|
|
cursor.execute(
|
|
"ALTER TABLE chat_messages ADD COLUMN bot_name TEXT",
|
|
)
|
|
logger.info("bot_name column added successfully")
|
|
|
|
# Create embeddings table for RAG
|
|
logger.info("Creating message_embeddings table if not exists")
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS message_embeddings (
|
|
message_id TEXT PRIMARY KEY,
|
|
embedding BLOB,
|
|
FOREIGN KEY (message_id) REFERENCES chat_messages(message_id)
|
|
)
|
|
""",
|
|
)
|
|
logger.info("message_embeddings table initialized successfully")
|
|
|
|
# Create index for faster lookups
|
|
logger.info("Creating idx_timestamp index if not exists")
|
|
cursor.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
|
|
""",
|
|
)
|
|
logger.info("idx_timestamp index created successfully")
|
|
|
|
logger.info("Creating idx_user_id index if not exists")
|
|
cursor.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
|
|
""",
|
|
)
|
|
logger.info("idx_user_id index created successfully")
|
|
|
|
conn.commit()
|
|
logger.info("Database initialization completed successfully")
|
|
conn.close()
|
|
|
|
def _vector_to_bytes(self, vector: list[float]) -> bytes:
|
|
"""Convert vector to bytes for SQLite storage."""
|
|
logger.debug("Converting vector (length: %d) to bytes", len(vector))
|
|
result = np.array(vector, dtype=np.float32).tobytes()
|
|
logger.debug("Vector converted to %d bytes", len(result))
|
|
return result
|
|
|
|
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
|
|
"""Convert bytes back to vector."""
|
|
logger.debug("Converting %d bytes back to vector", len(blob))
|
|
result = np.frombuffer(blob, dtype=np.float32)
|
|
logger.debug("Vector reconstructed with %d dimensions", len(result))
|
|
return result
|
|
|
|
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
vec1 = vec1.flatten()
|
|
vec2 = vec2.flatten()
|
|
logger.debug(
|
|
"Calculating cosine similarity between vectors of dimension %d",
|
|
len(vec1),
|
|
)
|
|
norm1 = np.linalg.norm(vec1)
|
|
norm2 = np.linalg.norm(vec2)
|
|
if norm1 == 0 or norm2 == 0:
|
|
return 0.0
|
|
result = float(np.dot(vec1, vec2) / (norm1 * norm2))
|
|
logger.debug("Similarity calculated: %.4f", result)
|
|
return result
|
|
|
|
def add_message(
|
|
self,
|
|
*,
|
|
message_id: str,
|
|
user_id: str,
|
|
username: str,
|
|
content: str,
|
|
bot_name: str | None = None,
|
|
channel_id: str | None = None,
|
|
guild_id: str | None = None,
|
|
) -> bool:
|
|
"""Add a message to the database and generate its embedding."""
|
|
logger.info("Adding message %s from user %s", message_id, username)
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
# Insert message
|
|
logger.debug(
|
|
"Inserting message into chat_messages table: message_id=%s",
|
|
message_id,
|
|
)
|
|
cursor.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO chat_messages
|
|
(message_id, user_id, username, content, bot_name, channel_id, guild_id)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
message_id,
|
|
user_id,
|
|
username,
|
|
content,
|
|
bot_name,
|
|
channel_id,
|
|
guild_id,
|
|
),
|
|
)
|
|
logger.debug("Message %s inserted into chat_messages table", message_id)
|
|
|
|
# Generate and store embedding
|
|
logger.info("Generating embedding for message %s", message_id)
|
|
embedding = llama_wrapper.embedding(
|
|
content,
|
|
openai_url=EMBEDDING_ENDPOINT,
|
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
model=EMBEDDING_MODEL,
|
|
)
|
|
if embedding:
|
|
logger.debug(
|
|
"Embedding generated successfully for message %s, "
|
|
"storing in database",
|
|
message_id,
|
|
)
|
|
cursor.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO message_embeddings
|
|
(message_id, embedding)
|
|
VALUES (?, ?)
|
|
""",
|
|
(message_id, self._vector_to_bytes(embedding)),
|
|
)
|
|
logger.debug(
|
|
"Embedding stored in message_embeddings table for message %s",
|
|
message_id,
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Failed to generate embedding for message %s, "
|
|
"skipping embedding storage",
|
|
message_id,
|
|
)
|
|
|
|
# Clean up old messages if exceeding limit
|
|
logger.info("Checking if cleanup of old messages is needed")
|
|
self._cleanup_old_messages(cursor)
|
|
|
|
conn.commit()
|
|
except Exception:
|
|
logger.exception("Error adding message %s", message_id)
|
|
conn.rollback()
|
|
return False
|
|
else:
|
|
logger.info("Successfully added message %s to database", message_id)
|
|
return True
|
|
finally:
|
|
conn.close()
|
|
|
|
def _cleanup_old_messages(self, cursor: sqlite3.Cursor) -> None:
|
|
"""Remove old messages to stay within the limit."""
|
|
cursor.execute(
|
|
"""
|
|
SELECT COUNT(*) FROM chat_messages
|
|
""",
|
|
)
|
|
count = cursor.fetchone()[0]
|
|
|
|
if count > MAX_HISTORY_MESSAGES:
|
|
cursor.execute(
|
|
"""
|
|
DELETE FROM chat_messages
|
|
WHERE id IN (
|
|
SELECT id FROM chat_messages
|
|
ORDER BY timestamp ASC
|
|
LIMIT ?
|
|
)
|
|
""",
|
|
(count - MAX_HISTORY_MESSAGES,),
|
|
)
|
|
|
|
# Also remove corresponding embeddings
|
|
cursor.execute(
|
|
"""
|
|
DELETE FROM message_embeddings
|
|
WHERE message_id IN (
|
|
SELECT message_id FROM chat_messages
|
|
ORDER BY timestamp ASC
|
|
LIMIT ?
|
|
)
|
|
""",
|
|
(count - MAX_HISTORY_MESSAGES,),
|
|
)
|
|
|
|
def get_recent_messages(
|
|
self,
|
|
limit: int = 10,
|
|
) -> list[tuple[str, str, str, datetime]]:
|
|
"""Get recent messages from the database."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT message_id, username, content, timestamp
|
|
FROM chat_messages
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
""",
|
|
(limit,),
|
|
)
|
|
|
|
messages = cursor.fetchall()
|
|
conn.close()
|
|
|
|
return messages
|
|
|
|
def search_similar_messages(
|
|
self,
|
|
query: str,
|
|
top_k: int = TOP_K_RESULTS,
|
|
min_similarity: float = SIMILARITY_THRESHOLD,
|
|
) -> list[tuple[str, str, float]]:
|
|
"""Search for messages similar to the query using embeddings."""
|
|
query_embedding = llama_wrapper.embedding(
|
|
text=query,
|
|
model=EMBEDDING_MODEL,
|
|
openai_url=EMBEDDING_ENDPOINT,
|
|
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
|
)
|
|
if not query_embedding:
|
|
return []
|
|
|
|
query_vector = np.array(query_embedding, dtype=np.float32)
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Join chat_messages and message_embeddings to get content and embeddings
|
|
cursor.execute(
|
|
"""
|
|
SELECT cm.message_id, cm.content, me.embedding
|
|
FROM chat_messages cm
|
|
JOIN message_embeddings me ON cm.message_id = me.message_id
|
|
WHERE cm.username != 'vibe-bot'
|
|
""",
|
|
)
|
|
rows = cursor.fetchall()
|
|
|
|
results: list[tuple[str, str, float]] = []
|
|
for message_id, content, embedding_blob in rows:
|
|
embedding_vector = self._bytes_to_vector(embedding_blob)
|
|
similarity = self._calculate_similarity(query_vector, embedding_vector)
|
|
|
|
if similarity >= min_similarity:
|
|
cursor.execute(
|
|
"""
|
|
SELECT content
|
|
FROM chat_messages
|
|
WHERE message_id = ?
|
|
ORDER BY timestamp DESC
|
|
""",
|
|
(f"{message_id}_response",),
|
|
)
|
|
response_row = cursor.fetchone()
|
|
if response_row:
|
|
results.append((content, response_row[0], similarity))
|
|
|
|
conn.close()
|
|
|
|
# Sort by similarity and return top results
|
|
results.sort(key=lambda x: x[2], reverse=True)
|
|
return results[:top_k]
|
|
|
|
def get_bot_history(self, bot_name: str, limit: int = 20) -> list[tuple[str, str]]:
|
|
"""Get message history for a specific custom bot.
|
|
|
|
Args:
|
|
bot_name: The name of the custom bot.
|
|
limit: Maximum number of messages to retrieve.
|
|
|
|
Returns:
|
|
List of (user_message, bot_response) tuples.
|
|
|
|
"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
logger.info(
|
|
"Fetching last %d messages for bot %r",
|
|
limit,
|
|
bot_name,
|
|
)
|
|
cursor.execute(
|
|
"""
|
|
SELECT message_id, content, timestamp
|
|
FROM chat_messages
|
|
WHERE bot_name = ? AND message_id NOT LIKE '%%_response'
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
""",
|
|
(bot_name, limit),
|
|
)
|
|
|
|
messages = cursor.fetchall()
|
|
|
|
conversations: list[tuple[str, str]] = []
|
|
for message in messages:
|
|
msg_content = message[1]
|
|
logger.debug("Finding response for %s...", msg_content[:50])
|
|
cursor.execute(
|
|
"""
|
|
SELECT content
|
|
FROM chat_messages
|
|
WHERE message_id = ?
|
|
ORDER BY timestamp DESC
|
|
""",
|
|
(f"{message[0]}_response",),
|
|
)
|
|
response_row = cursor.fetchone()
|
|
if response_row:
|
|
logger.debug("Found response: %s...", response_row[0][:50])
|
|
conversations.append((msg_content, response_row[0]))
|
|
else:
|
|
logger.debug("No response found")
|
|
conn.close()
|
|
|
|
return conversations
|
|
|
|
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
|
"""Get message history for a specific user."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
logger.info("Fetching last %d user messages", limit)
|
|
cursor.execute(
|
|
"""
|
|
SELECT message_id, content, timestamp
|
|
FROM chat_messages
|
|
WHERE user_id = ? AND username != 'vibe-bot'
|
|
ORDER BY timestamp DESC
|
|
LIMIT ?
|
|
""",
|
|
(user_id, limit),
|
|
)
|
|
|
|
messages = cursor.fetchall()
|
|
|
|
# Format is [user message, bot response]
|
|
conversations: list[tuple[str, str]] = []
|
|
for message in messages:
|
|
msg_content = message[1]
|
|
logger.debug("Finding response for %s...", msg_content[:50])
|
|
cursor.execute(
|
|
"""
|
|
SELECT content
|
|
FROM chat_messages
|
|
WHERE message_id = ?
|
|
ORDER BY timestamp DESC
|
|
""",
|
|
(f"{message[0]}_response",),
|
|
)
|
|
response_row = cursor.fetchone()
|
|
if response_row:
|
|
logger.debug("Found response: %s...", response_row[0][:50])
|
|
conversations.append((msg_content, response_row[0]))
|
|
else:
|
|
logger.debug("No response found")
|
|
conn.close()
|
|
|
|
return conversations
|
|
|
|
def get_conversation_context(
|
|
self,
|
|
user_id: str,
|
|
current_message: str,
|
|
max_context: int = 5,
|
|
) -> list[dict[str, str]]:
|
|
"""Get relevant conversation context for RAG."""
|
|
# Get recent messages from the user
|
|
recent_messages = self.get_user_history(user_id, limit=max_context * 2)
|
|
|
|
# Search for similar messages
|
|
similar_messages = self.search_similar_messages(
|
|
current_message,
|
|
top_k=max_context,
|
|
)
|
|
|
|
# Combine contexts
|
|
context_parts: list[dict[str, str]] = []
|
|
|
|
# Add recent messages
|
|
for user_message, bot_message in recent_messages:
|
|
context_parts.append({"role": "assistant", "content": bot_message})
|
|
context_parts.append({"role": "user", "content": user_message})
|
|
|
|
# Add similar messages
|
|
for user_message, bot_message, _similarity in similar_messages:
|
|
context_parts.append({"role": "assistant", "content": bot_message})
|
|
context_parts.append({"role": "user", "content": user_message})
|
|
|
|
# Conversation history needs to be delivered in "newest context last" order
|
|
context_parts.reverse()
|
|
return context_parts
|
|
|
|
def clear_all_messages(self) -> None:
|
|
"""Clear all messages and embeddings from the database."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("DELETE FROM message_embeddings")
|
|
cursor.execute("DELETE FROM chat_messages")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
# Global database instance
|
|
_chat_db: ChatDatabase | None = None
|
|
|
|
|
|
def get_database() -> ChatDatabase:
|
|
"""Get or create the global database instance."""
|
|
global _chat_db # noqa: PLW0603
|
|
if _chat_db is None:
|
|
_chat_db = ChatDatabase()
|
|
return _chat_db
|
|
|
|
|
|
class CustomBotManager:
|
|
"""Manages custom bot configurations stored in SQLite database."""
|
|
|
|
def __init__(self, db_path: str = DB_PATH) -> None:
|
|
"""Initialize the custom bot manager.
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database file.
|
|
|
|
"""
|
|
self.db_path = db_path
|
|
self._initialize_custom_bots_table()
|
|
|
|
def _initialize_custom_bots_table(self) -> None:
|
|
"""Initialize the custom bots table in SQLite."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create table to hold custom bots
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS custom_bots (
|
|
bot_name TEXT PRIMARY KEY,
|
|
system_prompt TEXT NOT NULL,
|
|
created_by TEXT NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
is_active INTEGER DEFAULT 1
|
|
)
|
|
""",
|
|
)
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def create_custom_bot(
|
|
self,
|
|
bot_name: str,
|
|
system_prompt: str,
|
|
created_by: str,
|
|
) -> bool:
|
|
"""Create a new custom bot configuration."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO custom_bots
|
|
(bot_name, system_prompt, created_by, is_active)
|
|
VALUES (?, ?, ?, 1)
|
|
""",
|
|
(bot_name.lower(), system_prompt, created_by),
|
|
)
|
|
|
|
conn.commit()
|
|
except Exception:
|
|
logger.exception("Error creating custom bot")
|
|
conn.rollback()
|
|
return False
|
|
else:
|
|
return True
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_custom_bot(self, bot_name: str) -> tuple[str, str, str, datetime] | None:
|
|
"""Get a custom bot configuration by name."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
SELECT bot_name, system_prompt, created_by, created_at
|
|
FROM custom_bots
|
|
WHERE bot_name = ? AND is_active = 1
|
|
""",
|
|
(bot_name.lower(),),
|
|
)
|
|
|
|
result = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if result is None:
|
|
return None
|
|
return (result[0], result[1], result[2], result[3])
|
|
|
|
def list_custom_bots(
|
|
self,
|
|
user_id: str | None = None,
|
|
) -> list[tuple[str, str, str]]:
|
|
"""List all custom bots, optionally filtered by creator."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
if user_id:
|
|
cursor.execute(
|
|
"""
|
|
SELECT bot_name, system_prompt, created_by
|
|
FROM custom_bots
|
|
WHERE is_active = 1 AND created_by = ?
|
|
ORDER BY created_at DESC
|
|
""",
|
|
(user_id,),
|
|
)
|
|
else:
|
|
cursor.execute(
|
|
"""
|
|
SELECT bot_name, system_prompt, created_by
|
|
FROM custom_bots
|
|
WHERE is_active = 1
|
|
ORDER BY created_at DESC
|
|
""",
|
|
)
|
|
|
|
bots = cursor.fetchall()
|
|
conn.close()
|
|
|
|
return bots
|
|
|
|
def delete_custom_bot(self, bot_name: str) -> bool:
|
|
"""Delete a custom bot configuration."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute(
|
|
"""
|
|
DELETE FROM custom_bots
|
|
WHERE bot_name = ?
|
|
""",
|
|
(bot_name.lower(),),
|
|
)
|
|
|
|
conn.commit()
|
|
except Exception:
|
|
logger.exception("Error deleting custom bot")
|
|
conn.rollback()
|
|
return False
|
|
else:
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|
|
|
|
def deactivate_custom_bot(self, bot_name: str) -> bool:
|
|
"""Deactivate a custom bot (soft delete)."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute(
|
|
"""
|
|
UPDATE custom_bots
|
|
SET is_active = 0
|
|
WHERE bot_name = ?
|
|
""",
|
|
(bot_name.lower(),),
|
|
)
|
|
|
|
conn.commit()
|
|
except Exception:
|
|
logger.exception("Error deactivating custom bot")
|
|
conn.rollback()
|
|
return False
|
|
else:
|
|
return cursor.rowcount > 0
|
|
finally:
|
|
conn.close()
|