Files

487 lines
14 KiB
Python

"""Tests for the database module."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
if TYPE_CHECKING:
from vibe_bot.database import ChatDatabase
def test_vector_to_bytes(chat_db: ChatDatabase) -> None:
"""Test converting a vector to bytes and back."""
vector: list[float] = [0.1, 0.2, 0.3, 0.4]
blob = chat_db._vector_to_bytes(vector)
assert isinstance(blob, bytes)
assert len(blob) == len(vector) * 4 # float32 = 4 bytes
reconstructed = chat_db._bytes_to_vector(blob)
assert np.allclose(reconstructed, np.array(vector, dtype=np.float32))
def test_bytes_to_vector(chat_db: ChatDatabase) -> None:
"""Test converting bytes back to a numpy vector."""
original = np.array([1.0, 2.0, 3.0], dtype=np.float32)
blob = original.tobytes()
result = chat_db._bytes_to_vector(blob)
assert np.array_equal(result, original)
def test_calculate_similarity_self(chat_db: ChatDatabase) -> None:
"""Test cosine similarity of a vector with itself is 1.0."""
vec = np.array([1.0, 2.0, 3.0], dtype=np.float32)
similarity = chat_db._calculate_similarity(vec, vec)
assert similarity == pytest.approx(1.0, abs=1e-6)
def test_calculate_similarity_orthogonal(chat_db: ChatDatabase) -> None:
"""Test cosine similarity of orthogonal vectors is 0."""
vec1 = np.array([1.0, 0.0], dtype=np.float32)
vec2 = np.array([0.0, 1.0], dtype=np.float32)
similarity = chat_db._calculate_similarity(vec1, vec2)
assert similarity == pytest.approx(0.0, abs=1e-6)
def test_calculate_similarity_negative(chat_db: ChatDatabase) -> None:
"""Test cosine similarity of opposite vectors is -1."""
vec1 = np.array([1.0, 0.0], dtype=np.float32)
vec2 = np.array([-1.0, 0.0], dtype=np.float32)
similarity = chat_db._calculate_similarity(vec1, vec2)
assert similarity == pytest.approx(-1.0, abs=1e-6)
def test_add_message(chat_db: ChatDatabase, mock_embedding: MagicMock) -> None:
"""Test adding a message to the database."""
result = chat_db.add_message(
message_id="msg-1",
user_id="user-1",
username="testuser",
content="Hello world",
channel_id="chan-1",
guild_id="guild-1",
)
assert result is True
messages = chat_db.get_recent_messages(limit=10)
assert len(messages) == 1
assert messages[0][0] == "msg-1"
assert messages[0][1] == "testuser"
assert messages[0][2] == "Hello world"
def test_add_message_no_embedding(chat_db: ChatDatabase) -> None:
"""Test adding a message when embedding generation fails."""
with patch("vibe_bot.llama_wrapper.embedding", return_value=None):
result = chat_db.add_message(
message_id="msg-no-embed",
user_id="user-1",
username="testuser",
content="No embedding message",
channel_id="chan-1",
guild_id="guild-1",
)
assert result is True
def test_add_message_duplicate(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test adding a duplicate message replaces the old one."""
chat_db.add_message(
message_id="msg-dup",
user_id="user-1",
username="testuser",
content="First content",
)
chat_db.add_message(
message_id="msg-dup",
user_id="user-1",
username="testuser",
content="Second content",
)
messages = chat_db.get_recent_messages(limit=10)
assert len(messages) == 1
assert messages[0][2] == "Second content"
def test_add_message_failure(chat_db: ChatDatabase) -> None:
"""Test that add_message returns False on database error."""
with patch.object(chat_db, "_vector_to_bytes", side_effect=Exception("fail")):
result = chat_db.add_message(
message_id="msg-fail",
user_id="user-1",
username="testuser",
content="Should fail",
)
assert result is False
def test_get_recent_messages(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test retrieving recent messages."""
chat_db.add_message(
message_id="msg-1",
user_id="u1",
username="alice",
content="First",
)
chat_db.add_message(
message_id="msg-2",
user_id="u2",
username="bob",
content="Second",
)
chat_db.add_message(
message_id="msg-3",
user_id="u1",
username="alice",
content="Third",
)
messages = chat_db.get_recent_messages(limit=2)
assert len(messages) == 2
assert messages[0][2] == "Third"
assert messages[1][2] == "Second"
def test_get_recent_messages_limit(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test that get_recent_messages respects the limit."""
for i in range(5):
chat_db.add_message(
message_id=f"msg-{i}",
user_id="u1",
username="alice",
content=f"Message {i}",
)
messages = chat_db.get_recent_messages(limit=3)
assert len(messages) == 3
def test_clear_all_messages(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test clearing all messages."""
chat_db.add_message(
message_id="msg-1",
user_id="u1",
username="alice",
content="Hello",
)
chat_db.add_message(
message_id="msg-2",
user_id="u2",
username="bob",
content="World",
)
chat_db.clear_all_messages()
messages = chat_db.get_recent_messages(limit=10)
assert len(messages) == 0
def test_get_user_history(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test retrieving user message history."""
chat_db.add_message(
message_id="msg-1",
user_id="u1",
username="alice",
content="User question",
)
chat_db.add_message(
message_id="msg-1_response",
user_id="bot",
username="vibe-bot",
content="Bot answer",
)
conversations = chat_db.get_user_history("u1")
assert len(conversations) == 1
assert conversations[0][0] == "User question"
assert conversations[0][1] == "Bot answer"
def test_get_user_history_no_response(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test user history when there is no bot response."""
chat_db.add_message(
message_id="msg-1",
user_id="u1",
username="alice",
content="User question with no response",
)
conversations = chat_db.get_user_history("u1")
assert len(conversations) == 0
def test_get_user_history_excludes_bot(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test that bot messages are excluded from user history."""
chat_db.add_message(
message_id="msg-1",
user_id="bot",
username="vibe-bot",
content="Bot message",
)
conversations = chat_db.get_user_history("u1")
assert len(conversations) == 0
def test_get_conversation_context(
chat_db: ChatDatabase,
mock_embedding: MagicMock,
) -> None:
"""Test getting conversation context for RAG."""
chat_db.add_message(
message_id="msg-1",
user_id="u1",
username="alice",
content="Previous question",
)
chat_db.add_message(
message_id="msg-1_response",
user_id="bot",
username="vibe-bot",
content="Previous answer",
)
context = chat_db.get_conversation_context("u1", "current message")
assert isinstance(context, list)
assert len(context) >= 2
def test_get_conversation_context_empty(chat_db: ChatDatabase) -> None:
"""Test getting context when there is no history."""
context = chat_db.get_conversation_context("u1", "new message")
assert context == []
def test_custom_bot_create(custom_bot_manager: Any) -> None:
"""Test creating a custom bot."""
result = custom_bot_manager.create_custom_bot(
bot_name="alfred",
system_prompt="You are a british butler",
created_by="user-123",
)
assert result is True
def test_custom_bot_create_duplicate(
custom_bot_manager: Any,
) -> None:
"""Test creating a duplicate custom bot replaces the old one."""
custom_bot_manager.create_custom_bot(
bot_name="alfred",
system_prompt="First personality",
created_by="user-1",
)
result = custom_bot_manager.create_custom_bot(
bot_name="alfred",
system_prompt="Second personality",
created_by="user-1",
)
assert result is True
bot = custom_bot_manager.get_custom_bot("alfred")
assert bot is not None
assert bot[1] == "Second personality"
def test_custom_bot_create_case_insensitive(
custom_bot_manager: Any,
) -> None:
"""Test that bot names are case-insensitive."""
custom_bot_manager.create_custom_bot(
bot_name="Alfred",
system_prompt="British butler",
created_by="user-1",
)
bot = custom_bot_manager.get_custom_bot("alfred")
assert bot is not None
def test_custom_bot_get_not_found(custom_bot_manager: Any) -> None:
"""Test getting a non-existent custom bot returns None."""
result = custom_bot_manager.get_custom_bot("nonexistent")
assert result is None
def test_custom_bot_get_returns_correct_data(
custom_bot_manager: Any,
) -> None:
"""Test that get_custom_bot returns the correct bot data."""
custom_bot_manager.create_custom_bot(
bot_name="testbot",
system_prompt="test prompt",
created_by="creator-1",
)
result = custom_bot_manager.get_custom_bot("testbot")
assert result is not None
assert result[0] == "testbot"
assert result[1] == "test prompt"
assert result[2] == "creator-1"
assert result[3] is not None
assert "20" in result[3]
def test_custom_bot_list_empty(custom_bot_manager: Any) -> None:
"""Test listing custom bots when none exist."""
bots = custom_bot_manager.list_custom_bots()
assert bots == []
def test_custom_bot_list(custom_bot_manager: Any) -> None:
"""Test listing custom bots."""
custom_bot_manager.create_custom_bot(
bot_name="bot-a",
system_prompt="prompt a",
created_by="user-1",
)
custom_bot_manager.create_custom_bot(
bot_name="bot-b",
system_prompt="prompt b",
created_by="user-2",
)
bots = custom_bot_manager.list_custom_bots()
assert len(bots) == 2
def test_custom_bot_delete(custom_bot_manager: Any) -> None:
"""Test deleting a custom bot."""
custom_bot_manager.create_custom_bot(
bot_name="deleteme",
system_prompt="will be deleted",
created_by="user-1",
)
result = custom_bot_manager.delete_custom_bot("deleteme")
assert result is True
bot = custom_bot_manager.get_custom_bot("deleteme")
assert bot is None
def test_custom_bot_delete_nonexistent(
custom_bot_manager: Any,
) -> None:
"""Test deleting a non-existent bot returns False."""
result = custom_bot_manager.delete_custom_bot("nonexistent")
assert result is False
def test_custom_bot_deactivate(custom_bot_manager: Any) -> None:
"""Test deactivating a custom bot."""
custom_bot_manager.create_custom_bot(
bot_name="inactive-bot",
system_prompt="will be deactivated",
created_by="user-1",
)
result = custom_bot_manager.deactivate_custom_bot("inactive-bot")
assert result is True
bot = custom_bot_manager.get_custom_bot("inactive-bot")
assert bot is None
def test_custom_bot_deactivate_nonexistent(
custom_bot_manager: Any,
) -> None:
"""Test deactivating a non-existent bot returns False."""
result = custom_bot_manager.deactivate_custom_bot("nonexistent")
assert result is False
def test_custom_bot_list_excludes_inactive(
custom_bot_manager: Any,
) -> None:
"""Test that list_custom_bots excludes deactivated bots."""
custom_bot_manager.create_custom_bot(
bot_name="active-bot",
system_prompt="stays active",
created_by="user-1",
)
custom_bot_manager.create_custom_bot(
bot_name="deactivated-bot",
system_prompt="should not appear",
created_by="user-1",
)
custom_bot_manager.deactivate_custom_bot("deactivated-bot")
bots = custom_bot_manager.list_custom_bots()
assert len(bots) == 1
assert bots[0][0] == "active-bot"
def test_custom_bot_delete_with_error(
custom_bot_manager: Any,
) -> None:
"""Test that delete_custom_bot returns False on error."""
with patch.object(
custom_bot_manager,
"_initialize_custom_bots_table",
side_effect=Exception("db error"),
):
pass
result = custom_bot_manager.delete_custom_bot("nonexistent")
assert result is False
def test_database_get_database_singleton(temp_db_path: str) -> None:
"""Test that get_database returns the same instance."""
import vibe_bot.database as db_module
from vibe_bot.database import ChatDatabase, get_database
db_module._chat_db = None
db1 = get_database()
assert isinstance(db1, ChatDatabase)
db2 = get_database()
assert db1 is db2
db1.client.close()
def test_database_init_creates_tables(temp_db_path: str) -> None:
"""Test that database initialization creates the expected tables."""
from vibe_bot.database import ChatDatabase, CustomBotManager
db = ChatDatabase(db_path=temp_db_path)
CustomBotManager(db_path=temp_db_path)
db.client.close()
import sqlite3
conn = sqlite3.connect(temp_db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = {row[0] for row in cursor.fetchall()}
conn.close()
assert "chat_messages" in tables
assert "message_embeddings" in tables
assert "custom_bots" in tables