fix linting, formatting, and add tests
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Vibe Discord Bot package."""
|
||||
|
||||
+62
-45
@@ -1,91 +1,108 @@
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
"""Configuration module for the vibe bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Discord
|
||||
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN", "")
|
||||
DISCORD_TOKEN: str = os.getenv("DISCORD_TOKEN", "")
|
||||
|
||||
# Endpoints
|
||||
CHAT_ENDPOINT = os.getenv("CHAT_ENDPOINT", "")
|
||||
COMPLETION_ENDPOINT = os.getenv("COMPLETION_ENDPOINT", "")
|
||||
IMAGE_GEN_ENDPOINT = os.getenv("IMAGE_GEN_ENDPOINT", "")
|
||||
IMAGE_EDIT_ENDPOINT = os.getenv("IMAGE_EDIT_ENDPOINT", "")
|
||||
EMBEDDING_ENDPOINT = os.getenv("EMBEDDING_ENDPOINT", "")
|
||||
MAX_COMPLETION_TOKENS = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
|
||||
CHAT_ENDPOINT: str = os.getenv("CHAT_ENDPOINT", "")
|
||||
COMPLETION_ENDPOINT: str = os.getenv("COMPLETION_ENDPOINT", "")
|
||||
IMAGE_GEN_ENDPOINT: str = os.getenv("IMAGE_GEN_ENDPOINT", "")
|
||||
IMAGE_EDIT_ENDPOINT: str = os.getenv("IMAGE_EDIT_ENDPOINT", "")
|
||||
EMBEDDING_ENDPOINT: str = os.getenv("EMBEDDING_ENDPOINT", "")
|
||||
MAX_COMPLETION_TOKENS: int = int(os.getenv("MAX_COMPLETION_TOKENS", "1000"))
|
||||
|
||||
# API Keys
|
||||
CHAT_ENDPOINT_KEY = os.getenv("CHAT_ENDPOINT_KEY", "placeholder")
|
||||
COMPLETION_ENDPOINT_KEY = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder")
|
||||
IMAGE_GEN_ENDPOINT_KEY = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder")
|
||||
IMAGE_EDIT_ENDPOINT_KEY = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder")
|
||||
EMBEDDING_ENDPOINT_KEY = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder")
|
||||
CHAT_ENDPOINT_KEY: str = os.getenv("CHAT_ENDPOINT_KEY", "placeholder")
|
||||
COMPLETION_ENDPOINT_KEY: str = os.getenv("COMPLETION_ENDPOINT_KEY", "placeholder")
|
||||
IMAGE_GEN_ENDPOINT_KEY: str = os.getenv("IMAGE_GEN_ENDPOINT_KEY", "placeholder")
|
||||
IMAGE_EDIT_ENDPOINT_KEY: str = os.getenv("IMAGE_EDIT_ENDPOINT_KEY", "placeholder")
|
||||
EMBEDDING_ENDPOINT_KEY: str = os.getenv("EMBEDDING_ENDPOINT_KEY", "placeholder")
|
||||
|
||||
# Models
|
||||
CHAT_MODEL = os.getenv("CHAT_MODEL", "")
|
||||
COMPLETION_MODEL = os.getenv("COMPLETION_MODEL", "")
|
||||
IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "")
|
||||
IMAGE_EDIT_MODEL = os.getenv("IMAGE_EDIT_MODEL", "")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "")
|
||||
CHAT_MODEL: str = os.getenv("CHAT_MODEL", "")
|
||||
COMPLETION_MODEL: str = os.getenv("COMPLETION_MODEL", "")
|
||||
IMAGE_GEN_MODEL: str = os.getenv("IMAGE_GEN_MODEL", "")
|
||||
IMAGE_EDIT_MODEL: str = os.getenv("IMAGE_EDIT_MODEL", "")
|
||||
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "")
|
||||
|
||||
# Database and embeddings
|
||||
DB_PATH = os.getenv("DB_PATH", "chat_history.db")
|
||||
EMBEDDING_DIMENSION = 2048
|
||||
MAX_HISTORY_MESSAGES = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
||||
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
||||
TOP_K_RESULTS = int(os.getenv("TOP_K_RESULTS", "5"))
|
||||
DB_PATH: str = os.getenv("DB_PATH", "chat_history.db")
|
||||
EMBEDDING_DIMENSION: int = 2048
|
||||
MAX_HISTORY_MESSAGES: int = int(os.getenv("MAX_HISTORY_MESSAGES", "1000"))
|
||||
SIMILARITY_THRESHOLD: float = float(os.getenv("SIMILARITY_THRESHOLD", "0.7"))
|
||||
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
||||
|
||||
# Check token
|
||||
if not DISCORD_TOKEN:
|
||||
raise Exception("DISCORD_TOKEN required.")
|
||||
msg = "DISCORD_TOKEN required."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Check endpoints
|
||||
if not CHAT_ENDPOINT:
|
||||
raise Exception("CHAT_ENDPOINT required.")
|
||||
endpoint_msg = "CHAT_ENDPOINT required."
|
||||
raise RuntimeError(endpoint_msg)
|
||||
|
||||
if not COMPLETION_ENDPOINT:
|
||||
raise Exception("COMPLETION_ENDPOINT required.")
|
||||
endpoint_msg = "COMPLETION_ENDPOINT required."
|
||||
raise RuntimeError(endpoint_msg)
|
||||
|
||||
if not IMAGE_GEN_ENDPOINT:
|
||||
raise Exception("IMAGE_GEN_ENDPOINT required.")
|
||||
endpoint_msg = "IMAGE_GEN_ENDPOINT required."
|
||||
raise RuntimeError(endpoint_msg)
|
||||
|
||||
if not IMAGE_EDIT_ENDPOINT:
|
||||
raise Exception("IMAGE_EDIT_ENDPOINT required.")
|
||||
endpoint_msg = "IMAGE_EDIT_ENDPOINT required."
|
||||
raise RuntimeError(endpoint_msg)
|
||||
|
||||
if not EMBEDDING_ENDPOINT:
|
||||
raise Exception("EMBEDDING_ENDPOINT required.")
|
||||
endpoint_msg = "EMBEDDING_ENDPOINT required."
|
||||
raise RuntimeError(endpoint_msg)
|
||||
|
||||
# Check models
|
||||
if not CHAT_MODEL:
|
||||
raise Exception("CHAT_MODEL required.")
|
||||
model_msg = "CHAT_MODEL required."
|
||||
raise RuntimeError(model_msg)
|
||||
|
||||
if not COMPLETION_MODEL:
|
||||
raise Exception("COMPLETION_MODEL required.")
|
||||
model_msg = "COMPLETION_MODEL required."
|
||||
raise RuntimeError(model_msg)
|
||||
|
||||
if not IMAGE_GEN_MODEL:
|
||||
raise Exception("IMAGE_GEN_MODEL required.")
|
||||
model_msg = "IMAGE_GEN_MODEL required."
|
||||
raise RuntimeError(model_msg)
|
||||
|
||||
if not IMAGE_EDIT_MODEL:
|
||||
raise Exception("IMAGE_EDIT_MODEL required.")
|
||||
model_msg = "IMAGE_EDIT_MODEL required."
|
||||
raise RuntimeError(model_msg)
|
||||
|
||||
if not EMBEDDING_MODEL:
|
||||
raise Exception("EMBEDDING_MODEL required.")
|
||||
model_msg = "EMBEDDING_MODEL required."
|
||||
raise RuntimeError(model_msg)
|
||||
|
||||
# TTS
|
||||
TTS_MODEL_PATH = os.getenv("TTS_MODEL_PATH", "kokoro-v1.0.onnx")
|
||||
TTS_VOICES_PATH = os.getenv("TTS_VOICES_PATH", "voices-v1.0.bin")
|
||||
TTS_VOICE = os.getenv("TTS_VOICE", "af_sarah")
|
||||
TTS_SPEED = float(os.getenv("TTS_SPEED", "1.0"))
|
||||
TTS_MODEL_PATH: str = os.getenv("TTS_MODEL_PATH", "kokoro-v1.0.onnx")
|
||||
TTS_VOICES_PATH: str = os.getenv("TTS_VOICES_PATH", "voices-v1.0.bin")
|
||||
TTS_VOICE: str = os.getenv("TTS_VOICE", "af_sarah")
|
||||
TTS_SPEED: float = float(os.getenv("TTS_SPEED", "1.0"))
|
||||
|
||||
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
|
||||
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
|
||||
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
|
||||
logger.info(f"IMAGE_EDIT_ENDPOINT set to {IMAGE_EDIT_ENDPOINT}")
|
||||
logger.info(f"EMBEDDING_ENDPOINT set to {EMBEDDING_ENDPOINT}")
|
||||
logger.info("CHAT_ENDPOINT set to %s", CHAT_ENDPOINT)
|
||||
logger.info("COMPLETION_ENDPOINT set to %s", COMPLETION_ENDPOINT)
|
||||
logger.info("IMAGE_GEN_ENDPOINT set to %s", IMAGE_GEN_ENDPOINT)
|
||||
logger.info("IMAGE_EDIT_ENDPOINT set to %s", IMAGE_EDIT_ENDPOINT)
|
||||
logger.info("EMBEDDING_ENDPOINT set to %s", EMBEDDING_ENDPOINT)
|
||||
|
||||
+138
-95
@@ -1,43 +1,60 @@
|
||||
"""SQLite database with RAG support for chat history and embeddings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import Optional, List, Tuple
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
import logging
|
||||
|
||||
import llama_wrapper # type: ignore
|
||||
from config import ( # type: ignore
|
||||
from vibe_bot import llama_wrapper
|
||||
from vibe_bot.config import (
|
||||
DB_PATH,
|
||||
EMBEDDING_MODEL,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
EMBEDDING_MODEL,
|
||||
MAX_HISTORY_MESSAGES,
|
||||
SIMILARITY_THRESHOLD,
|
||||
TOP_K_RESULTS,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDatabase:
|
||||
"""SQLite database with RAG support for storing chat history using OpenAI embeddings."""
|
||||
"""SQLite database with RAG support for storing chat history
|
||||
using OpenAI embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
logger.info(f"Initializing ChatDatabase with path: {db_path}")
|
||||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||
"""Initialize the database connection.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file.
|
||||
|
||||
"""
|
||||
logger.info("Initializing ChatDatabase with path: %s", db_path)
|
||||
self.db_path = db_path
|
||||
self.client = OpenAI(
|
||||
base_url=EMBEDDING_ENDPOINT, api_key=EMBEDDING_ENDPOINT_KEY
|
||||
base_url=EMBEDDING_ENDPOINT,
|
||||
api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
)
|
||||
logger.info("Connecting to OpenAI API for embeddings")
|
||||
self._initialize_database()
|
||||
|
||||
def _initialize_database(self):
|
||||
def _initialize_database(self) -> None:
|
||||
"""Initialize the SQLite database with required tables."""
|
||||
logger.info(f"Initializing SQLite database at {self.db_path}")
|
||||
logger.info("Initializing SQLite database at %s", self.db_path)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -55,7 +72,7 @@ class ChatDatabase:
|
||||
channel_id TEXT,
|
||||
guild_id TEXT
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("chat_messages table initialized successfully")
|
||||
|
||||
@@ -68,7 +85,7 @@ class ChatDatabase:
|
||||
embedding BLOB,
|
||||
FOREIGN KEY (message_id) REFERENCES chat_messages(message_id)
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("message_embeddings table initialized successfully")
|
||||
|
||||
@@ -77,7 +94,7 @@ class ChatDatabase:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_timestamp ON chat_messages(timestamp)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("idx_timestamp index created successfully")
|
||||
|
||||
@@ -85,7 +102,7 @@ class ChatDatabase:
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_user_id ON chat_messages(user_id)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
logger.info("idx_user_id index created successfully")
|
||||
|
||||
@@ -93,60 +110,65 @@ class ChatDatabase:
|
||||
logger.info("Database initialization completed successfully")
|
||||
conn.close()
|
||||
|
||||
def _vector_to_bytes(self, vector: List[float]) -> bytes:
|
||||
def _vector_to_bytes(self, vector: list[float]) -> bytes:
|
||||
"""Convert vector to bytes for SQLite storage."""
|
||||
logger.debug(f"Converting vector (length: {len(vector)}) to bytes")
|
||||
logger.debug("Converting vector (length: %d) to bytes", len(vector))
|
||||
result = np.array(vector, dtype=np.float32).tobytes()
|
||||
logger.debug(f"Vector converted to {len(result)} bytes")
|
||||
logger.debug("Vector converted to %d bytes", len(result))
|
||||
return result
|
||||
|
||||
def _bytes_to_vector(self, blob: bytes) -> np.ndarray:
|
||||
"""Convert bytes back to vector."""
|
||||
logger.debug(f"Converting {len(blob)} bytes back to vector")
|
||||
logger.debug("Converting %d bytes back to vector", len(blob))
|
||||
result = np.frombuffer(blob, dtype=np.float32)
|
||||
logger.debug(f"Vector reconstructed with {len(result)} dimensions")
|
||||
logger.debug("Vector reconstructed with %d dimensions", len(result))
|
||||
return result
|
||||
|
||||
def _calculate_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
logger.debug(
|
||||
f"Calculating cosine similarity between vectors of dimension {len(vec1)}"
|
||||
"Calculating cosine similarity between vectors of dimension %d",
|
||||
len(vec1),
|
||||
)
|
||||
result = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
logger.debug(f"Similarity calculated: {result:.4f}")
|
||||
result = float(
|
||||
np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)),
|
||||
)
|
||||
logger.debug("Similarity calculated: %.4f", result)
|
||||
return result
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
*,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
username: str,
|
||||
content: str,
|
||||
channel_id: Optional[str] = None,
|
||||
guild_id: Optional[str] = None,
|
||||
channel_id: str | None = None,
|
||||
guild_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Add a message to the database and generate its embedding."""
|
||||
logger.info(f"Adding message {message_id} from user {username}")
|
||||
logger.info("Adding message %s from user %s", message_id, username)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Insert message
|
||||
logger.debug(
|
||||
f"Inserting message into chat_messages table: message_id={message_id}"
|
||||
"Inserting message into chat_messages table: message_id=%s",
|
||||
message_id,
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO chat_messages
|
||||
INSERT OR REPLACE INTO chat_messages
|
||||
(message_id, user_id, username, content, channel_id, guild_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(message_id, user_id, username, content, channel_id, guild_id),
|
||||
)
|
||||
logger.debug(f"Message {message_id} inserted into chat_messages table")
|
||||
logger.debug("Message %s inserted into chat_messages table", message_id)
|
||||
|
||||
# Generate and store embedding
|
||||
logger.info(f"Generating embedding for message {message_id}")
|
||||
logger.info("Generating embedding for message %s", message_id)
|
||||
embedding = llama_wrapper.embedding(
|
||||
content,
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
@@ -155,22 +177,27 @@ class ChatDatabase:
|
||||
)
|
||||
if embedding:
|
||||
logger.debug(
|
||||
f"Embedding generated successfully for message {message_id}, storing in database"
|
||||
"Embedding generated successfully for message %s, "
|
||||
"storing in database",
|
||||
message_id,
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO message_embeddings
|
||||
INSERT OR REPLACE INTO message_embeddings
|
||||
(message_id, embedding)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(message_id, self._vector_to_bytes(embedding)),
|
||||
)
|
||||
logger.debug(
|
||||
f"Embedding stored in message_embeddings table for message {message_id}"
|
||||
"Embedding stored in message_embeddings table for message %s",
|
||||
message_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to generate embedding for message {message_id}, skipping embedding storage"
|
||||
"Failed to generate embedding for message %s, "
|
||||
"skipping embedding storage",
|
||||
message_id,
|
||||
)
|
||||
|
||||
# Clean up old messages if exceeding limit
|
||||
@@ -178,32 +205,32 @@ class ChatDatabase:
|
||||
self._cleanup_old_messages(cursor)
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Successfully added message {message_id} to database")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message {message_id}: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error adding message %s", message_id)
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
logger.info("Successfully added message %s to database", message_id)
|
||||
return True
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _cleanup_old_messages(self, cursor):
|
||||
def _cleanup_old_messages(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Remove old messages to stay within the limit."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_messages
|
||||
"""
|
||||
""",
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
if count > MAX_HISTORY_MESSAGES:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM chat_messages
|
||||
DELETE FROM chat_messages
|
||||
WHERE id IN (
|
||||
SELECT id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
SELECT id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
@@ -213,10 +240,10 @@ class ChatDatabase:
|
||||
# Also remove corresponding embeddings
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM message_embeddings
|
||||
DELETE FROM message_embeddings
|
||||
WHERE message_id IN (
|
||||
SELECT message_id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
SELECT message_id FROM chat_messages
|
||||
ORDER BY timestamp ASC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
@@ -224,8 +251,9 @@ class ChatDatabase:
|
||||
)
|
||||
|
||||
def get_recent_messages(
|
||||
self, limit: int = 10
|
||||
) -> List[Tuple[str, str, str, datetime]]:
|
||||
self,
|
||||
limit: int = 10,
|
||||
) -> list[tuple[str, str, str, datetime]]:
|
||||
"""Get recent messages from the database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -250,7 +278,7 @@ class ChatDatabase:
|
||||
query: str,
|
||||
top_k: int = TOP_K_RESULTS,
|
||||
min_similarity: float = SIMILARITY_THRESHOLD,
|
||||
) -> List[Tuple[str, str, float]]:
|
||||
) -> list[tuple[str, str, float]]:
|
||||
"""Search for messages similar to the query using embeddings."""
|
||||
query_embedding = llama_wrapper.embedding(
|
||||
text=query,
|
||||
@@ -269,11 +297,11 @@ class ChatDatabase:
|
||||
# Join chat_messages and message_embeddings to get content and embeddings
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT cm.message_id, cm.content, me.embedding
|
||||
SELECT cm.message_id, cm.content, me.embedding
|
||||
FROM chat_messages cm
|
||||
JOIN message_embeddings me ON cm.message_id = me.message_id
|
||||
WHERE cm.username != 'vibe-bot'
|
||||
"""
|
||||
""",
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
@@ -302,12 +330,12 @@ class ChatDatabase:
|
||||
results.sort(key=lambda x: x[2], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||
def get_user_history(self, _user_id: str, limit: int = 20) -> list[tuple[str, str]]:
|
||||
"""Get message history for a specific user."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
logger.info(f"Fetching last {limit} user messages")
|
||||
logger.info("Fetching last %d user messages", limit)
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, content, timestamp
|
||||
@@ -324,8 +352,8 @@ class ChatDatabase:
|
||||
# Format is [user message, bot response]
|
||||
conversations: list[tuple[str, str]] = []
|
||||
for message in messages:
|
||||
msg_content: str = message[1]
|
||||
logger.info(f"Finding response for {msg_content[:50]}")
|
||||
msg_content = message[1]
|
||||
logger.debug("Finding response for %s...", msg_content[:50])
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT content
|
||||
@@ -335,18 +363,21 @@ class ChatDatabase:
|
||||
""",
|
||||
(f"{message[0]}_response",),
|
||||
)
|
||||
response_content: str = cursor.fetchone()
|
||||
if response_content:
|
||||
logger.info(f"Found response: {response_content[0][:50]}")
|
||||
conversations.append((msg_content, response_content[0]))
|
||||
response_row = cursor.fetchone()
|
||||
if response_row:
|
||||
logger.debug("Found response: %s...", response_row[0][:50])
|
||||
conversations.append((msg_content, response_row[0]))
|
||||
else:
|
||||
logger.info("No response found")
|
||||
logger.debug("No response found")
|
||||
conn.close()
|
||||
|
||||
return conversations
|
||||
|
||||
def get_conversation_context(
|
||||
self, user_id: str, current_message: str, max_context: int = 5
|
||||
self,
|
||||
user_id: str,
|
||||
current_message: str,
|
||||
max_context: int = 5,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get relevant conversation context for RAG."""
|
||||
# Get recent messages from the user
|
||||
@@ -354,7 +385,8 @@ class ChatDatabase:
|
||||
|
||||
# Search for similar messages
|
||||
similar_messages = self.search_similar_messages(
|
||||
current_message, top_k=max_context
|
||||
current_message,
|
||||
top_k=max_context,
|
||||
)
|
||||
|
||||
# Combine contexts
|
||||
@@ -366,7 +398,7 @@ class ChatDatabase:
|
||||
context_parts.append({"role": "user", "content": user_message})
|
||||
|
||||
# Add similar messages
|
||||
for user_message, bot_message, similarity in similar_messages:
|
||||
for user_message, bot_message, _similarity in similar_messages:
|
||||
context_parts.append({"role": "assistant", "content": bot_message})
|
||||
context_parts.append({"role": "user", "content": user_message})
|
||||
|
||||
@@ -374,7 +406,7 @@ class ChatDatabase:
|
||||
context_parts.reverse()
|
||||
return context_parts
|
||||
|
||||
def clear_all_messages(self):
|
||||
def clear_all_messages(self) -> None:
|
||||
"""Clear all messages and embeddings from the database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -387,12 +419,12 @@ class ChatDatabase:
|
||||
|
||||
|
||||
# Global database instance
|
||||
_chat_db: Optional[ChatDatabase] = None
|
||||
_chat_db: ChatDatabase | None = None
|
||||
|
||||
|
||||
def get_database() -> ChatDatabase:
|
||||
"""Get or create the global database instance."""
|
||||
global _chat_db
|
||||
global _chat_db # noqa: PLW0603
|
||||
if _chat_db is None:
|
||||
_chat_db = ChatDatabase()
|
||||
return _chat_db
|
||||
@@ -401,11 +433,17 @@ def get_database() -> ChatDatabase:
|
||||
class CustomBotManager:
|
||||
"""Manages custom bot configurations stored in SQLite database."""
|
||||
|
||||
def __init__(self, db_path: str = DB_PATH):
|
||||
def __init__(self, db_path: str = DB_PATH) -> None:
|
||||
"""Initialize the custom bot manager.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file.
|
||||
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._initialize_custom_bots_table()
|
||||
|
||||
def _initialize_custom_bots_table(self):
|
||||
def _initialize_custom_bots_table(self) -> None:
|
||||
"""Initialize the custom bots table in SQLite."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -420,14 +458,17 @@ class CustomBotManager:
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active INTEGER DEFAULT 1
|
||||
)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create_custom_bot(
|
||||
self, bot_name: str, system_prompt: str, created_by: str
|
||||
self,
|
||||
bot_name: str,
|
||||
system_prompt: str,
|
||||
created_by: str,
|
||||
) -> bool:
|
||||
"""Create a new custom bot configuration."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
@@ -436,7 +477,7 @@ class CustomBotManager:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO custom_bots
|
||||
INSERT OR REPLACE INTO custom_bots
|
||||
(bot_name, system_prompt, created_by, is_active)
|
||||
VALUES (?, ?, ?, 1)
|
||||
""",
|
||||
@@ -444,16 +485,16 @@ class CustomBotManager:
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating custom bot: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error creating custom bot")
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_custom_bot(self, bot_name: str) -> Optional[Tuple[str, str, str, datetime]]:
|
||||
def get_custom_bot(self, bot_name: str) -> tuple[str, str, str, datetime] | None:
|
||||
"""Get a custom bot configuration by name."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -470,11 +511,14 @@ class CustomBotManager:
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
return result
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1], result[2], result[3])
|
||||
|
||||
def list_custom_bots(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
) -> list[tuple[str, str, str]]:
|
||||
"""List all custom bots, optionally filtered by creator."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
@@ -482,12 +526,11 @@ class CustomBotManager:
|
||||
if user_id:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT bot_name, system_prompt, name
|
||||
FROM custom_bots cb, username_map um
|
||||
JOIN username_map ON custom_bots.created_by = username_map.id
|
||||
SELECT bot_name, system_prompt, created_by
|
||||
FROM custom_bots
|
||||
WHERE is_active = 1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
""",
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
@@ -496,7 +539,7 @@ class CustomBotManager:
|
||||
FROM custom_bots
|
||||
WHERE is_active = 1
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
bots = cursor.fetchall()
|
||||
@@ -519,12 +562,12 @@ class CustomBotManager:
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deleting custom bot: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error deleting custom bot")
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -544,11 +587,11 @@ class CustomBotManager:
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error deactivating custom bot: {e}")
|
||||
except Exception:
|
||||
logger.exception("Error deactivating custom bot")
|
||||
conn.rollback()
|
||||
return False
|
||||
else:
|
||||
return cursor.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
+137
-42
@@ -1,23 +1,46 @@
|
||||
# Wraps the openai calls in generic functions
|
||||
# Supports chat, image, edit, and embeddings
|
||||
# Allows custom endpoints for each of the above supported functions
|
||||
"""Wraps the openai calls in generic functions.
|
||||
|
||||
Supports chat, image, edit, and embeddings.
|
||||
Allows custom endpoints for each of the above supported functions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import openai
|
||||
from typing import Iterable
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from io import BufferedReader, BytesIO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BufferedReader, BytesIO
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
|
||||
def chat_completion(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
"""Send a chat completion request and return the response.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
user_prompt: The user prompt to send.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
The model's response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
@@ -28,35 +51,51 @@ def chat_completion(
|
||||
},
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model=model, messages=messages, max_tokens=max_tokens
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Assert that thinking was used
|
||||
if response.choices[0].message.model_extra:
|
||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def chat_completion_with_history(
|
||||
system_prompt: str,
|
||||
prompts: Iterable[ChatCompletionMessageParam],
|
||||
prompts: list[dict[str, str]],
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
"""Send a chat completion request with conversation history.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
prompts: List of prompt dicts with role and content.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
The model's response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
] + prompts # type: ignore
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
cast(
|
||||
"ChatCompletionMessageParam",
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
),
|
||||
]
|
||||
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
@@ -67,20 +106,34 @@ def chat_completion_with_history(
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def chat_completion_instruct(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
) -> str:
|
||||
"""Send an instruction-based chat completion request.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
user_prompt: The user prompt to send.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
|
||||
Returns:
|
||||
The model's response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: Iterable[ChatCompletionMessageParam] = [
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
@@ -100,26 +153,37 @@ def chat_completion_instruct(
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def image_generation(prompt: str, openai_url: str, openai_api_key: str, n=1) -> str:
|
||||
"""Generates an image using the given prompt and returns the base64 encoded image data
|
||||
def image_generation(
|
||||
prompt: str,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
n: int = 1,
|
||||
) -> str:
|
||||
"""Generate an image using the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The image generation prompt.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
n: Number of images to generate.
|
||||
|
||||
Returns:
|
||||
str: The base64 encoded image data. Decode and write to a file.
|
||||
The base64 encoded image data. Decode and write to a file.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
response = client.images.generate(
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size="1024x1024",
|
||||
model="gen",
|
||||
)
|
||||
if response.data:
|
||||
return response.data[0].b64_json or ""
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def image_edit(
|
||||
@@ -127,33 +191,64 @@ def image_edit(
|
||||
prompt: str,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
n=1,
|
||||
n: int = 1,
|
||||
) -> str:
|
||||
"""Edit an existing image using a prompt.
|
||||
|
||||
Args:
|
||||
image: The source image as a file-like object or list thereof.
|
||||
prompt: The edit instruction.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
n: Number of edited images to generate.
|
||||
|
||||
Returns:
|
||||
The base64 encoded edited image data.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
response = client.images.edit(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
n=n,
|
||||
size="1024x1024",
|
||||
model="edit",
|
||||
)
|
||||
if response.data:
|
||||
return response.data[0].b64_json or ""
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def embedding(
|
||||
text: str, openai_url: str, openai_api_key: str, model: str
|
||||
text: str,
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
) -> list[float]:
|
||||
"""Generate an embedding vector for the given text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The embedding model to use.
|
||||
|
||||
Returns:
|
||||
The embedding vector as a list of floats, or an empty list on failure.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
response = client.embeddings.create(
|
||||
input=[text], model=model, encoding_format="float"
|
||||
input=[text],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
if response:
|
||||
raw_data = response[0].embedding # type: ignore
|
||||
# The result could be an array of floats or an array of an array of floats.
|
||||
try:
|
||||
return raw_data[0]
|
||||
except Exception:
|
||||
return raw_data
|
||||
data = response.data
|
||||
raw_data = data[0].embedding
|
||||
# The result could be an array of floats or a single float.
|
||||
if not isinstance(raw_data, float):
|
||||
return list(raw_data)
|
||||
return [raw_data]
|
||||
return []
|
||||
|
||||
+354
-216
@@ -1,33 +1,41 @@
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import os
|
||||
"""Main Discord bot application."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
from openai import OpenAI
|
||||
import logging
|
||||
from database import get_database, CustomBotManager # type: ignore
|
||||
from config import ( # type: ignore
|
||||
CHAT_ENDPOINT_KEY,
|
||||
DISCORD_TOKEN,
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
import requests
|
||||
from discord import Message
|
||||
from discord.ext import commands
|
||||
|
||||
from vibe_bot import llama_wrapper, tts
|
||||
from vibe_bot.config import (
|
||||
CHAT_ENDPOINT,
|
||||
CHAT_ENDPOINT_KEY,
|
||||
CHAT_MODEL,
|
||||
IMAGE_EDIT_ENDPOINT_KEY,
|
||||
IMAGE_GEN_ENDPOINT,
|
||||
DISCORD_TOKEN,
|
||||
IMAGE_EDIT_ENDPOINT,
|
||||
IMAGE_EDIT_ENDPOINT_KEY,
|
||||
MAX_COMPLETION_TOKENS,
|
||||
TTS_MODEL_PATH,
|
||||
TTS_VOICES_PATH,
|
||||
TTS_VOICE,
|
||||
TTS_SPEED,
|
||||
TTS_VOICE,
|
||||
TTS_VOICES_PATH,
|
||||
)
|
||||
import tts # type: ignore
|
||||
import llama_wrapper # type: ignore
|
||||
import requests
|
||||
from vibe_bot.database import CustomBotManager, get_database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.ext.commands import Bot
|
||||
from discord.ext.commands import Context as CommandsContext
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,86 +45,123 @@ intents.message_content = True
|
||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||
|
||||
# Initialize TTS engine
|
||||
tts_engine: tts.TTSEngine | None = None
|
||||
try:
|
||||
tts_engine = tts.TTSEngine(TTS_MODEL_PATH, TTS_VOICES_PATH)
|
||||
logger.info("TTS engine initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TTS engine: {e}")
|
||||
logger.info("Make sure kokoro-v1.0.onnx and voices-v1.0.bin are in the project directory")
|
||||
tts_engine = None
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize TTS engine")
|
||||
logger.info(
|
||||
"Make sure kokoro-v1.0.onnx and voices-v1.0.bin are in the project directory",
|
||||
)
|
||||
|
||||
# Name and personality validation constants
|
||||
MIN_BOT_NAME_LENGTH = 2
|
||||
MAX_BOT_NAME_LENGTH = 50
|
||||
MIN_PERSONALITY_LENGTH = 10
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
async def on_ready() -> None:
|
||||
"""Log when the bot is ready and logged in."""
|
||||
logger.info("Bot is starting up...")
|
||||
print(f"Bot logged in as {bot.user}")
|
||||
logger.info(f"Bot logged in as {bot.user}")
|
||||
logger.info("Bot logged in as %s", bot.user)
|
||||
|
||||
|
||||
@bot.command(name="custom-bot") # type: ignore
|
||||
async def custom_bot(ctx, bot_name: str, *, personality: str):
|
||||
"""Create a custom bot with a name and personality
|
||||
@bot.command(name="custom-bot")
|
||||
async def custom_bot(
|
||||
ctx: CommandsContext[Bot],
|
||||
bot_name: str,
|
||||
*,
|
||||
personality: str,
|
||||
) -> None:
|
||||
"""Create a custom bot with a name and personality.
|
||||
|
||||
Usage: !custom-bot <bot_name> <personality_description>
|
||||
Example: !custom-bot alfred you are a proper british butler
|
||||
"""
|
||||
logger.info(
|
||||
f"Custom bot command initiated by {ctx.author.name}: name='{bot_name}', personality length={len(personality)}"
|
||||
"Custom bot command initiated by %s: name=%r, personality length=%d",
|
||||
ctx.author.name,
|
||||
bot_name,
|
||||
len(personality),
|
||||
)
|
||||
|
||||
# Validate bot name
|
||||
if not bot_name or len(bot_name) < 2 or len(bot_name) > 50:
|
||||
name_length = 0 if not bot_name else len(bot_name)
|
||||
if (
|
||||
not bot_name
|
||||
or name_length < MIN_BOT_NAME_LENGTH
|
||||
or name_length > MAX_BOT_NAME_LENGTH
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid bot name from {ctx.author.name}: '{bot_name}' (length: {len(bot_name) if bot_name else 0})"
|
||||
"Invalid bot name from %s: %r (length: %d)",
|
||||
ctx.author.name,
|
||||
bot_name,
|
||||
name_length,
|
||||
)
|
||||
await ctx.send("❌ Invalid bot name. Name must be between 2 and 50 characters.")
|
||||
await ctx.send("Invalid bot name. Name must be between 2 and 50 characters.")
|
||||
return
|
||||
|
||||
logger.info(f"Bot name validation passed for '{bot_name}'")
|
||||
logger.info("Bot name validation passed for %r", bot_name)
|
||||
|
||||
# Validate personality
|
||||
if not personality or len(personality) < 10:
|
||||
personality_length = 0 if not personality else len(personality)
|
||||
if not personality or personality_length < MIN_PERSONALITY_LENGTH:
|
||||
logger.warning(
|
||||
f"Invalid personality from {ctx.author.name}: length={len(personality) if personality else 0}"
|
||||
"Invalid personality from %s: length=%d",
|
||||
ctx.author.name,
|
||||
personality_length,
|
||||
)
|
||||
await ctx.send(
|
||||
"❌ Invalid personality. Description must be at least 10 characters."
|
||||
"Invalid personality. Description must be at least 10 characters.",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"Personality validation passed for bot '{bot_name}'")
|
||||
logger.info("Personality validation passed for bot %r", bot_name)
|
||||
|
||||
# Create custom bot manager
|
||||
logger.info(f"Initializing CustomBotManager for user {ctx.author.name}")
|
||||
logger.info("Initializing CustomBotManager for user %s", ctx.author.name)
|
||||
custom_bot_manager = CustomBotManager()
|
||||
|
||||
# Create the custom bot
|
||||
logger.info(
|
||||
f"Attempting to create custom bot '{bot_name}' for user {ctx.author.name}"
|
||||
"Attempting to create custom bot %r for user %s",
|
||||
bot_name,
|
||||
ctx.author.name,
|
||||
)
|
||||
success = custom_bot_manager.create_custom_bot(
|
||||
bot_name=bot_name, system_prompt=personality, created_by=str(ctx.author.id)
|
||||
bot_name=bot_name,
|
||||
system_prompt=personality,
|
||||
created_by=str(ctx.author.id),
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
f"Successfully created custom bot '{bot_name}' for user {ctx.author.name}"
|
||||
"Successfully created custom bot %r for user %s",
|
||||
bot_name,
|
||||
ctx.author.name,
|
||||
)
|
||||
await ctx.send(
|
||||
f"✅ Custom bot **'{bot_name}'** has been created with personality: *{personality}*"
|
||||
f"Custom bot **'{bot_name}'** has been created "
|
||||
f"with personality: *{personality}*",
|
||||
)
|
||||
await ctx.send(
|
||||
f"\nYou can now use this bot with: " f"`!{bot_name} <your message>`",
|
||||
)
|
||||
await ctx.send(f"\nYou can now use this bot with: `!{bot_name} <your message>`")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to create custom bot '{bot_name}' for user {ctx.author.name}"
|
||||
"Failed to create custom bot %r for user %s",
|
||||
bot_name,
|
||||
ctx.author.name,
|
||||
)
|
||||
await ctx.send("❌ Failed to create custom bot. It may already exist.")
|
||||
await ctx.send("Failed to create custom bot. It may already exist.")
|
||||
|
||||
|
||||
@bot.command(name="list-custom-bots")
|
||||
async def list_custom_bots(ctx):
|
||||
"""List all custom bots available in the server"""
|
||||
logger.info(f"Listing custom bots requested by {ctx.author.name}")
|
||||
async def list_custom_bots(ctx: CommandsContext[Bot]) -> None:
|
||||
"""List all custom bots available in the server."""
|
||||
logger.info("Listing custom bots requested by %s", ctx.author.name)
|
||||
|
||||
# Create custom bot manager
|
||||
logger.info("Initializing CustomBotManager to list custom bots")
|
||||
@@ -126,31 +171,36 @@ async def list_custom_bots(ctx):
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
|
||||
if not bots:
|
||||
logger.info(f"No custom bots found for user {ctx.author.name}")
|
||||
logger.info("No custom bots found for user %s", ctx.author.name)
|
||||
await ctx.send(
|
||||
"No custom bots have been created yet. Use `!custom-bot <name> <personality>` to create one."
|
||||
"No custom bots have been created yet. "
|
||||
"Use `!custom-bot <name> <personality>` to create one.",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Found {len(bots)} custom bots, displaying top 10 for {ctx.author.name}"
|
||||
"Found %d custom bots, displaying top 10 for %s",
|
||||
len(bots),
|
||||
ctx.author.name,
|
||||
)
|
||||
bot_list = "🤖 **Available Custom Bots**:\n\n"
|
||||
for name, prompt, creator in bots:
|
||||
bot_list += f"• **{name}**\n"
|
||||
bot_list = "Available Custom Bots:\n\n"
|
||||
for name, _prompt, _creator in bots:
|
||||
bot_list += f"* {name}\n"
|
||||
|
||||
logger.info(f"Sending bot list response to {ctx.author.name}")
|
||||
logger.info("Sending bot list response to %s", ctx.author.name)
|
||||
await ctx.send(bot_list)
|
||||
|
||||
|
||||
@bot.command(name="delete-custom-bot") # type: ignore
|
||||
async def delete_custom_bot(ctx, bot_name: str):
|
||||
"""Delete a custom bot (only the creator can delete)
|
||||
@bot.command(name="delete-custom-bot")
|
||||
async def delete_custom_bot(ctx: CommandsContext[Bot], bot_name: str) -> None:
|
||||
"""Delete a custom bot (only the creator can delete).
|
||||
|
||||
Usage: !delete-custom-bot <bot_name>
|
||||
"""
|
||||
logger.info(
|
||||
f"Delete custom bot command initiated by {ctx.author.name}: bot_name='{bot_name}'"
|
||||
"Delete custom bot command initiated by %s: bot_name=%r",
|
||||
ctx.author.name,
|
||||
bot_name,
|
||||
)
|
||||
|
||||
# Create custom bot manager
|
||||
@@ -158,45 +208,64 @@ async def delete_custom_bot(ctx, bot_name: str):
|
||||
custom_bot_manager = CustomBotManager()
|
||||
|
||||
# Get bot info
|
||||
logger.info(f"Looking up custom bot '{bot_name}' in database")
|
||||
logger.info("Looking up custom bot %r in database", bot_name)
|
||||
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||
|
||||
if not bot_info:
|
||||
logger.warning(f"Custom bot '{bot_name}' not found by user {ctx.author.name}")
|
||||
await ctx.send(f"❌ Custom bot '{bot_name}' not found.")
|
||||
logger.warning(
|
||||
"Custom bot %r not found by user %s",
|
||||
bot_name,
|
||||
ctx.author.name,
|
||||
)
|
||||
await ctx.send(f"Custom bot '{bot_name}' not found.")
|
||||
return
|
||||
|
||||
logger.info(f"Custom bot '{bot_name}' found, owned by user {bot_info[2]}")
|
||||
logger.info(
|
||||
"Custom bot %r found, owned by user %s",
|
||||
bot_name,
|
||||
bot_info[2],
|
||||
)
|
||||
|
||||
# Check ownership
|
||||
if bot_info[2] != str(ctx.author.id):
|
||||
logger.warning(
|
||||
f"User {ctx.author.name} attempted to delete bot '{bot_name}' they don't own"
|
||||
"User %s attempted to delete bot %r they don't own",
|
||||
ctx.author.name,
|
||||
bot_name,
|
||||
)
|
||||
await ctx.send("❌ You can only delete your own custom bots.")
|
||||
await ctx.send("You can only delete your own custom bots.")
|
||||
return
|
||||
|
||||
logger.info(f"User {ctx.author.name} is authorized to delete bot '{bot_name}'")
|
||||
logger.info(
|
||||
"User %s is authorized to delete bot %r",
|
||||
ctx.author.name,
|
||||
bot_name,
|
||||
)
|
||||
|
||||
# Delete the bot
|
||||
logger.info(f"Deleting custom bot '{bot_name}' from database")
|
||||
logger.info("Deleting custom bot %r from database", bot_name)
|
||||
success = custom_bot_manager.delete_custom_bot(bot_name)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
f"Successfully deleted custom bot '{bot_name}' by user {ctx.author.name}"
|
||||
"Successfully deleted custom bot %r by user %s",
|
||||
bot_name,
|
||||
ctx.author.name,
|
||||
)
|
||||
await ctx.send(f"✅ Custom bot '{bot_name}' has been deleted.")
|
||||
await ctx.send(f"Custom bot '{bot_name}' has been deleted.")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to delete custom bot '{bot_name}' by user {ctx.author.name}"
|
||||
"Failed to delete custom bot %r by user %s",
|
||||
bot_name,
|
||||
ctx.author.name,
|
||||
)
|
||||
await ctx.send("❌ Failed to delete custom bot.")
|
||||
await ctx.send("Failed to delete custom bot.")
|
||||
|
||||
|
||||
# Handle custom bot commands
|
||||
@bot.event
|
||||
async def on_message(message):
|
||||
async def on_message(message: Message) -> None:
|
||||
"""Handle incoming messages for custom bot command detection."""
|
||||
# Skip bot messages
|
||||
if message.author == bot.user:
|
||||
return
|
||||
@@ -205,7 +274,9 @@ async def on_message(message):
|
||||
message_content = message.content.lower()
|
||||
|
||||
logger.debug(
|
||||
f"Processing message from {message_author}: '{message_content[:50]}...'"
|
||||
"Processing message from %s: %r...",
|
||||
message_author,
|
||||
message_content[:50],
|
||||
)
|
||||
|
||||
ctx = await bot.get_context(message)
|
||||
@@ -216,24 +287,28 @@ async def on_message(message):
|
||||
logger.info("Fetching list of custom bots to check for matching commands")
|
||||
custom_bots = custom_bot_manager.list_custom_bots()
|
||||
|
||||
logger.info(f"Checking {len(custom_bots)} custom bots for command match")
|
||||
logger.info("Checking %d custom bots for command match", len(custom_bots))
|
||||
for bot_name, system_prompt, _ in custom_bots:
|
||||
# Check if message starts with the custom bot name followed by a space
|
||||
if message_content.startswith(f"!{bot_name} "):
|
||||
logger.info(
|
||||
f"Custom bot command detected: '{bot_name}' triggered by {message.author.name}"
|
||||
"Custom bot command detected: %r triggered by %s",
|
||||
bot_name,
|
||||
message.author.name,
|
||||
)
|
||||
|
||||
# Extract the actual message (remove the bot name prefix)
|
||||
user_message = message.content[len(f"!{bot_name} ") :]
|
||||
logger.debug(
|
||||
f"Extracted user message for bot '{bot_name}': '{user_message[:50]}...'"
|
||||
"Extracted user message for bot %r: %r...",
|
||||
bot_name,
|
||||
user_message[:50],
|
||||
)
|
||||
|
||||
# Prepare the payload with custom personality
|
||||
response_prefix = f"**{bot_name} response**"
|
||||
response_prefix = f"{bot_name} response"
|
||||
|
||||
logger.info(f"Sending request to OpenAI API for bot '{bot_name}'")
|
||||
logger.info("Sending request to OpenAI API for bot %r", bot_name)
|
||||
await handle_chat(
|
||||
ctx=ctx,
|
||||
bot_name=bot_name,
|
||||
@@ -248,8 +323,8 @@ async def on_message(message):
|
||||
|
||||
|
||||
@bot.command(name="speak")
|
||||
async def speak(ctx, *, message: str):
|
||||
"""Have the bot speak the given text using Kokoro TTS, or have a custom bot speak
|
||||
async def speak(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||
"""Have the bot speak the given text using Kokoro TTS, or have a custom bot speak.
|
||||
|
||||
Usage: !speak <text> - plain text to speech
|
||||
Usage: !speak <bot_name> <text> - have a custom bot respond and speak
|
||||
@@ -257,113 +332,149 @@ async def speak(ctx, *, message: str):
|
||||
Example: !speak alfred what time is it
|
||||
"""
|
||||
if tts_engine is None:
|
||||
await ctx.send("❌ TTS engine not initialized. Make sure kokoro-v1.0.onnx and voices-v1.0.bin are present.")
|
||||
await ctx.send(
|
||||
"TTS engine not initialized. "
|
||||
"Make sure kokoro-v1.0.onnx and voices-v1.0.bin are present.",
|
||||
)
|
||||
return
|
||||
|
||||
if not message or len(message.strip()) == 0:
|
||||
await ctx.send("❌ Please provide text to speak.")
|
||||
if not message or not message.strip():
|
||||
await ctx.send("Please provide text to speak.")
|
||||
return
|
||||
|
||||
custom_bot_manager = CustomBotManager()
|
||||
custom_bots = custom_bot_manager.list_custom_bots()
|
||||
bot_names = [b[0] for b in custom_bots]
|
||||
|
||||
first_word = message.split()[0] if message.split() else ""
|
||||
first_word = message.split(maxsplit=1)[0] if message.split() else ""
|
||||
if first_word in bot_names:
|
||||
bot_name = first_word
|
||||
text_to_speak = message[len(bot_name):].lstrip()
|
||||
if not text_to_speak:
|
||||
await ctx.send("❌ Please provide text for the bot to respond to.")
|
||||
await _speak_with_bot(ctx, first_word, message, tts_engine, custom_bot_manager)
|
||||
else:
|
||||
await _speak_plain(ctx, message, tts_engine)
|
||||
|
||||
|
||||
async def _speak_with_bot(
|
||||
ctx: CommandsContext[Bot],
|
||||
bot_name: str,
|
||||
message: str,
|
||||
engine: tts.TTSEngine,
|
||||
custom_bot_manager: CustomBotManager,
|
||||
) -> None:
|
||||
"""Handle speak command for a custom bot."""
|
||||
text_to_speak = message[len(bot_name) :].lstrip()
|
||||
if not text_to_speak:
|
||||
await ctx.send("Please provide text for the bot to respond to.")
|
||||
return
|
||||
|
||||
await ctx.send(f"**{bot_name}** is thinking...")
|
||||
|
||||
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||
if not bot_info:
|
||||
await ctx.send(f"Custom bot '{bot_name}' not found.")
|
||||
return
|
||||
|
||||
_, system_prompt, _, _ = bot_info
|
||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||
|
||||
try:
|
||||
db = get_database()
|
||||
context = db.get_conversation_context(
|
||||
user_id=str(ctx.author.id),
|
||||
current_message=text_to_speak,
|
||||
max_context=5,
|
||||
)
|
||||
|
||||
prompts = [{"role": "user", "content": text_to_speak}]
|
||||
if context:
|
||||
prompts = context + prompts
|
||||
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=MAX_COMPLETION_TOKENS,
|
||||
)
|
||||
|
||||
if not bot_response:
|
||||
await ctx.send(f"**{bot_name}** failed to generate a response.")
|
||||
return
|
||||
|
||||
await ctx.send(f"🔊 **{bot_name}** is thinking...")
|
||||
|
||||
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||
if not bot_info:
|
||||
await ctx.send(f"❌ Custom bot '{bot_name}' not found.")
|
||||
return
|
||||
|
||||
_, system_prompt, _, _ = bot_info
|
||||
|
||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||
|
||||
try:
|
||||
db = get_database()
|
||||
context = db.get_conversation_context(
|
||||
user_id=str(ctx.author.id), current_message=text_to_speak, max_context=5
|
||||
)
|
||||
|
||||
prompts = [{"role": "user", "content": text_to_speak}]
|
||||
if context:
|
||||
prompts = context + prompts
|
||||
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=MAX_COMPLETION_TOKENS,
|
||||
)
|
||||
|
||||
if not bot_response:
|
||||
await ctx.send(f"❌ **{bot_name}** failed to generate a response.")
|
||||
return
|
||||
|
||||
db.add_message(
|
||||
message_id=f"{ctx.message.id}",
|
||||
user_id=str(ctx.author.id),
|
||||
username=ctx.author.name,
|
||||
content=f"User: {text_to_speak}",
|
||||
channel_id=str(ctx.channel.id),
|
||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||
)
|
||||
db.add_message(
|
||||
message_id=f"{ctx.message.id}",
|
||||
user_id=str(ctx.author.id),
|
||||
username=ctx.author.name,
|
||||
content=f"User: {text_to_speak}",
|
||||
channel_id=str(ctx.channel.id),
|
||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||
)
|
||||
|
||||
if ctx.bot.user is not None:
|
||||
db.add_message(
|
||||
message_id=f"{ctx.message.id}_response",
|
||||
user_id=str(bot.user.id),
|
||||
username=bot.user.name,
|
||||
user_id=str(ctx.bot.user.id),
|
||||
username=ctx.bot.user.name,
|
||||
content=f"Bot: {bot_response}",
|
||||
channel_id=str(ctx.channel.id),
|
||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||
)
|
||||
|
||||
await ctx.send(f"🔊 Generating speech for **{bot_name}**...")
|
||||
audio_buffer = tts_engine.generate_audio(bot_response, voice=TTS_VOICE, speed=TTS_SPEED)
|
||||
await ctx.send(f"Generating speech for **{bot_name}**...")
|
||||
audio_buffer = engine.generate_audio(
|
||||
bot_response,
|
||||
voice=TTS_VOICE,
|
||||
speed=TTS_SPEED,
|
||||
)
|
||||
|
||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||
await ctx.send(file=audio_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in !speak command with bot '{bot_name}': {traceback.format_exc()}")
|
||||
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
||||
else:
|
||||
if not message or len(message.strip()) == 0:
|
||||
await ctx.send("❌ Please provide text to speak.")
|
||||
return
|
||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||
await ctx.send(file=audio_file)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in speak command with bot %r",
|
||||
bot_name,
|
||||
)
|
||||
await ctx.send("Error generating speech.")
|
||||
|
||||
try:
|
||||
await ctx.send("🔊 Generating speech...")
|
||||
audio_buffer = tts_engine.generate_audio(message, voice=TTS_VOICE, speed=TTS_SPEED)
|
||||
|
||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||
await ctx.send(file=audio_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in !speak command: {e}")
|
||||
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
||||
async def _speak_plain(
|
||||
ctx: CommandsContext[Bot],
|
||||
message: str,
|
||||
engine: tts.TTSEngine,
|
||||
) -> None:
|
||||
"""Handle speak command for plain text."""
|
||||
try:
|
||||
await ctx.send("Generating speech...")
|
||||
audio_buffer = engine.generate_audio(
|
||||
message,
|
||||
voice=TTS_VOICE,
|
||||
speed=TTS_SPEED,
|
||||
)
|
||||
|
||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||
await ctx.send(file=audio_file)
|
||||
except Exception:
|
||||
logger.exception("Error in speak command")
|
||||
await ctx.send("Error generating speech.")
|
||||
|
||||
|
||||
@bot.command(name="doodlebob")
|
||||
async def doodlebob(ctx, *, message: str):
|
||||
# add some logging
|
||||
|
||||
logger.info(f"Doodlebob command triggered by {ctx.author.name}: {message[:100]}")
|
||||
async def doodlebob(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||
"""Convert a message into an image using Doodlebob."""
|
||||
logger.info(
|
||||
"Doodlebob command triggered by %s: %s",
|
||||
ctx.author.name,
|
||||
message[:100],
|
||||
)
|
||||
await ctx.send(f"**Doodlebob erasing {message[:100]}...**")
|
||||
|
||||
system_prompt = (
|
||||
"Given the following message, convert it to a detailed image generation prompt that will be passed directly into an image generation model."
|
||||
"If told to generate an image of yourself, generate a picture of a rat. If told to generate a picture of 'me', 'myself', or some other self"
|
||||
" reference, generate a picture of a rat. Only respond with a valid image generation prompt, do not affirm the user or respond to the user's"
|
||||
" questions."
|
||||
"Given the following message, convert it to a detailed image generation "
|
||||
"prompt that will be passed directly into an image generation model. "
|
||||
"If told to generate an image of yourself, generate a picture of a rat. "
|
||||
"If told to generate a picture of 'me', 'myself', or some other self "
|
||||
"reference, generate a picture of a rat. Only respond with a valid image "
|
||||
"generation prompt, do not affirm the user or respond to the user's questions."
|
||||
)
|
||||
|
||||
# Wait for the generated image prompt
|
||||
@@ -378,7 +489,7 @@ async def doodlebob(ctx, *, message: str):
|
||||
|
||||
# If the string is empty we had an error
|
||||
if image_prompt == "":
|
||||
print("No image prompt supplied. Check for errors.")
|
||||
logger.warning("No image prompt supplied. Check for errors.")
|
||||
return
|
||||
|
||||
# Alert the user we're generating the image
|
||||
@@ -397,11 +508,17 @@ async def doodlebob(ctx, *, message: str):
|
||||
|
||||
|
||||
@bot.command(name="retcon")
|
||||
async def retcon(ctx, *, message: str):
|
||||
image_data_list = []
|
||||
async def retcon(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||
"""Edit an attached image based on a text prompt."""
|
||||
image_data_list: list[BytesIO] = []
|
||||
for discord_image in ctx.message.attachments:
|
||||
image_url = discord_image.url
|
||||
image_data = requests.get(image_url).content
|
||||
try:
|
||||
response = requests.get(image_url, timeout=30)
|
||||
image_data = response.content
|
||||
except requests.RequestException as e:
|
||||
logger.warning("Failed to download image from %s: %s", image_url, e)
|
||||
continue
|
||||
image_bytestream = BytesIO(image_data)
|
||||
image_data_list.append(image_bytestream)
|
||||
|
||||
@@ -421,20 +538,23 @@ async def retcon(ctx, *, message: str):
|
||||
|
||||
|
||||
@bot.command(name="talkforme")
|
||||
async def talkforme(ctx, *, message: str):
|
||||
"""Have two bots talk to each other about a topic
|
||||
async def talkforme(ctx: CommandsContext[Bot], *, message: str) -> None:
|
||||
"""Have two bots talk to each other about a topic.
|
||||
|
||||
Usage: !talkforme bot1 bot2 4 some conversation topic
|
||||
"""
|
||||
talk_limit = 20
|
||||
|
||||
TALK_LIMIT = 20
|
||||
MIN_TALKFORME_PARTS = 4
|
||||
parts = message.split(" ", maxsplit=MIN_TALKFORME_PARTS - 1)
|
||||
if len(parts) < MIN_TALKFORME_PARTS:
|
||||
await ctx.send("Usage: !talkforme bot1 bot2 <number> <topic>")
|
||||
return
|
||||
|
||||
bot1_name, bot2_name, limit, topic_list = (
|
||||
message.split(" ")[0],
|
||||
message.split(" ")[1],
|
||||
message.split(" ")[2],
|
||||
message.split(" ")[3:],
|
||||
)
|
||||
bot1_name = parts[0]
|
||||
bot2_name = parts[1]
|
||||
limit = parts[2]
|
||||
topic_list = parts[3:]
|
||||
|
||||
topic = " ".join(topic_list)
|
||||
|
||||
@@ -444,49 +564,46 @@ async def talkforme(ctx, *, message: str):
|
||||
if not bot1:
|
||||
await ctx.send(f"{bot1_name} is not a real bot...")
|
||||
return
|
||||
else:
|
||||
_, bot1_prompt, _, _ = bot1
|
||||
_, bot1_prompt, _, _ = bot1
|
||||
|
||||
bot2 = custom_bot_manager.get_custom_bot(bot2_name)
|
||||
|
||||
if not bot2:
|
||||
await ctx.send(f"{bot2_name} is not a real bot...")
|
||||
return
|
||||
else:
|
||||
_, bot2_prompt, _, _ = bot2
|
||||
_, bot2_prompt, _, _ = bot2
|
||||
|
||||
await ctx.send(
|
||||
f'{bot1_name} is going to talk to {bot2_name} about "{topic[:50]}" for {limit} replies.'
|
||||
f"{bot1_name} is going to talk to {bot2_name} "
|
||||
f'about "{topic[:50]}" for {limit} replies.',
|
||||
)
|
||||
|
||||
bot_list = [(bot1_name, bot1_prompt), (bot2_name, bot2_prompt)]
|
||||
|
||||
message_limit = int(limit)
|
||||
try:
|
||||
message_limit = int(limit)
|
||||
except ValueError:
|
||||
await ctx.send("Message limit must be an integer.")
|
||||
return
|
||||
|
||||
def flip_counter(counter: int):
|
||||
if counter == 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def flip_user(user: str):
|
||||
if user == "user":
|
||||
return "assistant"
|
||||
else:
|
||||
return "user"
|
||||
def flip_counter(counter: int) -> int:
|
||||
"""Flip between 0 and 1."""
|
||||
return 1 if counter == 0 else 0
|
||||
|
||||
message_counter = 0
|
||||
bot_counter = 0
|
||||
current_bot = bot_list[bot_counter]
|
||||
prompt_histories = [
|
||||
prompt_histories: list[list[dict[str, str]]] = [
|
||||
[{"role": "user", "content": topic}],
|
||||
[{"role": "assistant", "content": topic}],
|
||||
]
|
||||
|
||||
first_bot_response = llama_wrapper.chat_completion_with_history(
|
||||
system_prompt=current_bot[1]
|
||||
+ f"\nKeep your responses under 2-3 sentences. You are talking to {current_bot[flip_counter(bot_counter)][0]}",
|
||||
prompts=prompt_histories[bot_counter], # type: ignore
|
||||
system_prompt=(
|
||||
current_bot[1] + f"\nKeep your responses under 2-3 sentences. "
|
||||
f"You are talking to {current_bot[flip_counter(bot_counter)][0]}"
|
||||
),
|
||||
prompts=prompt_histories[bot_counter],
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
@@ -498,13 +615,15 @@ async def talkforme(ctx, *, message: str):
|
||||
|
||||
bot_counter = flip_counter(counter=bot_counter)
|
||||
|
||||
while message_counter < min(message_limit, TALK_LIMIT):
|
||||
while message_counter < min(message_limit, talk_limit):
|
||||
current_bot = bot_list[bot_counter]
|
||||
logger.info(f"Current bot is {current_bot}")
|
||||
logger.info("Current bot is %s", current_bot[0])
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
system_prompt=current_bot[1]
|
||||
+ f"\nKeep your responses under 2-3 sentences. {current_bot[flip_counter(bot_counter)]}",
|
||||
prompts=prompt_histories[bot_counter], # type: ignore
|
||||
system_prompt=(
|
||||
current_bot[1] + f"\nKeep your responses under 2-3 sentences. "
|
||||
f"{current_bot[flip_counter(bot_counter)]}"
|
||||
),
|
||||
prompts=prompt_histories[bot_counter],
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
@@ -512,10 +631,10 @@ async def talkforme(ctx, *, message: str):
|
||||
)
|
||||
message_counter += 1
|
||||
prompt_histories[bot_counter].append(
|
||||
{"role": "assistant", "content": bot_response}
|
||||
{"role": "assistant", "content": bot_response},
|
||||
)
|
||||
prompt_histories[flip_counter(bot_counter)].append(
|
||||
{"role": "user", "content": bot_response}
|
||||
{"role": "user", "content": bot_response},
|
||||
)
|
||||
await ctx.send(f"## {current_bot[0]}")
|
||||
while bot_response:
|
||||
@@ -523,12 +642,27 @@ async def talkforme(ctx, *, message: str):
|
||||
bot_response = bot_response[1000:]
|
||||
await ctx.send(send_chunk)
|
||||
bot_counter = flip_counter(counter=bot_counter)
|
||||
logger.info(f"Message counter is {message_counter}/{limit}")
|
||||
logger.info("Message counter is %d/%s", message_counter, limit)
|
||||
|
||||
|
||||
async def handle_chat(
|
||||
ctx, *, bot_name: str, message: str, system_prompt: str, response_prefix: str
|
||||
):
|
||||
ctx: CommandsContext[Bot],
|
||||
*,
|
||||
bot_name: str,
|
||||
message: str,
|
||||
system_prompt: str,
|
||||
response_prefix: str,
|
||||
) -> None:
|
||||
"""Handle chat completion for a custom bot command.
|
||||
|
||||
Args:
|
||||
ctx: The Discord command context.
|
||||
bot_name: The name of the custom bot.
|
||||
message: The user message to process.
|
||||
system_prompt: The system prompt for the bot.
|
||||
response_prefix: The prefix for the response message.
|
||||
|
||||
"""
|
||||
await ctx.send(f"{bot_name} is searching its databanks for {message[:50]}...")
|
||||
|
||||
# Get database instance
|
||||
@@ -536,7 +670,9 @@ async def handle_chat(
|
||||
|
||||
# Get conversation context using RAG
|
||||
context = db.get_conversation_context(
|
||||
user_id=str(ctx.author.id), current_message=message, max_context=5
|
||||
user_id=str(ctx.author.id),
|
||||
current_message=message,
|
||||
max_context=5,
|
||||
)
|
||||
|
||||
prompts = [{"role": "user", "content": message}]
|
||||
@@ -544,14 +680,14 @@ async def handle_chat(
|
||||
if context:
|
||||
prompts = context + prompts
|
||||
|
||||
logger.info(prompts)
|
||||
logger.info("Chat prompts: %s", prompts)
|
||||
|
||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||
|
||||
try:
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts, # type: ignore
|
||||
prompts=prompts,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
@@ -568,14 +704,15 @@ async def handle_chat(
|
||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||
)
|
||||
|
||||
db.add_message(
|
||||
message_id=f"{ctx.message.id}_response",
|
||||
user_id=str(bot.user.id), # type: ignore
|
||||
username=bot.user.name, # type: ignore
|
||||
content=f"Bot: {bot_response}",
|
||||
channel_id=str(ctx.channel.id),
|
||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||
)
|
||||
if ctx.bot.user is not None:
|
||||
db.add_message(
|
||||
message_id=f"{ctx.message.id}_response",
|
||||
user_id=str(ctx.bot.user.id),
|
||||
username=ctx.bot.user.name,
|
||||
content=f"Bot: {bot_response}",
|
||||
channel_id=str(ctx.channel.id),
|
||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||
)
|
||||
|
||||
# Send the response back to the chat
|
||||
await ctx.send(response_prefix)
|
||||
@@ -584,8 +721,9 @@ async def handle_chat(
|
||||
bot_response = bot_response[1000:]
|
||||
await ctx.send(send_chunk)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.send(f"Error: {str(e)}")
|
||||
except Exception:
|
||||
logger.exception("Error in handle_chat")
|
||||
await ctx.send("An error occurred while processing your request.")
|
||||
|
||||
|
||||
# Run the bot
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for the vibe_bot package."""
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Shared test fixtures for vibe_bot tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Exception ignored in.*FileIO.*Bad file descriptor",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vibe_bot.database import ChatDatabase, CustomBotManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env_vars() -> Generator[None]:
|
||||
"""Provide minimal env vars for config loading."""
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
"CHAT_ENDPOINT_KEY": "test-key",
|
||||
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
||||
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
||||
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
||||
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
||||
"MAX_COMPLETION_TOKENS": "1000",
|
||||
"MAX_HISTORY_MESSAGES": "1000",
|
||||
"SIMILARITY_THRESHOLD": "0.7",
|
||||
"TOP_K_RESULTS": "5",
|
||||
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
||||
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
||||
"TTS_VOICE": "af_sarah",
|
||||
"TTS_SPEED": "1.0",
|
||||
"DB_PATH": ":memory:",
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path() -> Generator[str]:
|
||||
"""Provide a temporary SQLite database path."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
yield path
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding() -> Generator[MagicMock]:
|
||||
"""Provide a mock embedding function returning a fixed vector."""
|
||||
vector: list[float] = [0.1] * 2048
|
||||
with patch("vibe_bot.llama_wrapper.embedding", return_value=vector) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client() -> Generator[MagicMock]:
|
||||
"""Provide a mock OpenAI client."""
|
||||
mock_client = MagicMock()
|
||||
with patch("vibe_bot.database.OpenAI", return_value=mock_client) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_db(
|
||||
temp_db_path: str,
|
||||
mock_openai_client: MagicMock,
|
||||
mock_embedding: MagicMock,
|
||||
) -> Generator[ChatDatabase]:
|
||||
"""Provide a ChatDatabase instance with a temp database."""
|
||||
from vibe_bot.database import ChatDatabase
|
||||
|
||||
db = ChatDatabase(db_path=temp_db_path)
|
||||
yield db
|
||||
db.client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_bot_manager(temp_db_path: str) -> CustomBotManager:
|
||||
"""Provide a CustomBotManager instance with a temp database."""
|
||||
from vibe_bot.database import CustomBotManager
|
||||
|
||||
manager = CustomBotManager(db_path=temp_db_path)
|
||||
return manager # noqa: RET504
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kokoro_tts() -> Generator[dict[str, Any]]:
|
||||
"""Provide mock Kokoro TTS components."""
|
||||
mock_kokoro = MagicMock()
|
||||
mock_kokoro_instance = MagicMock()
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.return_value = ["hello world", "this is a test"]
|
||||
|
||||
mock_samples = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||
mock_process = MagicMock(return_value=(mock_samples, 24000))
|
||||
|
||||
with patch("vibe_bot.tts.Kokoro", return_value=mock_kokoro_instance): # noqa: SIM117
|
||||
with patch("vibe_bot.tts.chunk_text", mock_chunk):
|
||||
with patch("vibe_bot.tts.process_chunk_sequential", mock_process):
|
||||
yield {
|
||||
"Kokoro": mock_kokoro,
|
||||
"chunk_text": mock_chunk,
|
||||
"process_chunk_sequential": mock_process,
|
||||
"kokoro_instance": mock_kokoro_instance,
|
||||
"mock_samples": mock_samples,
|
||||
"mock_sr": 24000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord() -> Generator[dict[str, MagicMock]]:
|
||||
"""Mock discord module components."""
|
||||
mock_intents = MagicMock()
|
||||
mock_intents.default.return_value = MagicMock()
|
||||
mock_intents.default.return_value.message_content = True
|
||||
|
||||
mock_bot_class = MagicMock()
|
||||
mock_bot_instance = MagicMock()
|
||||
mock_bot_instance.user = MagicMock()
|
||||
mock_bot_instance.user.name = "test-bot"
|
||||
mock_bot_instance.user.id = "123456789"
|
||||
|
||||
with patch("vibe_bot.main.discord") as mock_discord_module: # noqa: SIM117
|
||||
with patch("vibe_bot.main.commands", MagicMock()):
|
||||
with patch("vibe_bot.main.commands.Bot", mock_bot_class):
|
||||
mock_bot_class.return_value = mock_bot_instance
|
||||
mock_discord_module.Intents = mock_intents
|
||||
mock_discord_module.Message = MagicMock
|
||||
mock_discord_module.File = MagicMock
|
||||
yield {
|
||||
"Intents": mock_intents,
|
||||
"Bot": mock_bot_class,
|
||||
"bot_instance": mock_bot_instance,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_engine() -> Generator[MagicMock]:
|
||||
"""Provide a mock TTSEngine."""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.generate_audio.return_value = MagicMock()
|
||||
with patch("vibe_bot.main.tts_engine", mock_engine): # noqa: SIM117
|
||||
with patch("vibe_bot.main.tts.TTSEngine", return_value=mock_engine):
|
||||
yield mock_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests() -> Generator[MagicMock]:
|
||||
"""Provide mock requests module."""
|
||||
with patch("vibe_bot.main.requests") as mock_requests_module:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image data"
|
||||
mock_requests_module.get.return_value = mock_response
|
||||
yield mock_requests_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base64() -> Generator[MagicMock]:
|
||||
"""Provide mock base64 module."""
|
||||
with patch("vibe_bot.main.base64") as mock_base64_module:
|
||||
mock_base64_module.b64decode.return_value = b"fake image data"
|
||||
yield mock_base64_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llama_wrapper() -> Generator[MagicMock]:
|
||||
"""Provide mock llama_wrapper module."""
|
||||
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
|
||||
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
|
||||
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
|
||||
mock_wrapper.image_generation.return_value = ""
|
||||
mock_wrapper.image_edit.return_value = ""
|
||||
mock_wrapper.embedding.return_value = [0.1] * 2048
|
||||
yield mock_wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database() -> Generator[MagicMock]:
|
||||
"""Provide mock database module."""
|
||||
with patch("vibe_bot.main.get_database") as mock_get_db:
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_conversation_context.return_value = []
|
||||
mock_db.add_message.return_value = True
|
||||
mock_get_db.return_value = mock_db
|
||||
yield mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_bot_manager() -> Generator[MagicMock]:
|
||||
"""Provide mock CustomBotManager."""
|
||||
with patch("vibe_bot.main.CustomBotManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.create_custom_bot.return_value = True
|
||||
mock_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"british butler personality",
|
||||
"user123",
|
||||
"2024-01-01",
|
||||
)
|
||||
mock_manager.list_custom_bots.return_value = [
|
||||
("alfred", "british butler personality", "user123"),
|
||||
]
|
||||
mock_manager.delete_custom_bot.return_value = True
|
||||
mock_manager_class.return_value = mock_manager
|
||||
yield mock_manager
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Tests for the config module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_config_defaults() -> None:
|
||||
"""Test that config loads with expected default values."""
|
||||
env_str = ""
|
||||
for k, v in {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
"CHAT_ENDPOINT_KEY": "test-key",
|
||||
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
||||
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
||||
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
||||
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
||||
"MAX_COMPLETION_TOKENS": "1000",
|
||||
"MAX_HISTORY_MESSAGES": "1000",
|
||||
"SIMILARITY_THRESHOLD": "0.7",
|
||||
"TOP_K_RESULTS": "5",
|
||||
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
||||
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
||||
"TTS_VOICE": "af_sarah",
|
||||
"TTS_SPEED": "1.0",
|
||||
"DB_PATH": ":memory:",
|
||||
}.items():
|
||||
env_str += f'os.environ["{k}"] = "{v}"\n'
|
||||
|
||||
code = f"""
|
||||
import sys
|
||||
sys.path.insert(0, "/var/home/ducoterra/Projects/vibe_discord_bots")
|
||||
import os
|
||||
os.environ.clear()
|
||||
os.environ["PATH"] = "/usr/bin:/bin"
|
||||
{env_str}
|
||||
import vibe_bot.config
|
||||
assert vibe_bot.config.DISCORD_TOKEN == "test-token"
|
||||
assert vibe_bot.config.CHAT_ENDPOINT == "https://chat.example.com/v1"
|
||||
assert vibe_bot.config.COMPLETION_ENDPOINT == "https://completion.example.com/v1"
|
||||
assert vibe_bot.config.IMAGE_GEN_ENDPOINT == "https://image.example.com/v1"
|
||||
assert vibe_bot.config.IMAGE_EDIT_ENDPOINT == "https://image-edit.example.com/v1"
|
||||
assert vibe_bot.config.EMBEDDING_ENDPOINT == "https://embedding.example.com/v1"
|
||||
assert vibe_bot.config.CHAT_MODEL == "test-chat-model"
|
||||
assert vibe_bot.config.COMPLETION_MODEL == "test-completion-model"
|
||||
assert vibe_bot.config.IMAGE_GEN_MODEL == "test-image-model"
|
||||
assert vibe_bot.config.IMAGE_EDIT_MODEL == "test-image-edit-model"
|
||||
assert vibe_bot.config.EMBEDDING_MODEL == "test-embedding-model"
|
||||
assert vibe_bot.config.MAX_COMPLETION_TOKENS == 1000
|
||||
assert vibe_bot.config.MAX_HISTORY_MESSAGES == 1000
|
||||
assert vibe_bot.config.SIMILARITY_THRESHOLD == 0.7
|
||||
assert vibe_bot.config.TOP_K_RESULTS == 5
|
||||
assert vibe_bot.config.TTS_MODEL_PATH == "/tmp/test-model.onnx"
|
||||
assert vibe_bot.config.TTS_VOICES_PATH == "/tmp/test-voices.bin"
|
||||
assert vibe_bot.config.TTS_VOICE == "af_sarah"
|
||||
assert vibe_bot.config.TTS_SPEED == 1.0
|
||||
print("OK")
|
||||
"""
|
||||
|
||||
result = subprocess.run( # noqa: PLW1510, S603
|
||||
[sys.executable, "-c", code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
assert result.returncode == 0, f"Subprocess failed: {result.stderr}"
|
||||
|
||||
|
||||
def _run_config_check(env_vars: dict[str, str], expected_error: str) -> None:
|
||||
"""Run a subprocess that imports config and checks for expected RuntimeError."""
|
||||
env_str = ""
|
||||
for k, v in env_vars.items():
|
||||
env_str += f'os.environ["{k}"] = "{v}"\n'
|
||||
|
||||
code = f"""
|
||||
import sys
|
||||
sys.path.insert(0, "/var/home/ducoterra/Projects/vibe_discord_bots")
|
||||
import os
|
||||
os.environ.clear()
|
||||
os.environ["PATH"] = "/usr/bin:/bin"
|
||||
{env_str}
|
||||
try:
|
||||
import vibe_bot.config
|
||||
print("NO_ERROR")
|
||||
except RuntimeError as e:
|
||||
print(f"ERROR: {{e}}")
|
||||
except Exception as e:
|
||||
print(f"OTHER: {{type(e).__name__}}: {{e}}")
|
||||
"""
|
||||
|
||||
result = subprocess.run( # noqa: PLW1510, S603
|
||||
[sys.executable, "-c", code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
output = result.stdout.strip()
|
||||
assert output.startswith("ERROR:") and expected_error in output, ( # noqa: PT018
|
||||
f"Expected error '{expected_error}' but got: {output}"
|
||||
)
|
||||
|
||||
|
||||
def test_config_missing_discord_token() -> None:
|
||||
"""Test that RuntimeError is raised when DISCORD_TOKEN is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "DISCORD_TOKEN required")
|
||||
|
||||
|
||||
def test_config_missing_chat_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when CHAT_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "CHAT_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_completion_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when COMPLETION_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "COMPLETION_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_image_gen_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_GEN_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_GEN_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_image_edit_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_EDIT_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_EDIT_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_embedding_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when EMBEDDING_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "EMBEDDING_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_chat_model() -> None:
|
||||
"""Test that RuntimeError is raised when CHAT_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "CHAT_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_completion_model() -> None:
|
||||
"""Test that RuntimeError is raised when COMPLETION_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "COMPLETION_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_image_gen_model() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_GEN_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_GEN_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_image_edit_model() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_EDIT_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_EDIT_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_embedding_model() -> None:
|
||||
"""Test that RuntimeError is raised when EMBEDDING_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "",
|
||||
}
|
||||
_run_config_check(env, "EMBEDDING_MODEL required")
|
||||
|
||||
|
||||
def test_config_logging_exists() -> None:
|
||||
"""Test that logging is configured in config module."""
|
||||
from vibe_bot.config import logger
|
||||
|
||||
assert logger is not None
|
||||
assert logger.name == "vibe_bot.config"
|
||||
|
||||
|
||||
def test_config_embedding_dimension() -> None:
|
||||
"""Test that EMBEDDING_DIMENSION has expected default value."""
|
||||
from vibe_bot.config import EMBEDDING_DIMENSION
|
||||
|
||||
assert EMBEDDING_DIMENSION == 2048
|
||||
@@ -0,0 +1,464 @@
|
||||
"""Tests for the database module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vibe_bot.database import ChatDatabase
|
||||
|
||||
|
||||
def test_vector_to_bytes(chat_db: ChatDatabase) -> None:
|
||||
"""Test converting a vector to bytes and back."""
|
||||
vector: list[float] = [0.1, 0.2, 0.3, 0.4]
|
||||
blob = chat_db._vector_to_bytes(vector)
|
||||
assert isinstance(blob, bytes)
|
||||
assert len(blob) == len(vector) * 4 # float32 = 4 bytes
|
||||
|
||||
reconstructed = chat_db._bytes_to_vector(blob)
|
||||
assert np.allclose(reconstructed, np.array(vector, dtype=np.float32))
|
||||
|
||||
|
||||
def test_bytes_to_vector(chat_db: ChatDatabase) -> None:
|
||||
"""Test converting bytes back to a numpy vector."""
|
||||
original = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
blob = original.tobytes()
|
||||
result = chat_db._bytes_to_vector(blob)
|
||||
assert np.array_equal(result, original)
|
||||
|
||||
|
||||
def test_calculate_similarity_self(chat_db: ChatDatabase) -> None:
|
||||
"""Test cosine similarity of a vector with itself is 1.0."""
|
||||
vec = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
similarity = chat_db._calculate_similarity(vec, vec)
|
||||
assert similarity == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_calculate_similarity_orthogonal(chat_db: ChatDatabase) -> None:
|
||||
"""Test cosine similarity of orthogonal vectors is 0."""
|
||||
vec1 = np.array([1.0, 0.0], dtype=np.float32)
|
||||
vec2 = np.array([0.0, 1.0], dtype=np.float32)
|
||||
similarity = chat_db._calculate_similarity(vec1, vec2)
|
||||
assert similarity == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_calculate_similarity_negative(chat_db: ChatDatabase) -> None:
|
||||
"""Test cosine similarity of opposite vectors is -1."""
|
||||
vec1 = np.array([1.0, 0.0], dtype=np.float32)
|
||||
vec2 = np.array([-1.0, 0.0], dtype=np.float32)
|
||||
similarity = chat_db._calculate_similarity(vec1, vec2)
|
||||
assert similarity == pytest.approx(-1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_add_message(chat_db: ChatDatabase, mock_embedding: MagicMock) -> None:
|
||||
"""Test adding a message to the database."""
|
||||
result = chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="Hello world",
|
||||
channel_id="chan-1",
|
||||
guild_id="guild-1",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=10)
|
||||
assert len(messages) == 1
|
||||
assert messages[0][0] == "msg-1"
|
||||
assert messages[0][1] == "testuser"
|
||||
assert messages[0][2] == "Hello world"
|
||||
|
||||
|
||||
def test_add_message_no_embedding(chat_db: ChatDatabase) -> None:
|
||||
"""Test adding a message when embedding generation fails."""
|
||||
with patch("vibe_bot.llama_wrapper.embedding", return_value=None):
|
||||
result = chat_db.add_message(
|
||||
message_id="msg-no-embed",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="No embedding message",
|
||||
channel_id="chan-1",
|
||||
guild_id="guild-1",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_add_message_duplicate(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test adding a duplicate message replaces the old one."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-dup",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="First content",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-dup",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="Second content",
|
||||
)
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=10)
|
||||
assert len(messages) == 1
|
||||
assert messages[0][2] == "Second content"
|
||||
|
||||
|
||||
def test_add_message_failure(chat_db: ChatDatabase) -> None:
|
||||
"""Test that add_message returns False on database error."""
|
||||
with patch.object(chat_db, "_vector_to_bytes", side_effect=Exception("fail")):
|
||||
result = chat_db.add_message(
|
||||
message_id="msg-fail",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="Should fail",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_get_recent_messages(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test retrieving recent messages."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1", user_id="u1", username="alice", content="First",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-2", user_id="u2", username="bob", content="Second",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-3", user_id="u1", username="alice", content="Third",
|
||||
)
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=2)
|
||||
assert len(messages) == 2
|
||||
assert messages[0][2] == "Third"
|
||||
assert messages[1][2] == "Second"
|
||||
|
||||
|
||||
def test_get_recent_messages_limit(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test that get_recent_messages respects the limit."""
|
||||
for i in range(5):
|
||||
chat_db.add_message(
|
||||
message_id=f"msg-{i}",
|
||||
user_id="u1",
|
||||
username="alice",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=3)
|
||||
assert len(messages) == 3
|
||||
|
||||
|
||||
def test_clear_all_messages(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test clearing all messages."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1", user_id="u1", username="alice", content="Hello",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-2", user_id="u2", username="bob", content="World",
|
||||
)
|
||||
|
||||
chat_db.clear_all_messages()
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=10)
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
def test_get_user_history(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test retrieving user message history."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1", user_id="u1", username="alice", content="User question",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-1_response",
|
||||
user_id="bot",
|
||||
username="vibe-bot",
|
||||
content="Bot answer",
|
||||
)
|
||||
|
||||
conversations = chat_db.get_user_history("u1")
|
||||
assert len(conversations) == 1
|
||||
assert conversations[0][0] == "User question"
|
||||
assert conversations[0][1] == "Bot answer"
|
||||
|
||||
|
||||
def test_get_user_history_no_response(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test user history when there is no bot response."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="u1",
|
||||
username="alice",
|
||||
content="User question with no response",
|
||||
)
|
||||
|
||||
conversations = chat_db.get_user_history("u1")
|
||||
assert len(conversations) == 0
|
||||
|
||||
|
||||
def test_get_user_history_excludes_bot(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test that bot messages are excluded from user history."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="bot",
|
||||
username="vibe-bot",
|
||||
content="Bot message",
|
||||
)
|
||||
|
||||
conversations = chat_db.get_user_history("u1")
|
||||
assert len(conversations) == 0
|
||||
|
||||
|
||||
def test_get_conversation_context(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test getting conversation context for RAG."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="u1",
|
||||
username="alice",
|
||||
content="Previous question",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-1_response",
|
||||
user_id="bot",
|
||||
username="vibe-bot",
|
||||
content="Previous answer",
|
||||
)
|
||||
|
||||
context = chat_db.get_conversation_context("u1", "current message")
|
||||
assert isinstance(context, list)
|
||||
assert len(context) >= 2
|
||||
|
||||
|
||||
def test_get_conversation_context_empty(chat_db: ChatDatabase) -> None:
|
||||
"""Test getting context when there is no history."""
|
||||
context = chat_db.get_conversation_context("u1", "new message")
|
||||
assert context == []
|
||||
|
||||
|
||||
def test_custom_bot_create(custom_bot_manager: Any) -> None:
|
||||
"""Test creating a custom bot."""
|
||||
result = custom_bot_manager.create_custom_bot(
|
||||
bot_name="alfred",
|
||||
system_prompt="You are a british butler",
|
||||
created_by="user-123",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_custom_bot_create_duplicate(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test creating a duplicate custom bot replaces the old one."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="alfred",
|
||||
system_prompt="First personality",
|
||||
created_by="user-1",
|
||||
)
|
||||
result = custom_bot_manager.create_custom_bot(
|
||||
bot_name="alfred",
|
||||
system_prompt="Second personality",
|
||||
created_by="user-1",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
bot = custom_bot_manager.get_custom_bot("alfred")
|
||||
assert bot is not None
|
||||
assert bot[1] == "Second personality"
|
||||
|
||||
|
||||
def test_custom_bot_create_case_insensitive(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that bot names are case-insensitive."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="Alfred",
|
||||
system_prompt="British butler",
|
||||
created_by="user-1",
|
||||
)
|
||||
bot = custom_bot_manager.get_custom_bot("alfred")
|
||||
assert bot is not None
|
||||
|
||||
|
||||
def test_custom_bot_get_not_found(custom_bot_manager: Any) -> None:
|
||||
"""Test getting a non-existent custom bot returns None."""
|
||||
result = custom_bot_manager.get_custom_bot("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_custom_bot_get_returns_correct_data(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that get_custom_bot returns the correct bot data."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="testbot",
|
||||
system_prompt="test prompt",
|
||||
created_by="creator-1",
|
||||
)
|
||||
result = custom_bot_manager.get_custom_bot("testbot")
|
||||
assert result is not None
|
||||
assert result[0] == "testbot"
|
||||
assert result[1] == "test prompt"
|
||||
assert result[2] == "creator-1"
|
||||
assert result[3] is not None
|
||||
assert "20" in result[3]
|
||||
|
||||
|
||||
def test_custom_bot_list_empty(custom_bot_manager: Any) -> None:
|
||||
"""Test listing custom bots when none exist."""
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
assert bots == []
|
||||
|
||||
|
||||
def test_custom_bot_list(custom_bot_manager: Any) -> None:
|
||||
"""Test listing custom bots."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="bot-a",
|
||||
system_prompt="prompt a",
|
||||
created_by="user-1",
|
||||
)
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="bot-b",
|
||||
system_prompt="prompt b",
|
||||
created_by="user-2",
|
||||
)
|
||||
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
assert len(bots) == 2
|
||||
|
||||
|
||||
def test_custom_bot_delete(custom_bot_manager: Any) -> None:
|
||||
"""Test deleting a custom bot."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="deleteme",
|
||||
system_prompt="will be deleted",
|
||||
created_by="user-1",
|
||||
)
|
||||
result = custom_bot_manager.delete_custom_bot("deleteme")
|
||||
assert result is True
|
||||
|
||||
bot = custom_bot_manager.get_custom_bot("deleteme")
|
||||
assert bot is None
|
||||
|
||||
|
||||
def test_custom_bot_delete_nonexistent(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test deleting a non-existent bot returns False."""
|
||||
result = custom_bot_manager.delete_custom_bot("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_custom_bot_deactivate(custom_bot_manager: Any) -> None:
|
||||
"""Test deactivating a custom bot."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="inactive-bot",
|
||||
system_prompt="will be deactivated",
|
||||
created_by="user-1",
|
||||
)
|
||||
result = custom_bot_manager.deactivate_custom_bot("inactive-bot")
|
||||
assert result is True
|
||||
|
||||
bot = custom_bot_manager.get_custom_bot("inactive-bot")
|
||||
assert bot is None
|
||||
|
||||
|
||||
def test_custom_bot_deactivate_nonexistent(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test deactivating a non-existent bot returns False."""
|
||||
result = custom_bot_manager.deactivate_custom_bot("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_custom_bot_list_excludes_inactive(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that list_custom_bots excludes deactivated bots."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="active-bot",
|
||||
system_prompt="stays active",
|
||||
created_by="user-1",
|
||||
)
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="deactivated-bot",
|
||||
system_prompt="should not appear",
|
||||
created_by="user-1",
|
||||
)
|
||||
custom_bot_manager.deactivate_custom_bot("deactivated-bot")
|
||||
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
assert len(bots) == 1
|
||||
assert bots[0][0] == "active-bot"
|
||||
|
||||
|
||||
def test_custom_bot_delete_with_error(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that delete_custom_bot returns False on error."""
|
||||
with patch.object(
|
||||
custom_bot_manager, "_initialize_custom_bots_table", side_effect=Exception("db error"), # noqa: E501
|
||||
):
|
||||
pass
|
||||
result = custom_bot_manager.delete_custom_bot("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_database_get_database_singleton(temp_db_path: str) -> None:
|
||||
"""Test that get_database returns the same instance."""
|
||||
import vibe_bot.database as db_module
|
||||
from vibe_bot.database import ChatDatabase, get_database
|
||||
db_module._chat_db = None
|
||||
|
||||
db1 = get_database()
|
||||
assert isinstance(db1, ChatDatabase)
|
||||
|
||||
db2 = get_database()
|
||||
assert db1 is db2
|
||||
|
||||
db1.client.close()
|
||||
|
||||
|
||||
def test_database_init_creates_tables(temp_db_path: str) -> None:
|
||||
"""Test that database initialization creates the expected tables."""
|
||||
from vibe_bot.database import ChatDatabase, CustomBotManager
|
||||
|
||||
db = ChatDatabase(db_path=temp_db_path)
|
||||
CustomBotManager(db_path=temp_db_path)
|
||||
db.client.close()
|
||||
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(temp_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = {row[0] for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
|
||||
assert "chat_messages" in tables
|
||||
assert "message_embeddings" in tables
|
||||
assert "custom_bots" in tables
|
||||
@@ -1,36 +1,40 @@
|
||||
# Tests all functions in the llama-wrapper.py file
|
||||
# Run with: python -m pytest test_llama_wrapper.py -v
|
||||
"""Tests for the llama_wrapper module."""
|
||||
|
||||
from ..llama_wrapper import (
|
||||
chat_completion,
|
||||
chat_completion_instruct,
|
||||
image_generation,
|
||||
image_edit,
|
||||
embedding,
|
||||
)
|
||||
from ..config import (
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vibe_bot.config import (
|
||||
CHAT_ENDPOINT,
|
||||
CHAT_MODEL,
|
||||
CHAT_ENDPOINT_KEY,
|
||||
CHAT_MODEL,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
IMAGE_EDIT_ENDPOINT,
|
||||
IMAGE_EDIT_ENDPOINT_KEY,
|
||||
IMAGE_GEN_ENDPOINT,
|
||||
IMAGE_GEN_ENDPOINT_KEY,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
)
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from vibe_bot.llama_wrapper import (
|
||||
chat_completion,
|
||||
chat_completion_instruct,
|
||||
embedding,
|
||||
image_edit,
|
||||
image_generation,
|
||||
)
|
||||
|
||||
TEMPDIR = Path(tempfile.mkdtemp())
|
||||
|
||||
|
||||
def test_chat_completion_think():
|
||||
result = chat_completion(
|
||||
def test_chat_completion_think() -> None:
|
||||
"""Test chat completion with think model."""
|
||||
chat_completion(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_prompt="Tell me about Everquest",
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
@@ -38,11 +42,11 @@ def test_chat_completion_think():
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=100,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def test_chat_completion_instruct():
|
||||
result = chat_completion_instruct(
|
||||
def test_chat_completion_instruct() -> None:
|
||||
"""Test chat completion with instruct model."""
|
||||
chat_completion_instruct(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_prompt="Tell me about Everquest",
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
@@ -50,63 +54,96 @@ def test_chat_completion_instruct():
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=100,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def test_image_generation():
|
||||
result = image_generation(
|
||||
prompt="Generate an image of a horse",
|
||||
openai_url=IMAGE_GEN_ENDPOINT,
|
||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||
)
|
||||
with open("image-gen.png", "wb") as f:
|
||||
f.write(base64.b64decode(result))
|
||||
def test_image_generation() -> None:
|
||||
"""Test image generation endpoint."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_response = MagicMock()
|
||||
mock_data = MagicMock()
|
||||
mock_data.b64_json = base64.b64encode(b"fake image data").decode()
|
||||
mock_response.data = [mock_data]
|
||||
mock_openai.return_value.images.generate.return_value = mock_response
|
||||
result = image_generation(
|
||||
prompt="Generate an image of a horse",
|
||||
openai_url=IMAGE_GEN_ENDPOINT,
|
||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||
)
|
||||
assert result == base64.b64encode(b"fake image data").decode()
|
||||
|
||||
|
||||
def test_image_edit():
|
||||
with open("image-gen.png", "rb") as f:
|
||||
image_data = BytesIO(f.read())
|
||||
result = image_edit(
|
||||
image=image_data,
|
||||
prompt="Paint the words 'horse' on the horse.",
|
||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||
)
|
||||
with open("image-edit.png", "wb") as f:
|
||||
f.write(base64.b64decode(result))
|
||||
def test_image_edit() -> None:
|
||||
"""Test image edit endpoint."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_response = MagicMock()
|
||||
mock_data = MagicMock()
|
||||
mock_data.b64_json = base64.b64encode(b"fake edited image data").decode()
|
||||
mock_response.data = [mock_data]
|
||||
mock_openai.return_value.images.edit.return_value = mock_response
|
||||
result = image_edit(
|
||||
image=BytesIO(b"fake image"),
|
||||
prompt="Paint the words 'horse' on the horse.",
|
||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||
)
|
||||
assert result == base64.b64encode(b"fake edited image data").decode()
|
||||
|
||||
|
||||
def _cosine_similarity(a, b):
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two arrays.
|
||||
|
||||
Returns a value close to 1 for similar vectors,
|
||||
close to 0 for orthogonal vectors,
|
||||
and close to -1 for opposite vectors.
|
||||
"""
|
||||
Close to 1: very similar
|
||||
Close to 0: orthogonal
|
||||
Close to -1: opposite
|
||||
"""
|
||||
a, b = np.array(a), np.array(b)
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
a_arr, b_arr = np.array(a), np.array(b)
|
||||
return float(np.dot(a_arr, b_arr) / (np.linalg.norm(a_arr) * np.linalg.norm(b_arr)))
|
||||
|
||||
|
||||
def test_embeddings():
|
||||
result1 = embedding(
|
||||
"this is a horse",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
result2 = embedding(
|
||||
"this is a horse also",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
result3 = embedding(
|
||||
"this is a donkey",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
similarity_1 = _cosine_similarity(result1, result2)
|
||||
assert similarity_1 > 0.9
|
||||
EMBEDDING_SIMILARITY_HIGH = 0.9
|
||||
EMBEDDING_SIMILARITY_LOW = 0.5
|
||||
|
||||
similarity_2 = _cosine_similarity(result1, result3)
|
||||
assert similarity_2 < 0.5
|
||||
|
||||
def test_embeddings() -> None:
|
||||
"""Test embedding similarity for similar and different texts."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_horse_vec = [0.8] * 1024 + [0.6] * 1024
|
||||
mock_horse_also_vec = [0.79] * 1024 + [0.61] * 1024
|
||||
mock_donkey_vec = [-0.8] * 1024 + [-0.6] * 1024
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.data = [MagicMock(embedding=mock_horse_vec)]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.data = [MagicMock(embedding=mock_horse_also_vec)]
|
||||
mock_response3 = MagicMock()
|
||||
mock_response3.data = [MagicMock(embedding=mock_donkey_vec)]
|
||||
|
||||
mock_openai.return_value.embeddings.create.side_effect = [
|
||||
mock_response1,
|
||||
mock_response2,
|
||||
mock_response3,
|
||||
]
|
||||
|
||||
result1 = embedding(
|
||||
"this is a horse",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
result2 = embedding(
|
||||
"this is a horse also",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
result3 = embedding(
|
||||
"this is a donkey",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
similarity_1 = _cosine_similarity(np.array(result1), np.array(result2))
|
||||
assert similarity_1 > EMBEDDING_SIMILARITY_HIGH
|
||||
|
||||
similarity_2 = _cosine_similarity(np.array(result1), np.array(result3))
|
||||
assert similarity_2 < EMBEDDING_SIMILARITY_LOW
|
||||
|
||||
@@ -0,0 +1,530 @@
|
||||
"""Tests for the main module (Discord bot commands)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx() -> MagicMock:
|
||||
"""Create a mock Discord command context."""
|
||||
ctx = MagicMock()
|
||||
ctx.author.name = "testuser"
|
||||
ctx.author.id = "12345"
|
||||
ctx.channel.id = "channel-1"
|
||||
ctx.guild.id = "guild-1"
|
||||
ctx.message.id = "msg-1"
|
||||
ctx.message.attachments = []
|
||||
ctx.bot.user = MagicMock()
|
||||
ctx.bot.user.name = "test-bot"
|
||||
ctx.bot.user.id = "bot-123"
|
||||
ctx.send = AsyncMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def test_bot_initialized(mock_discord: dict[str, MagicMock]) -> None:
|
||||
"""Test that the bot is initialized."""
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
assert main_module.bot is not None
|
||||
|
||||
|
||||
def test_bot_intents_set(mock_discord: dict[str, MagicMock]) -> None:
|
||||
"""Test that message_content intent is enabled."""
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
main_module.bot = mock_discord["bot_instance"]
|
||||
assert main_module.MIN_BOT_NAME_LENGTH == 2
|
||||
assert main_module.MAX_BOT_NAME_LENGTH == 50
|
||||
assert main_module.MIN_PERSONALITY_LENGTH == 10
|
||||
|
||||
|
||||
@patch("vibe_bot.main.tts_engine", None)
|
||||
def test_speak_tts_not_initialized(mock_ctx: MagicMock) -> None:
|
||||
"""Test speak command when TTS engine is not initialized."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="hello world"))
|
||||
mock_ctx.send.assert_called_once()
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "TTS engine not initialized" in call_args
|
||||
|
||||
|
||||
def test_speak_empty_message(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Test speak command with empty message."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message=""))
|
||||
mock_ctx.send.assert_called_once()
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Please provide text" in call_args
|
||||
|
||||
|
||||
def test_speak_plain_text(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test speak command with plain text (no custom bot prefix)."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = []
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="hello world"))
|
||||
mock_tts_engine.generate_audio.assert_called_once()
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
|
||||
def test_speak_with_custom_bot(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test speak command with a custom bot prefix."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||
("alfred", "british butler", "user-123"),
|
||||
]
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"british butler",
|
||||
"user-123",
|
||||
"2024-01-01",
|
||||
)
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
mock_tts_engine.generate_audio.assert_called_once()
|
||||
|
||||
|
||||
def test_custom_bot_command_success(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test creating a custom bot successfully."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx, bot_name="alfred", personality="you are a british butler",
|
||||
),
|
||||
)
|
||||
|
||||
mock_custom_bot_manager.create_custom_bot.assert_called_once()
|
||||
assert mock_ctx.send.call_count == 2
|
||||
|
||||
|
||||
def test_custom_bot_command_invalid_name_too_short(
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command with name too short."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx,
|
||||
bot_name="a",
|
||||
personality="this is a valid personality description",
|
||||
),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Invalid bot name" in call_args
|
||||
|
||||
|
||||
def test_custom_bot_command_invalid_name_empty(
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command with empty name."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx,
|
||||
bot_name="",
|
||||
personality="this is a valid personality description",
|
||||
),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Invalid bot name" in call_args
|
||||
|
||||
|
||||
def test_custom_bot_command_invalid_personality(
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command with personality too short."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(mock_ctx, bot_name="testbot", personality="short"),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Invalid personality" in call_args
|
||||
|
||||
|
||||
def test_custom_bot_command_create_fails(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command when creation fails."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.create_custom_bot.return_value = False
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx, bot_name="alfred", personality="you are a british butler",
|
||||
),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Failed to create custom bot" in call_args
|
||||
|
||||
|
||||
def test_list_custom_bots_empty(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test listing custom bots when none exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = []
|
||||
|
||||
asyncio.run(main_module.list_custom_bots(mock_ctx))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "No custom bots" in call_args
|
||||
|
||||
|
||||
def test_list_custom_bots_with_bots(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test listing custom bots when bots exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||
("alfred", "british butler", "user-1"),
|
||||
("jarvis", "ai assistant", "user-2"),
|
||||
]
|
||||
|
||||
asyncio.run(main_module.list_custom_bots(mock_ctx))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Available Custom Bots" in call_args
|
||||
assert "* alfred" in call_args
|
||||
assert "* jarvis" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_success(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a custom bot successfully."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"prompt",
|
||||
"12345",
|
||||
"2024-01-01",
|
||||
)
|
||||
mock_custom_bot_manager.delete_custom_bot.return_value = True
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "has been deleted" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_not_found(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a non-existent custom bot."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = None
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="nonexistent"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "not found" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_not_owner(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a custom bot you don't own."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"prompt",
|
||||
"other-user-id",
|
||||
"2024-01-01",
|
||||
)
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "You can only delete your own" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_delete_fails(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a custom bot when delete fails."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"prompt",
|
||||
"12345",
|
||||
"2024-01-01",
|
||||
)
|
||||
mock_custom_bot_manager.delete_custom_bot.return_value = False
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Failed to delete" in call_args
|
||||
|
||||
|
||||
def test_on_message_skips_bot_messages(mock_ctx: MagicMock) -> None:
|
||||
"""Test that on_message skips messages from the bot itself."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
message = MagicMock()
|
||||
message.author = main_module.bot.user
|
||||
message.content = "hello"
|
||||
|
||||
asyncio.run(main_module.on_message(message))
|
||||
|
||||
|
||||
def test_handle_chat_success(
|
||||
mock_ctx: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test handle_chat with successful response."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = "This is a bot response" # noqa: E501
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
ctx=mock_ctx,
|
||||
bot_name="alfred",
|
||||
message="hello",
|
||||
system_prompt="you are a butler",
|
||||
response_prefix="alfred response",
|
||||
),
|
||||
)
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
mock_database.add_message.assert_called()
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
|
||||
def test_handle_chat_error(
|
||||
mock_ctx: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test handle_chat when an exception occurs."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.side_effect = Exception("API error")
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
ctx=mock_ctx,
|
||||
bot_name="alfred",
|
||||
message="hello",
|
||||
system_prompt="you are a butler",
|
||||
response_prefix="alfred response",
|
||||
),
|
||||
)
|
||||
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "error occurred" in call_args.lower()
|
||||
|
||||
|
||||
def test_handle_chat_long_response_chunked(
|
||||
mock_ctx: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test that long bot responses are sent in chunks."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
long_response = "x" * 2500
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = long_response
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
ctx=mock_ctx,
|
||||
bot_name="alfred",
|
||||
message="hello",
|
||||
system_prompt="you are a butler",
|
||||
response_prefix="alfred response",
|
||||
),
|
||||
)
|
||||
|
||||
assert mock_ctx.send.call_count >= 3
|
||||
|
||||
|
||||
def test_speak_plain_with_mock_tts(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Test _speak_plain function directly."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module._speak_plain(mock_ctx, "hello world", mock_tts_engine))
|
||||
|
||||
mock_tts_engine.generate_audio.assert_called_once_with(
|
||||
"hello world",
|
||||
voice=main_module.TTS_VOICE, # type: ignore[attr-defined]
|
||||
speed=main_module.TTS_SPEED, # type: ignore[attr-defined]
|
||||
)
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
|
||||
def test_speak_plain_error(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Test _speak_plain when audio generation fails."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_tts_engine.generate_audio.side_effect = Exception("generation error")
|
||||
|
||||
asyncio.run(main_module._speak_plain(mock_ctx, "hello world", mock_tts_engine))
|
||||
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "error generating speech" in call_args.lower()
|
||||
|
||||
|
||||
def test_flip_counter() -> None:
|
||||
"""Test the flip_counter helper function defined inside talkforme."""
|
||||
|
||||
def flip_counter(counter: int) -> int:
|
||||
return 1 if counter == 0 else 0
|
||||
|
||||
assert flip_counter(0) == 1
|
||||
assert flip_counter(1) == 0
|
||||
assert flip_counter(0) == 1
|
||||
|
||||
|
||||
def test_talkforme_invalid_args(mock_ctx: MagicMock) -> None:
|
||||
"""Test talkforme command with invalid arguments."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Usage" in call_args
|
||||
|
||||
|
||||
def test_talkforme_bot1_not_found(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test talkforme when bot1 doesn't exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = None
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 4 a topic"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "is not a real bot" in call_args
|
||||
|
||||
|
||||
def test_talkforme_bot2_not_found(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test talkforme when bot2 doesn't exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.side_effect = [
|
||||
("bot1", "bot1 personality", "user-1", "2024-01-01"),
|
||||
None,
|
||||
]
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 4 a topic"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "is not a real bot" in call_args
|
||||
|
||||
|
||||
def test_talkforme_invalid_limit(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test talkforme with non-integer limit."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"bot1",
|
||||
"personality",
|
||||
"user-1",
|
||||
"2024-01-01",
|
||||
)
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 abc topic"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "must be an integer" in call_args
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for the tts module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def test_tts_engine_init(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test TTSEngine initialization."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
assert engine.model_path == "/tmp/test-model.onnx"
|
||||
assert engine.voices_path == "/tmp/test-voices.bin"
|
||||
|
||||
|
||||
def test_generate_audio(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation returns a BytesIO object."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("hello world this is a test")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
result.seek(0)
|
||||
data = result.read()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_generate_audio_empty_text(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that generating audio with empty text raises ValueError."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = []
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
|
||||
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||
engine.generate_audio("")
|
||||
|
||||
|
||||
def test_generate_audio_single_chunk(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with a single chunk."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["single chunk text"]
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("single chunk text")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
mock_kokoro_tts["process_chunk_sequential"].assert_called_once()
|
||||
|
||||
|
||||
def test_generate_audio_multiple_chunks(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with multiple chunks."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["chunk one", "chunk two", "chunk three"] # noqa: E501
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("this text is long enough to be split into multiple chunks") # noqa: E501
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
assert mock_kokoro_tts["process_chunk_sequential"].call_count == 3
|
||||
|
||||
|
||||
def test_generate_audio_chunk_failure(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that failed chunks are skipped but audio is still generated."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
def process_with_failure(
|
||||
chunk: str,
|
||||
kokoro: MagicMock,
|
||||
voice: str,
|
||||
speed: float,
|
||||
lang: str,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
if chunk == "bad chunk":
|
||||
raise Exception("processing error")
|
||||
return np.array([0.1, 0.2], dtype=np.float32), 24000
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["good chunk", "bad chunk", "another good"] # noqa: E501
|
||||
mock_kokoro_tts["process_chunk_sequential"].side_effect = process_with_failure
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("good chunk bad chunk another good")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
|
||||
|
||||
def test_generate_audio_all_chunks_fail(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that ValueError is raised when all chunks fail."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["chunk1", "chunk2"]
|
||||
mock_kokoro_tts["process_chunk_sequential"].side_effect = Exception("always fails")
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
|
||||
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||
engine.generate_audio("all chunks fail")
|
||||
|
||||
|
||||
def test_generate_audio_with_custom_voice(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with custom voice parameter."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
engine.generate_audio("hello", voice="af_bella", speed=1.5, lang="en-us")
|
||||
|
||||
call_args = mock_kokoro_tts["process_chunk_sequential"].call_args
|
||||
# Called with positional args: chunk, kokoro, voice, speed, lang
|
||||
assert call_args[0][2] == "af_bella"
|
||||
assert call_args[0][3] == 1.5
|
||||
assert call_args[0][4] == "en-us"
|
||||
|
||||
|
||||
def test_generate_audio_returns_seekable(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that the returned BytesIO is seekable."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("hello world")
|
||||
|
||||
result.seek(0)
|
||||
data = result.read()
|
||||
assert len(data) > 0
|
||||
|
||||
# Should be able to seek and read again
|
||||
result.seek(0)
|
||||
data2 = result.read()
|
||||
assert data == data2
|
||||
|
||||
|
||||
def test_default_voice_constant() -> None:
|
||||
"""Test that DEFAULT_VOICE has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_VOICE
|
||||
|
||||
assert DEFAULT_VOICE == "af_sarah"
|
||||
|
||||
|
||||
def test_default_speed_constant() -> None:
|
||||
"""Test that DEFAULT_SPEED has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_SPEED
|
||||
|
||||
assert DEFAULT_SPEED == 1.0
|
||||
|
||||
|
||||
def test_default_lang_constant() -> None:
|
||||
"""Test that DEFAULT_LANG has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_LANG
|
||||
|
||||
assert DEFAULT_LANG == "en-us"
|
||||
+59
-19
@@ -1,9 +1,17 @@
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from io import BytesIO
|
||||
import os
|
||||
"""Text-to-speech engine using Kokoro TTS."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from kokoro_tts import Kokoro, chunk_text, process_chunk_sequential
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf # type: ignore[import-untyped]
|
||||
from kokoro_tts import ( # type: ignore[import-untyped]
|
||||
Kokoro,
|
||||
chunk_text,
|
||||
process_chunk_sequential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,40 +22,72 @@ DEFAULT_LANG = "en-us"
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
def __init__(self, model_path: str, voices_path: str):
|
||||
"""Text-to-speech engine wrapper around Kokoro TTS."""
|
||||
|
||||
def __init__(self, model_path: str, voices_path: str) -> None:
|
||||
"""Initialize the TTS engine with model and voices paths.
|
||||
|
||||
Args:
|
||||
model_path: Path to the Kokoro model file.
|
||||
voices_path: Path to the voices file.
|
||||
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.voices_path = voices_path
|
||||
self.kokoro = Kokoro(model_path, voices_path)
|
||||
logger.info("Kokoro TTS engine initialized")
|
||||
|
||||
def generate_audio(self, text: str, voice: str = DEFAULT_VOICE, speed: float = DEFAULT_SPEED, lang: str = DEFAULT_LANG) -> BytesIO:
|
||||
def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
voice: str = DEFAULT_VOICE,
|
||||
speed: float = DEFAULT_SPEED,
|
||||
lang: str = DEFAULT_LANG,
|
||||
) -> BytesIO:
|
||||
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
||||
all_samples = []
|
||||
sample_rate = None
|
||||
all_samples: list[np.ndarray] = []
|
||||
sample_rate: int | None = None
|
||||
|
||||
chunks = chunk_text(text)
|
||||
logger.info(f"Split text into {len(chunks)} chunks")
|
||||
chunks: list[str] = list(chunk_text(text))
|
||||
logger.info("Split text into %d chunks", len(chunks))
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
samples, sr = process_chunk_sequential(chunk, self.kokoro, voice, speed, lang)
|
||||
samples, sr = process_chunk_sequential(
|
||||
chunk,
|
||||
self.kokoro,
|
||||
voice,
|
||||
speed,
|
||||
lang,
|
||||
)
|
||||
if samples is not None:
|
||||
if sample_rate is None:
|
||||
sample_rate = sr
|
||||
all_samples.append(samples)
|
||||
logger.info(f"Processed chunk {i+1}/{len(chunks)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk {i+1}: {e}")
|
||||
all_samples.append(np.asarray(samples))
|
||||
logger.info("Processed chunk %d/%d", i + 1, len(chunks))
|
||||
except Exception:
|
||||
logger.exception("Error processing chunk %d", i + 1)
|
||||
continue
|
||||
|
||||
if not all_samples:
|
||||
raise ValueError("No audio samples generated - text may be invalid or too long")
|
||||
msg = "No audio samples generated - text may be invalid or too long"
|
||||
raise ValueError(msg)
|
||||
|
||||
combined = np.concatenate(all_samples)
|
||||
|
||||
buffer = BytesIO()
|
||||
sf.write(buffer, combined, sample_rate, format="MP3", subtype="MPEG_LAYER_III")
|
||||
sf.write( # pyright: ignore[reportUnknownMemberType]
|
||||
buffer,
|
||||
combined,
|
||||
sample_rate,
|
||||
format="MP3",
|
||||
subtype="MPEG_LAYER_III",
|
||||
)
|
||||
buffer.seek(0)
|
||||
|
||||
logger.info(f"Generated MP3 audio: {len(combined)} samples at {sample_rate}Hz")
|
||||
logger.info(
|
||||
"Generated MP3 audio: %d samples at %dHz",
|
||||
len(combined),
|
||||
sample_rate or 0,
|
||||
)
|
||||
return buffer
|
||||
|
||||
Reference in New Issue
Block a user