human cleanup
This commit is contained in:
@@ -1,32 +1,27 @@
|
||||
import sqlite3
|
||||
import os
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
import logging
|
||||
|
||||
import llama_wrapper # type: ignore
|
||||
from config import ( # type: ignore
|
||||
DB_PATH,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
MAX_HISTORY_MESSAGES,
|
||||
SIMILARITY_THRESHOLD,
|
||||
TOP_K_RESULTS,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "qwen3-embed-4b")
|
||||
EMBEDDING_DIMENSION = 2048 # Default for qwen3-embed-4b
|
||||
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
||||
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
||||
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
|
||||
|
||||
# OpenAI configuration
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "placeholder")
|
||||
OPENAI_API_EMBED_ENDPOINT = os.getenv(
|
||||
"OPENAI_API_EMBED_ENDPOINT", "https://llama-embed.reeselink.com"
|
||||
)
|
||||
|
||||
|
||||
class ChatDatabase:
|
||||
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
|
||||
@@ -34,7 +29,9 @@ class ChatDatabase:
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
logger.info(f"Initializing ChatDatabase with path: {db_path}")
|
||||
self.db_path = db_path
|
||||
self.client = OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key=OPENAI_API_KEY)
|
||||
self.client = OpenAI(
|
||||
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
|
||||
)
|
||||
logger.info("Connecting to OpenAI API for embeddings")
|
||||
self._initialize_database()
|
||||
|
||||
@@ -83,7 +80,7 @@ class ChatDatabase:
|
||||
"""
|
||||
)
|
||||
logger.info("idx_timestamp index created successfully")
|
||||
|
||||
|
||||
logger.info("Creating idx_user_id index if not exists")
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -96,36 +93,6 @@ class ChatDatabase:
|
||||
logger.info("Database initialization completed successfully")
|
||||
conn.close()
|
||||
|
||||
def _generate_embedding(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text using OpenAI API."""
|
||||
logger.debug(f"Generating embedding for text (length: {len(text)})")
|
||||
try:
|
||||
logger.info(f"Calling OpenAI API to generate embedding with model: {EMBEDDING_MODEL}")
|
||||
response = self.client.embeddings.create(
|
||||
model=EMBEDDING_MODEL, input=text, encoding_format="float"
|
||||
)
|
||||
logger.debug("OpenAI API response received successfully")
|
||||
|
||||
# The embedding is returned as a nested list: [[embedding_values]]
|
||||
# We need to extract the inner list
|
||||
embedding_data = response[0].embedding
|
||||
if isinstance(embedding_data, list) and len(embedding_data) > 0:
|
||||
# The first element might be the embedding array itself or a nested list
|
||||
first_item = embedding_data[0]
|
||||
if isinstance(first_item, list):
|
||||
# Handle nested structure: [[values]] -> [values]
|
||||
logger.debug("Extracted embedding from nested structure [[values]]")
|
||||
return first_item
|
||||
else:
|
||||
# Handle direct structure: [values]
|
||||
logger.debug("Extracted embedding from direct structure [values]")
|
||||
return embedding_data
|
||||
logger.warning("Embedding data is empty or invalid")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
return None
|
||||
|
||||
def _vector_to_bytes(self, vector: List[float]) -> bytes:
|
||||
"""Convert vector to bytes for SQLite storage."""
|
||||
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
|
||||
@@ -142,7 +109,9 @@ class ChatDatabase:
|
||||
|
||||
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
logger.debug(f"Calculating cosine similarity between vectors of dimension {len(vec1)}")
|
||||
logger.debug(
|
||||
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
|
||||
)
|
||||
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
logger.debug(f"Similarity calculated: {result:.4f}")
|
||||
return result
|
||||
@@ -163,7 +132,9 @@ class ChatDatabase:
|
||||
|
||||
try:
|
||||
# Insert message
|
||||
logger.debug(f"Inserting message into chat_messages table: message_id={message_id}")
|
||||
logger.debug(
|
||||
f"Inserting message into chat_messages table: message_id={message_id}"
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO chat_messages
|
||||
@@ -176,9 +147,16 @@ class ChatDatabase:
|
||||
|
||||
# Generate and store embedding
|
||||
logger.info(f"Generating embedding for message {message_id}")
|
||||
embedding = self._generate_embedding(content)
|
||||
embedding = llama_wrapper.embedding(
|
||||
content,
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model=EMBEDDING_MODEL,
|
||||
)
|
||||
if embedding:
|
||||
logger.debug(f"Embedding generated successfully for message {message_id}, storing in database")
|
||||
logger.debug(
|
||||
f"Embedding generated successfully for message {message_id}, storing in database"
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO message_embeddings
|
||||
@@ -187,9 +165,13 @@ class ChatDatabase:
|
||||
""",
|
||||
(message_id, self._vector_to_bytes(embedding)),
|
||||
)
|
||||
logger.debug(f"Embedding stored in message_embeddings table for message {message_id}")
|
||||
logger.debug(
|
||||
f"Embedding stored in message_embeddings table for message {message_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to generate embedding for message {message_id}, skipping embedding storage")
|
||||
logger.warning(
|
||||
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
|
||||
)
|
||||
|
||||
# Clean up old messages if exceeding limit
|
||||
logger.info("Checking if cleanup of old messages is needed")
|
||||
@@ -268,9 +250,14 @@ class ChatDatabase:
|
||||
query: str,
|
||||
top_k: int = TOP_K_RESULTS,
|
||||
min_similarity: float = SIMILARITY_THRESHOLD,
|
||||
) -> List[Tuple[str, str, str, float]]:
|
||||
) -> List[Tuple[str, str, float]]:
|
||||
"""Search for messages similar to the query using embeddings."""
|
||||
query_embedding = self._generate_embedding(query)
|
||||
query_embedding = llama_wrapper.embedding(
|
||||
text=query,
|
||||
model=EMBEDDING_MODEL,
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
@@ -285,19 +272,28 @@ class ChatDatabase:
|
||||
SELECT cm.message_id, cm.content, me.embedding
|
||||
FROM chat_messages cm
|
||||
JOIN message_embeddings me ON cm.message_id = me.message_id
|
||||
WHERE cm.username != 'vibe-bot'
|
||||
"""
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
results = []
|
||||
results: list[tuple[str, str, float]] = []
|
||||
for message_id, content, embedding_blob in rows:
|
||||
embedding_vector = self._bytes_to_vector(embedding_blob)
|
||||
similarity = self._calculate_similarity(query_vector, embedding_vector)
|
||||
|
||||
if similarity >= min_similarity:
|
||||
results.append(
|
||||
(message_id, content[:500], similarity)
|
||||
) # Limit content length
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT content
|
||||
FROM chat_messages
|
||||
WHERE message_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
""",
|
||||
(f"{message_id}_response",),
|
||||
)
|
||||
response: str = cursor.fetchone()[0]
|
||||
results.append((content, response, similarity))
|
||||
|
||||
conn.close()
|
||||
|
||||
@@ -305,28 +301,48 @@ class ChatDatabase:
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def get_user_history(
|
||||
self, user_id: str, limit: int = 20
|
||||
) -> List[Tuple[str, str, datetime]]:
|
||||
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||
"""Get message history for a specific user."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
logger.info(f"Fetching last {limit} user messages")
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, content, timestamp
|
||||
FROM chat_messages
|
||||
WHERE user_id = ?
|
||||
WHERE username != 'vibe-bot'
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(user_id, limit),
|
||||
(limit,),
|
||||
)
|
||||
|
||||
messages = cursor.fetchall()
|
||||
|
||||
# Format is [user message, bot response]
|
||||
conversations: list[tuple[str, str]] = []
|
||||
for message in messages:
|
||||
msg_content: str = message[1]
|
||||
logger.info(f"Finding response for {msg_content[:50]}")
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT content
|
||||
FROM chat_messages
|
||||
WHERE message_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
""",
|
||||
(f"{message[0]}_response",),
|
||||
)
|
||||
response_content: str = cursor.fetchone()
|
||||
if response_content:
|
||||
logger.info(f"Found response: {response_content[0][:50]}")
|
||||
conversations.append((msg_content, response_content[0]))
|
||||
else:
|
||||
logger.info("No response found")
|
||||
conn.close()
|
||||
|
||||
return messages
|
||||
return conversations
|
||||
|
||||
def get_conversation_context(
|
||||
self, user_id: str, current_message: str, max_context: int = 5
|
||||
@@ -344,15 +360,19 @@ class ChatDatabase:
|
||||
context_parts = []
|
||||
|
||||
# Add recent messages
|
||||
for message_id, content, timestamp in recent_messages:
|
||||
context_parts.append(f"[{timestamp}] User: {content}")
|
||||
for user_message, bot_message in recent_messages:
|
||||
combined_content = f"[Recent chat]\n{user_message}\n{bot_message}"
|
||||
context_parts.append(combined_content)
|
||||
|
||||
# Add similar messages
|
||||
for message_id, content, similarity in similar_messages:
|
||||
if f"[{content}" not in "\n".join(context_parts): # Avoid duplicates
|
||||
context_parts.append(f"[Similar] {content}")
|
||||
for user_message, bot_message, similarity in similar_messages:
|
||||
combined_content = f"{user_message}\n{bot_message}"
|
||||
if combined_content not in "\n".join(context_parts):
|
||||
context_parts.append(f"[You remember]\n{combined_content}")
|
||||
|
||||
return "\n".join(context_parts[-max_context * 2 :]) # Limit total context
|
||||
# Conversation history needs to be delivered in "newest context last" order
|
||||
context_parts.reverse()
|
||||
return "\n".join(context_parts[-max_context * 4 :]) # Limit total context
|
||||
|
||||
def clear_all_messages(self):
|
||||
"""Clear all messages and embeddings from the database."""
|
||||
@@ -390,6 +410,7 @@ class CustomBotManager:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create table to hold custom bots
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS custom_bots (
|
||||
@@ -399,7 +420,7 @@ class CustomBotManager:
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active INTEGER DEFAULT 1
|
||||
)
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -461,8 +482,9 @@ class CustomBotManager:
|
||||
if user_id:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT bot_name, system_prompt, created_by
|
||||
FROM custom_bots
|
||||
SELECT bot_name, system_prompt, name
|
||||
FROM custom_bots cb, username_map um
|
||||
JOIN username_map ON custom_bots.created_by = username_map.id
|
||||
WHERE is_active = 1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user