fix linting, formatting, and add tests
This commit is contained in:
+138
-95
@@ -1,43 +1,60 @@
|
||||
"""SQLite database with RAG support for chat history and embeddings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
import logging
|
||||
|
||||
import llama_wrapper # type: ignore
|
||||
from config import ( # type: ignore
|
||||
from vibe_bot import llama_wrapper
|
||||
from vibe_bot.config import (
|
||||
DB_PATH,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
EMBEDDING_MODEL,
|
||||
MAX_HISTORY_MESSAGES,
|
||||
SIMILARITY_THRESHOLD,
|
||||
TOP_K_RESULTS,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDatabase:
|
||||
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
|
||||
"""SQLite database with RAG support for storing chat history
|
||||
using OpenAI embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
logger.info(f"Initializing ChatDatabase with path: {db_path}")
|
||||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||
"""Initialize the database connection.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file.
|
||||
|
||||
"""
|
||||
logger.info("Initializing ChatDatabase with path: %s", db_path)
|
||||
self.db_path = db_path
|
||||
self.client = OpenAI(
|
||||
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
|
||||
base_url=EMBEDDING_ENDPOINT,
|
||||
api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
)
|
||||
logger.info("Connecting to OpenAI API for embeddings")
|
||||
self._initialize_database()
|
||||
|
||||
def _initialize_database(self):
|
||||
def _initialize_database(self) -> None:
|
||||
"""Initialize the SQLite database with required tables."""
|
||||
logger.info(f"Initializing SQLite database at {self.db_path}")
|
||||
logger.info("Initializing SQLite database at %s", self.db_path)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -55,7 +72,7 @@ class ChatDatabase:
|
||||
channel_id TEXT,
|
||||
guild_id TEXT
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("chat_messages table initialized successfully")
|
||||
|
||||
@@ -68,7 +85,7 @@ class ChatDatabase:
|
||||
embedding BLOB,
|
||||
FOREIGN KEY (message_id) REFERENCES chat_messages(message_id)
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("message_embeddings table initialized successfully")
|
||||
|
||||
@@ -77,7 +94,7 @@ class ChatDatabase:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("idx_timestamp index created successfully")
|
||||
|
||||
@@ -85,7 +102,7 @@ class ChatDatabase:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("idx_user_id index created successfully")
|
||||
|
||||
@@ -93,60 +110,65 @@ class ChatDatabase:
|
||||
logger.info("Database initialization completed successfully")
|
||||
conn.close()
|
||||
|
||||
def _vector_to_bytes(self, vector: List[float]) -> bytes:
|
||||
def _vector_to_bytes(self, vector: list[float]) -> bytes:
|
||||
"""Convert vector to bytes for SQLite storage."""
|
||||
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
|
||||
logger.debug("Converting vector (length: %d) to bytes", len(vector))
|
||||
result = np.array(vector, dtype=np.float32).tobytes()
|
||||
logger.debug(f"Vector converted to {len(result)} bytes")
|
||||
logger.debug("Vector converted to %d bytes", len(result))
|
||||
return result
|
||||
|
||||
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
|
||||
"""Convert bytes back to vector."""
|
||||
logger.debug(f"Converting {len(blob)} bytes back to vector")
|
||||
logger.debug("Converting %d bytes back to vector", len(blob))
|
||||
result = np.frombuffer(blob, dtype=np.float32)
|
||||
logger.debug(f"Vector reconstructed with {len(result)} dimensions")
|
||||
logger.debug("Vector reconstructed with %d dimensions", len(result))
|
||||
return result
|
||||
|
||||
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
logger.debug(
|
||||
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
|
||||
"Calculating cosine similarity between vectors of dimension %d",
|
||||
len(vec1),
|
||||
)
|
||||
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
logger.debug(f"Similarity calculated: {result:.4f}")
|
||||
result = float(
|
||||
np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)),
|
||||
)
|
||||
logger.debug("Similarity calculated: %.4f", result)
|
||||
return result
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
username: str,
|
||||
content: str,
|
||||
channel_id: Optional[str] = None,
|
||||
guild_id: Optional[str] = None,
|
||||
channel_id: str | None = None,
|
||||
guild_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Add a message to the database and generate its embedding."""
|
||||
logger.info(f"Adding message {message_id} from user {username}")
|
||||
logger.info("Adding message %s from user %s", message_id, username)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Insert message
|
||||
logger.debug(
|
||||
f"Inserting message into chat_messages table: message_id={message_id}"
|
||||
"Inserting message into chat_messages table: message_id=%s",
|
||||
message_id,
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO chat_messages
|
||||
INSERT OR REPLACE INTO chat_messages
|
||||
(message_id, user_id, username, content, channel_id, guild_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(message_id, user_id, username, content, channel_id, guild_id),
|
||||
)
|
||||
logger.debug(f"Message {message_id} inserted into chat_messages table")
|
||||
logger.debug("Message %s inserted into chat_messages table", message_id)
|
||||
|
||||
# Generate and store embedding
|
||||
logger.info(f"Generating embedding for message {message_id}")
|
||||
logger.info("Generating embedding for message %s", message_id)
|
||||
embedding = llama_wrapper.embedding(
|
||||
content,
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
@@ -155,22 +177,27 @@ class ChatDatabase:
|
||||
)
|
||||
if embedding:
|
||||
logger.debug(
|
||||
f"Embedding generated successfully for message {message_id}, storing in database"
|
||||
"Embedding generated successfully for message %s, "
|
||||
"storing in database",
|
||||
message_id,
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO message_embeddings
|
||||
INSERT OR REPLACE INTO message_embeddings
|
||||
(message_id, embedding)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(message_id, self._vector_to_bytes(embedding)),
|
||||
)
|
||||
logger.debug(
|
||||
f"Embedding stored in message_embeddings table for message {message_id}"
|
||||
"Embedding stored in message_embeddings table for message %s",
|
||||
message_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
|
||||
"Failed to generate embedding for message %s, "
|
||||
"skipping embedding storage",
|
||||
message_id,
|
||||
)
|
||||
|
||||
# Clean up old messages if exceeding limit
|
||||
@@ -178,32 +205,32 @@ class ChatDatabase:
|
||||
self._cleanup_old_messages(cursor)
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Successfully added message {message_id} to database")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message {message_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error adding message %s", message_id)
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
logger.info("Successfully added message %s to database", message_id)
|
||||
return True
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _cleanup_old_messages(self, cursor):
|
||||
def _cleanup_old_messages(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Remove old messages to stay within the limit."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_messages
|
||||
"""
|
||||
""",
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
if count > MAX_HISTORY_MESSAGES:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM chat_messages
|
||||
DELETE FROM chat_messages
|
||||
WHERE id IN (
|
||||
SELECT id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
SELECT id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
@@ -213,10 +240,10 @@ class ChatDatabase:
|
||||
# Also remove corresponding embeddings
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM message_embeddings
|
||||
DELETE FROM message_embeddings
|
||||
WHERE message_id IN (
|
||||
SELECT message_id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
SELECT message_id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
@@ -224,8 +251,9 @@ class ChatDatabase:
|
||||
)
|
||||
|
||||
def get_recent_messages(
|
||||
self, limit: int = 10
|
||||
) -> List[Tuple[str, str, str, datetime]]:
|
||||
self,
|
||||
limit: int = 10,
|
||||
) -> list[tuple[str, str, str, datetime]]:
|
||||
"""Get recent messages from the database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -250,7 +278,7 @@ class ChatDatabase:
|
||||
query: str,
|
||||
top_k: int = TOP_K_RESULTS,
|
||||
min_similarity: float = SIMILARITY_THRESHOLD,
|
||||
) -> List[Tuple[str, str, float]]:
|
||||
) -> list[tuple[str, str, float]]:
|
||||
"""Search for messages similar to the query using embeddings."""
|
||||
query_embedding = llama_wrapper.embedding(
|
||||
text=query,
|
||||
@@ -269,11 +297,11 @@ class ChatDatabase:
|
||||
# Join chat_messages and message_embeddings to get content and embeddings
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT cm.message_id, cm.content, me.embedding
|
||||
SELECT cm.message_id, cm.content, me.embedding
|
||||
FROM chat_messages cm
|
||||
JOIN message_embeddings me ON cm.message_id = me.message_id
|
||||
WHERE cm.username != 'vibe-bot'
|
||||
"""
|
||||
""",
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
@@ -302,12 +330,12 @@ class ChatDatabase:
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||
def get_user_history(self, _user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||
"""Get message history for a specific user."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
logger.info(f"Fetching last {limit} user messages")
|
||||
logger.info("Fetching last %d user messages", limit)
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, content, timestamp
|
||||
@@ -324,8 +352,8 @@ class ChatDatabase:
|
||||
# Format is [user message, bot response]
|
||||
conversations: list[tuple[str, str]] = []
|
||||
for message in messages:
|
||||
msg_content: str = message[1]
|
||||
logger.info(f"Finding response for {msg_content[:50]}")
|
||||
msg_content = message[1]
|
||||
logger.debug("Finding response for %s...", msg_content[:50])
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT content
|
||||
@@ -335,18 +363,21 @@ class ChatDatabase:
|
||||
""",
|
||||
(f"{message[0]}_response",),
|
||||
)
|
||||
response_content: str = cursor.fetchone()
|
||||
if response_content:
|
||||
logger.info(f"Found response: {response_content[0][:50]}")
|
||||
conversations.append((msg_content, response_content[0]))
|
||||
response_row = cursor.fetchone()
|
||||
if response_row:
|
||||
logger.debug("Found response: %s...", response_row[0][:50])
|
||||
conversations.append((msg_content, response_row[0]))
|
||||
else:
|
||||
logger.info("No response found")
|
||||
logger.debug("No response found")
|
||||
conn.close()
|
||||
|
||||
return conversations
|
||||
|
||||
def get_conversation_context(
|
||||
self, user_id: str, current_message: str, max_context: int = 5
|
||||
self,
|
||||
user_id: str,
|
||||
current_message: str,
|
||||
max_context: int = 5,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get relevant conversation context for RAG."""
|
||||
# Get recent messages from the user
|
||||
@@ -354,7 +385,8 @@ class ChatDatabase:
|
||||
|
||||
# Search for similar messages
|
||||
similar_messages = self.search_similar_messages(
|
||||
current_message, top_k=max_context
|
||||
current_message,
|
||||
top_k=max_context,
|
||||
)
|
||||
|
||||
# Combine contexts
|
||||
@@ -366,7 +398,7 @@ class ChatDatabase:
|
||||
context_parts.append({"role": "user", "content": user_message})
|
||||
|
||||
# Add similar messages
|
||||
for user_message, bot_message, similarity in similar_messages:
|
||||
for user_message, bot_message, _similarity in similar_messages:
|
||||
context_parts.append({"role": "assistant", "content": bot_message})
|
||||
context_parts.append({"role": "user", "content": user_message})
|
||||
|
||||
@@ -374,7 +406,7 @@ class ChatDatabase:
|
||||
context_parts.reverse()
|
||||
return context_parts
|
||||
|
||||
def clear_all_messages(self):
|
||||
def clear_all_messages(self) -> None:
|
||||
"""Clear all messages and embeddings from the database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -387,12 +419,12 @@ class ChatDatabase:
|
||||
|
||||
|
||||
# Global database instance
|
||||
_chat_db: Optional[ChatDatabase] = None
|
||||
_chat_db: ChatDatabase | None = None
|
||||
|
||||
|
||||
def get_database() -> ChatDatabase:
|
||||
"""Get or create the global database instance."""
|
||||
global _chat_db
|
||||
global _chat_db # noqa: PLW0603
|
||||
if _chat_db is None:
|
||||
_chat_db = ChatDatabase()
|
||||
return _chat_db
|
||||
@@ -401,11 +433,17 @@ def get_database() -> ChatDatabase:
|
||||
class CustomBotManager:
|
||||
"""Manages custom bot configurations stored in SQLite database."""
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||
"""Initialize the custom bot manager.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file.
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._initialize_custom_bots_table()
|
||||
|
||||
def _initialize_custom_bots_table(self):
|
||||
def _initialize_custom_bots_table(self) -> None:
|
||||
"""Initialize the custom bots table in SQLite."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -420,14 +458,17 @@ class CustomBotManager:
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active INTEGER DEFAULT 1
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create_custom_bot(
|
||||
self, bot_name: str, system_prompt: str, created_by: str
|
||||
self,
|
||||
bot_name: str,
|
||||
system_prompt: str,
|
||||
created_by: str,
|
||||
) -> bool:
|
||||
"""Create a new custom bot configuration."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -436,7 +477,7 @@ class CustomBotManager:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO custom_bots
|
||||
INSERT OR REPLACE INTO custom_bots
|
||||
(bot_name, system_prompt, created_by, is_active)
|
||||
VALUES (?, ?, ?, 1)
|
||||
""",
|
||||
@@ -444,16 +485,16 @@ class CustomBotManager:
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating custom bot: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error creating custom bot")
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_custom_bot(self, bot_name: str) -> Optional[Tuple[str, str, str, datetime]]:
|
||||
def get_custom_bot(self, bot_name: str) -> tuple[str, str, str, datetime] | None:
|
||||
"""Get a custom bot configuration by name."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -470,11 +511,14 @@ class CustomBotManager:
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
return result
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1], result[2], result[3])
|
||||
|
||||
def list_custom_bots(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
) -> list[tuple[str, str, str]]:
|
||||
"""List all custom bots, optionally filtered by creator."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -482,12 +526,11 @@ class CustomBotManager:
|
||||
if user_id:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT bot_name, system_prompt, name
|
||||
FROM custom_bots cb, username_map um
|
||||
JOIN username_map ON custom_bots.created_by = username_map.id
|
||||
SELECT bot_name, system_prompt, created_by
|
||||
FROM custom_bots
|
||||
WHERE is_active = 1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
@@ -496,7 +539,7 @@ class CustomBotManager:
|
||||
FROM custom_bots
|
||||
WHERE is_active = 1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
bots = cursor.fetchall()
|
||||
@@ -519,12 +562,12 @@ class CustomBotManager:
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting custom bot: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error deleting custom bot")
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -544,11 +587,11 @@ class CustomBotManager:
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deactivating custom bot: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error deactivating custom bot")
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user