hopefully fix repetition with higher temperature and frequency penalty

This commit is contained in:
2026-03-03 21:30:57 -05:00
parent c547edc44b
commit a6ab9708a0
4 changed files with 271 additions and 26 deletions

View File

@@ -4,6 +4,14 @@ from typing import Optional, List, Tuple
from datetime import datetime
import numpy as np
from openai import OpenAI
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Database configuration
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
@@ -24,16 +32,20 @@ class ChatDatabase:
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
def __init__(self, db_path: str = DB_PATH):
logger.info(f"Initializing ChatDatabase with path: {db_path}")
self.db_path = db_path
self.client = OpenAI(base_url=OPENAI_API_EMBED_ENDPOINT, api_key=OPENAI_API_KEY)
logger.info("Connecting to OpenAI API for embeddings")
self._initialize_database()
def _initialize_database(self):
"""Initialize the SQLite database with required tables."""
logger.info(f"Initializing SQLite database at {self.db_path}")
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create messages table
logger.info("Creating chat_messages table if not exists")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS chat_messages (
@@ -48,8 +60,10 @@ class ChatDatabase:
)
"""
)
logger.info("chat_messages table initialized successfully")
# Create embeddings table for RAG
logger.info("Creating message_embeddings table if not exists")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS message_embeddings (
@@ -59,28 +73,39 @@ class ChatDatabase:
)
"""
)
logger.info("message_embeddings table initialized successfully")
# Create index for faster lookups
logger.info("Creating idx_timestamp index if not exists")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
"""
)
logger.info("idx_timestamp index created successfully")
logger.info("Creating idx_user_id index if not exists")
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
"""
)
logger.info("idx_user_id index created successfully")
conn.commit()
logger.info("Database initialization completed successfully")
conn.close()
def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text using OpenAI API."""
logger.debug(f"Generating embedding for text (length: {len(text)})")
try:
logger.info(f"Calling OpenAI API to generate embedding with model: {EMBEDDING_MODEL}")
response = self.client.embeddings.create(
model=EMBEDDING_MODEL, input=text, encoding_format="float"
)
logger.debug("OpenAI API response received successfully")
# The embedding is returned as a nested list: [[embedding_values]]
# We need to extract the inner list
embedding_data = response[0].embedding
@@ -89,26 +114,38 @@ class ChatDatabase:
first_item = embedding_data[0]
if isinstance(first_item, list):
# Handle nested structure: [[values]] -> [values]
logger.debug("Extracted embedding from nested structure [[values]]")
return first_item
else:
# Handle direct structure: [values]
logger.debug("Extracted embedding from direct structure [values]")
return embedding_data
logger.warning("Embedding data is empty or invalid")
return []
except Exception as e:
print(f"Error generating embedding: {e}")
logger.error(f"Error generating embedding: {e}")
return None
def _vector_to_bytes(self, vector: List[float]) -> bytes:
"""Convert vector to bytes for SQLite storage."""
return np.array(vector, dtype=np.float32).tobytes()
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
result = np.array(vector, dtype=np.float32).tobytes()
logger.debug(f"Vector converted to {len(result)} bytes")
return result
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
"""Convert bytes back to vector."""
return np.frombuffer(blob, dtype=np.float32)
logger.debug(f"Converting {len(blob)} bytes back to vector")
result = np.frombuffer(blob, dtype=np.float32)
logger.debug(f"Vector reconstructed with {len(result)} dimensions")
return result
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors."""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
logger.debug(f"Calculating cosine similarity between vectors of dimension {len(vec1)}")
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
logger.debug(f"Similarity calculated: {result:.4f}")
return result
def add_message(
self,
@@ -120,11 +157,13 @@ class ChatDatabase:
guild_id: Optional[str] = None,
) -> bool:
"""Add a message to the database and generate its embedding."""
logger.info(f"Adding message {message_id} from user {username}")
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
# Insert message
logger.debug(f"Inserting message into chat_messages table: message_id={message_id}")
cursor.execute(
"""
INSERT OR REPLACE INTO chat_messages
@@ -133,10 +172,13 @@ class ChatDatabase:
""",
(message_id, user_id, username, content, channel_id, guild_id),
)
logger.debug(f"Message {message_id} inserted into chat_messages table")
# Generate and store embedding
logger.info(f"Generating embedding for message {message_id}")
embedding = self._generate_embedding(content)
if embedding:
logger.debug(f"Embedding generated successfully for message {message_id}, storing in database")
cursor.execute(
"""
INSERT OR REPLACE INTO message_embeddings
@@ -145,15 +187,20 @@ class ChatDatabase:
""",
(message_id, self._vector_to_bytes(embedding)),
)
logger.debug(f"Embedding stored in message_embeddings table for message {message_id}")
else:
logger.warning(f"Failed to generate embedding for message {message_id}, skipping embedding storage")
# Clean up old messages if exceeding limit
logger.info("Checking if cleanup of old messages is needed")
self._cleanup_old_messages(cursor)
conn.commit()
logger.info(f"Successfully added message {message_id} to database")
return True
except Exception as e:
print(f"Error adding message: {e}")
logger.error(f"Error adding message {message_id}: {e}")
conn.rollback()
return False
finally: