fix linting, formatting, and add tests
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Tests for the vibe_bot package."""
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Shared test fixtures for vibe_bot tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Exception ignored in.*FileIO.*Bad file descriptor",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vibe_bot.database import ChatDatabase, CustomBotManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env_vars() -> Generator[None]:
|
||||
"""Provide minimal env vars for config loading."""
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
"CHAT_ENDPOINT_KEY": "test-key",
|
||||
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
||||
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
||||
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
||||
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
||||
"MAX_COMPLETION_TOKENS": "1000",
|
||||
"MAX_HISTORY_MESSAGES": "1000",
|
||||
"SIMILARITY_THRESHOLD": "0.7",
|
||||
"TOP_K_RESULTS": "5",
|
||||
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
||||
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
||||
"TTS_VOICE": "af_sarah",
|
||||
"TTS_SPEED": "1.0",
|
||||
"DB_PATH": ":memory:",
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path() -> Generator[str]:
|
||||
"""Provide a temporary SQLite database path."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
path = f.name
|
||||
yield path
|
||||
Path(path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding() -> Generator[MagicMock]:
|
||||
"""Provide a mock embedding function returning a fixed vector."""
|
||||
vector: list[float] = [0.1] * 2048
|
||||
with patch("vibe_bot.llama_wrapper.embedding", return_value=vector) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client() -> Generator[MagicMock]:
|
||||
"""Provide a mock OpenAI client."""
|
||||
mock_client = MagicMock()
|
||||
with patch("vibe_bot.database.OpenAI", return_value=mock_client) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat_db(
|
||||
temp_db_path: str,
|
||||
mock_openai_client: MagicMock,
|
||||
mock_embedding: MagicMock,
|
||||
) -> Generator[ChatDatabase]:
|
||||
"""Provide a ChatDatabase instance with a temp database."""
|
||||
from vibe_bot.database import ChatDatabase
|
||||
|
||||
db = ChatDatabase(db_path=temp_db_path)
|
||||
yield db
|
||||
db.client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_bot_manager(temp_db_path: str) -> CustomBotManager:
|
||||
"""Provide a CustomBotManager instance with a temp database."""
|
||||
from vibe_bot.database import CustomBotManager
|
||||
|
||||
manager = CustomBotManager(db_path=temp_db_path)
|
||||
return manager # noqa: RET504
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kokoro_tts() -> Generator[dict[str, Any]]:
|
||||
"""Provide mock Kokoro TTS components."""
|
||||
mock_kokoro = MagicMock()
|
||||
mock_kokoro_instance = MagicMock()
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.return_value = ["hello world", "this is a test"]
|
||||
|
||||
mock_samples = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||
mock_process = MagicMock(return_value=(mock_samples, 24000))
|
||||
|
||||
with patch("vibe_bot.tts.Kokoro", return_value=mock_kokoro_instance): # noqa: SIM117
|
||||
with patch("vibe_bot.tts.chunk_text", mock_chunk):
|
||||
with patch("vibe_bot.tts.process_chunk_sequential", mock_process):
|
||||
yield {
|
||||
"Kokoro": mock_kokoro,
|
||||
"chunk_text": mock_chunk,
|
||||
"process_chunk_sequential": mock_process,
|
||||
"kokoro_instance": mock_kokoro_instance,
|
||||
"mock_samples": mock_samples,
|
||||
"mock_sr": 24000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord() -> Generator[dict[str, MagicMock]]:
|
||||
"""Mock discord module components."""
|
||||
mock_intents = MagicMock()
|
||||
mock_intents.default.return_value = MagicMock()
|
||||
mock_intents.default.return_value.message_content = True
|
||||
|
||||
mock_bot_class = MagicMock()
|
||||
mock_bot_instance = MagicMock()
|
||||
mock_bot_instance.user = MagicMock()
|
||||
mock_bot_instance.user.name = "test-bot"
|
||||
mock_bot_instance.user.id = "123456789"
|
||||
|
||||
with patch("vibe_bot.main.discord") as mock_discord_module: # noqa: SIM117
|
||||
with patch("vibe_bot.main.commands", MagicMock()):
|
||||
with patch("vibe_bot.main.commands.Bot", mock_bot_class):
|
||||
mock_bot_class.return_value = mock_bot_instance
|
||||
mock_discord_module.Intents = mock_intents
|
||||
mock_discord_module.Message = MagicMock
|
||||
mock_discord_module.File = MagicMock
|
||||
yield {
|
||||
"Intents": mock_intents,
|
||||
"Bot": mock_bot_class,
|
||||
"bot_instance": mock_bot_instance,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_engine() -> Generator[MagicMock]:
|
||||
"""Provide a mock TTSEngine."""
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.generate_audio.return_value = MagicMock()
|
||||
with patch("vibe_bot.main.tts_engine", mock_engine): # noqa: SIM117
|
||||
with patch("vibe_bot.main.tts.TTSEngine", return_value=mock_engine):
|
||||
yield mock_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests() -> Generator[MagicMock]:
|
||||
"""Provide mock requests module."""
|
||||
with patch("vibe_bot.main.requests") as mock_requests_module:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image data"
|
||||
mock_requests_module.get.return_value = mock_response
|
||||
yield mock_requests_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base64() -> Generator[MagicMock]:
|
||||
"""Provide mock base64 module."""
|
||||
with patch("vibe_bot.main.base64") as mock_base64_module:
|
||||
mock_base64_module.b64decode.return_value = b"fake image data"
|
||||
yield mock_base64_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llama_wrapper() -> Generator[MagicMock]:
|
||||
"""Provide mock llama_wrapper module."""
|
||||
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
|
||||
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
|
||||
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
|
||||
mock_wrapper.image_generation.return_value = ""
|
||||
mock_wrapper.image_edit.return_value = ""
|
||||
mock_wrapper.embedding.return_value = [0.1] * 2048
|
||||
yield mock_wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database() -> Generator[MagicMock]:
|
||||
"""Provide mock database module."""
|
||||
with patch("vibe_bot.main.get_database") as mock_get_db:
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_conversation_context.return_value = []
|
||||
mock_db.add_message.return_value = True
|
||||
mock_get_db.return_value = mock_db
|
||||
yield mock_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_custom_bot_manager() -> Generator[MagicMock]:
|
||||
"""Provide mock CustomBotManager."""
|
||||
with patch("vibe_bot.main.CustomBotManager") as mock_manager_class:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.create_custom_bot.return_value = True
|
||||
mock_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"british butler personality",
|
||||
"user123",
|
||||
"2024-01-01",
|
||||
)
|
||||
mock_manager.list_custom_bots.return_value = [
|
||||
("alfred", "british butler personality", "user123"),
|
||||
]
|
||||
mock_manager.delete_custom_bot.return_value = True
|
||||
mock_manager_class.return_value = mock_manager
|
||||
yield mock_manager
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Tests for the config module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def test_config_defaults() -> None:
|
||||
"""Test that config loads with expected default values."""
|
||||
env_str = ""
|
||||
for k, v in {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
"CHAT_ENDPOINT_KEY": "test-key",
|
||||
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
||||
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
||||
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
||||
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
||||
"MAX_COMPLETION_TOKENS": "1000",
|
||||
"MAX_HISTORY_MESSAGES": "1000",
|
||||
"SIMILARITY_THRESHOLD": "0.7",
|
||||
"TOP_K_RESULTS": "5",
|
||||
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
||||
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
||||
"TTS_VOICE": "af_sarah",
|
||||
"TTS_SPEED": "1.0",
|
||||
"DB_PATH": ":memory:",
|
||||
}.items():
|
||||
env_str += f'os.environ["{k}"] = "{v}"\n'
|
||||
|
||||
code = f"""
|
||||
import sys
|
||||
sys.path.insert(0, "/var/home/ducoterra/Projects/vibe_discord_bots")
|
||||
import os
|
||||
os.environ.clear()
|
||||
os.environ["PATH"] = "/usr/bin:/bin"
|
||||
{env_str}
|
||||
import vibe_bot.config
|
||||
assert vibe_bot.config.DISCORD_TOKEN == "test-token"
|
||||
assert vibe_bot.config.CHAT_ENDPOINT == "https://chat.example.com/v1"
|
||||
assert vibe_bot.config.COMPLETION_ENDPOINT == "https://completion.example.com/v1"
|
||||
assert vibe_bot.config.IMAGE_GEN_ENDPOINT == "https://image.example.com/v1"
|
||||
assert vibe_bot.config.IMAGE_EDIT_ENDPOINT == "https://image-edit.example.com/v1"
|
||||
assert vibe_bot.config.EMBEDDING_ENDPOINT == "https://embedding.example.com/v1"
|
||||
assert vibe_bot.config.CHAT_MODEL == "test-chat-model"
|
||||
assert vibe_bot.config.COMPLETION_MODEL == "test-completion-model"
|
||||
assert vibe_bot.config.IMAGE_GEN_MODEL == "test-image-model"
|
||||
assert vibe_bot.config.IMAGE_EDIT_MODEL == "test-image-edit-model"
|
||||
assert vibe_bot.config.EMBEDDING_MODEL == "test-embedding-model"
|
||||
assert vibe_bot.config.MAX_COMPLETION_TOKENS == 1000
|
||||
assert vibe_bot.config.MAX_HISTORY_MESSAGES == 1000
|
||||
assert vibe_bot.config.SIMILARITY_THRESHOLD == 0.7
|
||||
assert vibe_bot.config.TOP_K_RESULTS == 5
|
||||
assert vibe_bot.config.TTS_MODEL_PATH == "/tmp/test-model.onnx"
|
||||
assert vibe_bot.config.TTS_VOICES_PATH == "/tmp/test-voices.bin"
|
||||
assert vibe_bot.config.TTS_VOICE == "af_sarah"
|
||||
assert vibe_bot.config.TTS_SPEED == 1.0
|
||||
print("OK")
|
||||
"""
|
||||
|
||||
result = subprocess.run( # noqa: PLW1510, S603
|
||||
[sys.executable, "-c", code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
assert result.returncode == 0, f"Subprocess failed: {result.stderr}"
|
||||
|
||||
|
||||
def _run_config_check(env_vars: dict[str, str], expected_error: str) -> None:
|
||||
"""Run a subprocess that imports config and checks for expected RuntimeError."""
|
||||
env_str = ""
|
||||
for k, v in env_vars.items():
|
||||
env_str += f'os.environ["{k}"] = "{v}"\n'
|
||||
|
||||
code = f"""
|
||||
import sys
|
||||
sys.path.insert(0, "/var/home/ducoterra/Projects/vibe_discord_bots")
|
||||
import os
|
||||
os.environ.clear()
|
||||
os.environ["PATH"] = "/usr/bin:/bin"
|
||||
{env_str}
|
||||
try:
|
||||
import vibe_bot.config
|
||||
print("NO_ERROR")
|
||||
except RuntimeError as e:
|
||||
print(f"ERROR: {{e}}")
|
||||
except Exception as e:
|
||||
print(f"OTHER: {{type(e).__name__}}: {{e}}")
|
||||
"""
|
||||
|
||||
result = subprocess.run( # noqa: PLW1510, S603
|
||||
[sys.executable, "-c", code],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
output = result.stdout.strip()
|
||||
assert output.startswith("ERROR:") and expected_error in output, ( # noqa: PT018
|
||||
f"Expected error '{expected_error}' but got: {output}"
|
||||
)
|
||||
|
||||
|
||||
def test_config_missing_discord_token() -> None:
|
||||
"""Test that RuntimeError is raised when DISCORD_TOKEN is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "DISCORD_TOKEN required")
|
||||
|
||||
|
||||
def test_config_missing_chat_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when CHAT_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "CHAT_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_completion_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when COMPLETION_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "COMPLETION_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_image_gen_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_GEN_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_GEN_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_image_edit_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_EDIT_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_EDIT_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_embedding_endpoint() -> None:
|
||||
"""Test that RuntimeError is raised when EMBEDDING_ENDPOINT is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "EMBEDDING_ENDPOINT required")
|
||||
|
||||
|
||||
def test_config_missing_chat_model() -> None:
|
||||
"""Test that RuntimeError is raised when CHAT_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "CHAT_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_completion_model() -> None:
|
||||
"""Test that RuntimeError is raised when COMPLETION_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "COMPLETION_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_image_gen_model() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_GEN_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_GEN_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_image_edit_model() -> None:
|
||||
"""Test that RuntimeError is raised when IMAGE_EDIT_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "",
|
||||
"EMBEDDING_MODEL": "test-embedding-model",
|
||||
}
|
||||
_run_config_check(env, "IMAGE_EDIT_MODEL required")
|
||||
|
||||
|
||||
def test_config_missing_embedding_model() -> None:
|
||||
"""Test that RuntimeError is raised when EMBEDDING_MODEL is missing."""
|
||||
env: dict[str, str] = {
|
||||
"DISCORD_TOKEN": "test-token",
|
||||
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
||||
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
||||
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
||||
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
||||
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
||||
"CHAT_MODEL": "test-chat-model",
|
||||
"COMPLETION_MODEL": "test-completion-model",
|
||||
"IMAGE_GEN_MODEL": "test-image-model",
|
||||
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
||||
"EMBEDDING_MODEL": "",
|
||||
}
|
||||
_run_config_check(env, "EMBEDDING_MODEL required")
|
||||
|
||||
|
||||
def test_config_logging_exists() -> None:
|
||||
"""Test that logging is configured in config module."""
|
||||
from vibe_bot.config import logger
|
||||
|
||||
assert logger is not None
|
||||
assert logger.name == "vibe_bot.config"
|
||||
|
||||
|
||||
def test_config_embedding_dimension() -> None:
|
||||
"""Test that EMBEDDING_DIMENSION has expected default value."""
|
||||
from vibe_bot.config import EMBEDDING_DIMENSION
|
||||
|
||||
assert EMBEDDING_DIMENSION == 2048
|
||||
@@ -0,0 +1,464 @@
|
||||
"""Tests for the database module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vibe_bot.database import ChatDatabase
|
||||
|
||||
|
||||
def test_vector_to_bytes(chat_db: ChatDatabase) -> None:
|
||||
"""Test converting a vector to bytes and back."""
|
||||
vector: list[float] = [0.1, 0.2, 0.3, 0.4]
|
||||
blob = chat_db._vector_to_bytes(vector)
|
||||
assert isinstance(blob, bytes)
|
||||
assert len(blob) == len(vector) * 4 # float32 = 4 bytes
|
||||
|
||||
reconstructed = chat_db._bytes_to_vector(blob)
|
||||
assert np.allclose(reconstructed, np.array(vector, dtype=np.float32))
|
||||
|
||||
|
||||
def test_bytes_to_vector(chat_db: ChatDatabase) -> None:
|
||||
"""Test converting bytes back to a numpy vector."""
|
||||
original = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
blob = original.tobytes()
|
||||
result = chat_db._bytes_to_vector(blob)
|
||||
assert np.array_equal(result, original)
|
||||
|
||||
|
||||
def test_calculate_similarity_self(chat_db: ChatDatabase) -> None:
|
||||
"""Test cosine similarity of a vector with itself is 1.0."""
|
||||
vec = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
similarity = chat_db._calculate_similarity(vec, vec)
|
||||
assert similarity == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_calculate_similarity_orthogonal(chat_db: ChatDatabase) -> None:
|
||||
"""Test cosine similarity of orthogonal vectors is 0."""
|
||||
vec1 = np.array([1.0, 0.0], dtype=np.float32)
|
||||
vec2 = np.array([0.0, 1.0], dtype=np.float32)
|
||||
similarity = chat_db._calculate_similarity(vec1, vec2)
|
||||
assert similarity == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_calculate_similarity_negative(chat_db: ChatDatabase) -> None:
|
||||
"""Test cosine similarity of opposite vectors is -1."""
|
||||
vec1 = np.array([1.0, 0.0], dtype=np.float32)
|
||||
vec2 = np.array([-1.0, 0.0], dtype=np.float32)
|
||||
similarity = chat_db._calculate_similarity(vec1, vec2)
|
||||
assert similarity == pytest.approx(-1.0, abs=1e-6)
|
||||
|
||||
|
||||
def test_add_message(chat_db: ChatDatabase, mock_embedding: MagicMock) -> None:
|
||||
"""Test adding a message to the database."""
|
||||
result = chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="Hello world",
|
||||
channel_id="chan-1",
|
||||
guild_id="guild-1",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=10)
|
||||
assert len(messages) == 1
|
||||
assert messages[0][0] == "msg-1"
|
||||
assert messages[0][1] == "testuser"
|
||||
assert messages[0][2] == "Hello world"
|
||||
|
||||
|
||||
def test_add_message_no_embedding(chat_db: ChatDatabase) -> None:
|
||||
"""Test adding a message when embedding generation fails."""
|
||||
with patch("vibe_bot.llama_wrapper.embedding", return_value=None):
|
||||
result = chat_db.add_message(
|
||||
message_id="msg-no-embed",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="No embedding message",
|
||||
channel_id="chan-1",
|
||||
guild_id="guild-1",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_add_message_duplicate(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test adding a duplicate message replaces the old one."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-dup",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="First content",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-dup",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="Second content",
|
||||
)
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=10)
|
||||
assert len(messages) == 1
|
||||
assert messages[0][2] == "Second content"
|
||||
|
||||
|
||||
def test_add_message_failure(chat_db: ChatDatabase) -> None:
|
||||
"""Test that add_message returns False on database error."""
|
||||
with patch.object(chat_db, "_vector_to_bytes", side_effect=Exception("fail")):
|
||||
result = chat_db.add_message(
|
||||
message_id="msg-fail",
|
||||
user_id="user-1",
|
||||
username="testuser",
|
||||
content="Should fail",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_get_recent_messages(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test retrieving recent messages."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1", user_id="u1", username="alice", content="First",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-2", user_id="u2", username="bob", content="Second",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-3", user_id="u1", username="alice", content="Third",
|
||||
)
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=2)
|
||||
assert len(messages) == 2
|
||||
assert messages[0][2] == "Third"
|
||||
assert messages[1][2] == "Second"
|
||||
|
||||
|
||||
def test_get_recent_messages_limit(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test that get_recent_messages respects the limit."""
|
||||
for i in range(5):
|
||||
chat_db.add_message(
|
||||
message_id=f"msg-{i}",
|
||||
user_id="u1",
|
||||
username="alice",
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=3)
|
||||
assert len(messages) == 3
|
||||
|
||||
|
||||
def test_clear_all_messages(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test clearing all messages."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1", user_id="u1", username="alice", content="Hello",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-2", user_id="u2", username="bob", content="World",
|
||||
)
|
||||
|
||||
chat_db.clear_all_messages()
|
||||
|
||||
messages = chat_db.get_recent_messages(limit=10)
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
def test_get_user_history(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test retrieving user message history."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1", user_id="u1", username="alice", content="User question",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-1_response",
|
||||
user_id="bot",
|
||||
username="vibe-bot",
|
||||
content="Bot answer",
|
||||
)
|
||||
|
||||
conversations = chat_db.get_user_history("u1")
|
||||
assert len(conversations) == 1
|
||||
assert conversations[0][0] == "User question"
|
||||
assert conversations[0][1] == "Bot answer"
|
||||
|
||||
|
||||
def test_get_user_history_no_response(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test user history when there is no bot response."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="u1",
|
||||
username="alice",
|
||||
content="User question with no response",
|
||||
)
|
||||
|
||||
conversations = chat_db.get_user_history("u1")
|
||||
assert len(conversations) == 0
|
||||
|
||||
|
||||
def test_get_user_history_excludes_bot(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test that bot messages are excluded from user history."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="bot",
|
||||
username="vibe-bot",
|
||||
content="Bot message",
|
||||
)
|
||||
|
||||
conversations = chat_db.get_user_history("u1")
|
||||
assert len(conversations) == 0
|
||||
|
||||
|
||||
def test_get_conversation_context(
|
||||
chat_db: ChatDatabase,
|
||||
mock_embedding: MagicMock,
|
||||
) -> None:
|
||||
"""Test getting conversation context for RAG."""
|
||||
chat_db.add_message(
|
||||
message_id="msg-1",
|
||||
user_id="u1",
|
||||
username="alice",
|
||||
content="Previous question",
|
||||
)
|
||||
chat_db.add_message(
|
||||
message_id="msg-1_response",
|
||||
user_id="bot",
|
||||
username="vibe-bot",
|
||||
content="Previous answer",
|
||||
)
|
||||
|
||||
context = chat_db.get_conversation_context("u1", "current message")
|
||||
assert isinstance(context, list)
|
||||
assert len(context) >= 2
|
||||
|
||||
|
||||
def test_get_conversation_context_empty(chat_db: ChatDatabase) -> None:
|
||||
"""Test getting context when there is no history."""
|
||||
context = chat_db.get_conversation_context("u1", "new message")
|
||||
assert context == []
|
||||
|
||||
|
||||
def test_custom_bot_create(custom_bot_manager: Any) -> None:
|
||||
"""Test creating a custom bot."""
|
||||
result = custom_bot_manager.create_custom_bot(
|
||||
bot_name="alfred",
|
||||
system_prompt="You are a british butler",
|
||||
created_by="user-123",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_custom_bot_create_duplicate(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test creating a duplicate custom bot replaces the old one."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="alfred",
|
||||
system_prompt="First personality",
|
||||
created_by="user-1",
|
||||
)
|
||||
result = custom_bot_manager.create_custom_bot(
|
||||
bot_name="alfred",
|
||||
system_prompt="Second personality",
|
||||
created_by="user-1",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
bot = custom_bot_manager.get_custom_bot("alfred")
|
||||
assert bot is not None
|
||||
assert bot[1] == "Second personality"
|
||||
|
||||
|
||||
def test_custom_bot_create_case_insensitive(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that bot names are case-insensitive."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="Alfred",
|
||||
system_prompt="British butler",
|
||||
created_by="user-1",
|
||||
)
|
||||
bot = custom_bot_manager.get_custom_bot("alfred")
|
||||
assert bot is not None
|
||||
|
||||
|
||||
def test_custom_bot_get_not_found(custom_bot_manager: Any) -> None:
|
||||
"""Test getting a non-existent custom bot returns None."""
|
||||
result = custom_bot_manager.get_custom_bot("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_custom_bot_get_returns_correct_data(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that get_custom_bot returns the correct bot data."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="testbot",
|
||||
system_prompt="test prompt",
|
||||
created_by="creator-1",
|
||||
)
|
||||
result = custom_bot_manager.get_custom_bot("testbot")
|
||||
assert result is not None
|
||||
assert result[0] == "testbot"
|
||||
assert result[1] == "test prompt"
|
||||
assert result[2] == "creator-1"
|
||||
assert result[3] is not None
|
||||
assert "20" in result[3]
|
||||
|
||||
|
||||
def test_custom_bot_list_empty(custom_bot_manager: Any) -> None:
|
||||
"""Test listing custom bots when none exist."""
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
assert bots == []
|
||||
|
||||
|
||||
def test_custom_bot_list(custom_bot_manager: Any) -> None:
|
||||
"""Test listing custom bots."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="bot-a",
|
||||
system_prompt="prompt a",
|
||||
created_by="user-1",
|
||||
)
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="bot-b",
|
||||
system_prompt="prompt b",
|
||||
created_by="user-2",
|
||||
)
|
||||
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
assert len(bots) == 2
|
||||
|
||||
|
||||
def test_custom_bot_delete(custom_bot_manager: Any) -> None:
|
||||
"""Test deleting a custom bot."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="deleteme",
|
||||
system_prompt="will be deleted",
|
||||
created_by="user-1",
|
||||
)
|
||||
result = custom_bot_manager.delete_custom_bot("deleteme")
|
||||
assert result is True
|
||||
|
||||
bot = custom_bot_manager.get_custom_bot("deleteme")
|
||||
assert bot is None
|
||||
|
||||
|
||||
def test_custom_bot_delete_nonexistent(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test deleting a non-existent bot returns False."""
|
||||
result = custom_bot_manager.delete_custom_bot("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_custom_bot_deactivate(custom_bot_manager: Any) -> None:
|
||||
"""Test deactivating a custom bot."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="inactive-bot",
|
||||
system_prompt="will be deactivated",
|
||||
created_by="user-1",
|
||||
)
|
||||
result = custom_bot_manager.deactivate_custom_bot("inactive-bot")
|
||||
assert result is True
|
||||
|
||||
bot = custom_bot_manager.get_custom_bot("inactive-bot")
|
||||
assert bot is None
|
||||
|
||||
|
||||
def test_custom_bot_deactivate_nonexistent(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test deactivating a non-existent bot returns False."""
|
||||
result = custom_bot_manager.deactivate_custom_bot("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_custom_bot_list_excludes_inactive(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that list_custom_bots excludes deactivated bots."""
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="active-bot",
|
||||
system_prompt="stays active",
|
||||
created_by="user-1",
|
||||
)
|
||||
custom_bot_manager.create_custom_bot(
|
||||
bot_name="deactivated-bot",
|
||||
system_prompt="should not appear",
|
||||
created_by="user-1",
|
||||
)
|
||||
custom_bot_manager.deactivate_custom_bot("deactivated-bot")
|
||||
|
||||
bots = custom_bot_manager.list_custom_bots()
|
||||
assert len(bots) == 1
|
||||
assert bots[0][0] == "active-bot"
|
||||
|
||||
|
||||
def test_custom_bot_delete_with_error(
|
||||
custom_bot_manager: Any,
|
||||
) -> None:
|
||||
"""Test that delete_custom_bot returns False on error."""
|
||||
with patch.object(
|
||||
custom_bot_manager, "_initialize_custom_bots_table", side_effect=Exception("db error"), # noqa: E501
|
||||
):
|
||||
pass
|
||||
result = custom_bot_manager.delete_custom_bot("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_database_get_database_singleton(temp_db_path: str) -> None:
|
||||
"""Test that get_database returns the same instance."""
|
||||
import vibe_bot.database as db_module
|
||||
from vibe_bot.database import ChatDatabase, get_database
|
||||
db_module._chat_db = None
|
||||
|
||||
db1 = get_database()
|
||||
assert isinstance(db1, ChatDatabase)
|
||||
|
||||
db2 = get_database()
|
||||
assert db1 is db2
|
||||
|
||||
db1.client.close()
|
||||
|
||||
|
||||
def test_database_init_creates_tables(temp_db_path: str) -> None:
|
||||
"""Test that database initialization creates the expected tables."""
|
||||
from vibe_bot.database import ChatDatabase, CustomBotManager
|
||||
|
||||
db = ChatDatabase(db_path=temp_db_path)
|
||||
CustomBotManager(db_path=temp_db_path)
|
||||
db.client.close()
|
||||
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(temp_db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = {row[0] for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
|
||||
assert "chat_messages" in tables
|
||||
assert "message_embeddings" in tables
|
||||
assert "custom_bots" in tables
|
||||
@@ -1,36 +1,40 @@
|
||||
# Tests all functions in the llama-wrapper.py file
|
||||
# Run with: python -m pytest test_llama_wrapper.py -v
|
||||
"""Tests for the llama_wrapper module."""
|
||||
|
||||
from ..llama_wrapper import (
|
||||
chat_completion,
|
||||
chat_completion_instruct,
|
||||
image_generation,
|
||||
image_edit,
|
||||
embedding,
|
||||
)
|
||||
from ..config import (
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vibe_bot.config import (
|
||||
CHAT_ENDPOINT,
|
||||
CHAT_MODEL,
|
||||
CHAT_ENDPOINT_KEY,
|
||||
CHAT_MODEL,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
IMAGE_EDIT_ENDPOINT,
|
||||
IMAGE_EDIT_ENDPOINT_KEY,
|
||||
IMAGE_GEN_ENDPOINT,
|
||||
IMAGE_GEN_ENDPOINT_KEY,
|
||||
EMBEDDING_ENDPOINT,
|
||||
EMBEDDING_ENDPOINT_KEY,
|
||||
)
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from vibe_bot.llama_wrapper import (
|
||||
chat_completion,
|
||||
chat_completion_instruct,
|
||||
embedding,
|
||||
image_edit,
|
||||
image_generation,
|
||||
)
|
||||
|
||||
TEMPDIR = Path(tempfile.mkdtemp())
|
||||
|
||||
|
||||
def test_chat_completion_think():
|
||||
result = chat_completion(
|
||||
def test_chat_completion_think() -> None:
|
||||
"""Test chat completion with think model."""
|
||||
chat_completion(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_prompt="Tell me about Everquest",
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
@@ -38,11 +42,11 @@ def test_chat_completion_think():
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=100,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def test_chat_completion_instruct():
|
||||
result = chat_completion_instruct(
|
||||
def test_chat_completion_instruct() -> None:
|
||||
"""Test chat completion with instruct model."""
|
||||
chat_completion_instruct(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_prompt="Tell me about Everquest",
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
@@ -50,63 +54,96 @@ def test_chat_completion_instruct():
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=100,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
def test_image_generation():
|
||||
result = image_generation(
|
||||
prompt="Generate an image of a horse",
|
||||
openai_url=IMAGE_GEN_ENDPOINT,
|
||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||
)
|
||||
with open("image-gen.png", "wb") as f:
|
||||
f.write(base64.b64decode(result))
|
||||
def test_image_generation() -> None:
|
||||
"""Test image generation endpoint."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_response = MagicMock()
|
||||
mock_data = MagicMock()
|
||||
mock_data.b64_json = base64.b64encode(b"fake image data").decode()
|
||||
mock_response.data = [mock_data]
|
||||
mock_openai.return_value.images.generate.return_value = mock_response
|
||||
result = image_generation(
|
||||
prompt="Generate an image of a horse",
|
||||
openai_url=IMAGE_GEN_ENDPOINT,
|
||||
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
|
||||
)
|
||||
assert result == base64.b64encode(b"fake image data").decode()
|
||||
|
||||
|
||||
def test_image_edit():
|
||||
with open("image-gen.png", "rb") as f:
|
||||
image_data = BytesIO(f.read())
|
||||
result = image_edit(
|
||||
image=image_data,
|
||||
prompt="Paint the words 'horse' on the horse.",
|
||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||
)
|
||||
with open("image-edit.png", "wb") as f:
|
||||
f.write(base64.b64decode(result))
|
||||
def test_image_edit() -> None:
|
||||
"""Test image edit endpoint."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_response = MagicMock()
|
||||
mock_data = MagicMock()
|
||||
mock_data.b64_json = base64.b64encode(b"fake edited image data").decode()
|
||||
mock_response.data = [mock_data]
|
||||
mock_openai.return_value.images.edit.return_value = mock_response
|
||||
result = image_edit(
|
||||
image=BytesIO(b"fake image"),
|
||||
prompt="Paint the words 'horse' on the horse.",
|
||||
openai_url=IMAGE_EDIT_ENDPOINT,
|
||||
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
|
||||
)
|
||||
assert result == base64.b64encode(b"fake edited image data").decode()
|
||||
|
||||
|
||||
def _cosine_similarity(a, b):
|
||||
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two arrays.
|
||||
|
||||
Returns a value close to 1 for similar vectors,
|
||||
close to 0 for orthogonal vectors,
|
||||
and close to -1 for opposite vectors.
|
||||
"""
|
||||
Close to 1: very similar
|
||||
Close to 0: orthogonal
|
||||
Close to -1: opposite
|
||||
"""
|
||||
a, b = np.array(a), np.array(b)
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
a_arr, b_arr = np.array(a), np.array(b)
|
||||
return float(np.dot(a_arr, b_arr) / (np.linalg.norm(a_arr) * np.linalg.norm(b_arr)))
|
||||
|
||||
|
||||
def test_embeddings():
|
||||
result1 = embedding(
|
||||
"this is a horse",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
result2 = embedding(
|
||||
"this is a horse also",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
result3 = embedding(
|
||||
"this is a donkey",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="qwen3-embed-4b",
|
||||
)
|
||||
similarity_1 = _cosine_similarity(result1, result2)
|
||||
assert similarity_1 > 0.9
|
||||
EMBEDDING_SIMILARITY_HIGH = 0.9
|
||||
EMBEDDING_SIMILARITY_LOW = 0.5
|
||||
|
||||
similarity_2 = _cosine_similarity(result1, result3)
|
||||
assert similarity_2 < 0.5
|
||||
|
||||
def test_embeddings() -> None:
|
||||
"""Test embedding similarity for similar and different texts."""
|
||||
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
|
||||
mock_horse_vec = [0.8] * 1024 + [0.6] * 1024
|
||||
mock_horse_also_vec = [0.79] * 1024 + [0.61] * 1024
|
||||
mock_donkey_vec = [-0.8] * 1024 + [-0.6] * 1024
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.data = [MagicMock(embedding=mock_horse_vec)]
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.data = [MagicMock(embedding=mock_horse_also_vec)]
|
||||
mock_response3 = MagicMock()
|
||||
mock_response3.data = [MagicMock(embedding=mock_donkey_vec)]
|
||||
|
||||
mock_openai.return_value.embeddings.create.side_effect = [
|
||||
mock_response1,
|
||||
mock_response2,
|
||||
mock_response3,
|
||||
]
|
||||
|
||||
result1 = embedding(
|
||||
"this is a horse",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
result2 = embedding(
|
||||
"this is a horse also",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
result3 = embedding(
|
||||
"this is a donkey",
|
||||
openai_url=EMBEDDING_ENDPOINT,
|
||||
openai_api_key=EMBEDDING_ENDPOINT_KEY,
|
||||
model="embed",
|
||||
)
|
||||
similarity_1 = _cosine_similarity(np.array(result1), np.array(result2))
|
||||
assert similarity_1 > EMBEDDING_SIMILARITY_HIGH
|
||||
|
||||
similarity_2 = _cosine_similarity(np.array(result1), np.array(result3))
|
||||
assert similarity_2 < EMBEDDING_SIMILARITY_LOW
|
||||
|
||||
@@ -0,0 +1,530 @@
|
||||
"""Tests for the main module (Discord bot commands)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx() -> MagicMock:
|
||||
"""Create a mock Discord command context."""
|
||||
ctx = MagicMock()
|
||||
ctx.author.name = "testuser"
|
||||
ctx.author.id = "12345"
|
||||
ctx.channel.id = "channel-1"
|
||||
ctx.guild.id = "guild-1"
|
||||
ctx.message.id = "msg-1"
|
||||
ctx.message.attachments = []
|
||||
ctx.bot.user = MagicMock()
|
||||
ctx.bot.user.name = "test-bot"
|
||||
ctx.bot.user.id = "bot-123"
|
||||
ctx.send = AsyncMock()
|
||||
return ctx
|
||||
|
||||
|
||||
def test_bot_initialized(mock_discord: dict[str, MagicMock]) -> None:
|
||||
"""Test that the bot is initialized."""
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
assert main_module.bot is not None
|
||||
|
||||
|
||||
def test_bot_intents_set(mock_discord: dict[str, MagicMock]) -> None:
|
||||
"""Test that message_content intent is enabled."""
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
main_module.bot = mock_discord["bot_instance"]
|
||||
assert main_module.MIN_BOT_NAME_LENGTH == 2
|
||||
assert main_module.MAX_BOT_NAME_LENGTH == 50
|
||||
assert main_module.MIN_PERSONALITY_LENGTH == 10
|
||||
|
||||
|
||||
@patch("vibe_bot.main.tts_engine", None)
|
||||
def test_speak_tts_not_initialized(mock_ctx: MagicMock) -> None:
|
||||
"""Test speak command when TTS engine is not initialized."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="hello world"))
|
||||
mock_ctx.send.assert_called_once()
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "TTS engine not initialized" in call_args
|
||||
|
||||
|
||||
def test_speak_empty_message(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Test speak command with empty message."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message=""))
|
||||
mock_ctx.send.assert_called_once()
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Please provide text" in call_args
|
||||
|
||||
|
||||
def test_speak_plain_text(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test speak command with plain text (no custom bot prefix)."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = []
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="hello world"))
|
||||
mock_tts_engine.generate_audio.assert_called_once()
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
|
||||
def test_speak_with_custom_bot(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test speak command with a custom bot prefix."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||
("alfred", "british butler", "user-123"),
|
||||
]
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"british butler",
|
||||
"user-123",
|
||||
"2024-01-01",
|
||||
)
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
mock_tts_engine.generate_audio.assert_called_once()
|
||||
|
||||
|
||||
def test_custom_bot_command_success(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test creating a custom bot successfully."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx, bot_name="alfred", personality="you are a british butler",
|
||||
),
|
||||
)
|
||||
|
||||
mock_custom_bot_manager.create_custom_bot.assert_called_once()
|
||||
assert mock_ctx.send.call_count == 2
|
||||
|
||||
|
||||
def test_custom_bot_command_invalid_name_too_short(
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command with name too short."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx,
|
||||
bot_name="a",
|
||||
personality="this is a valid personality description",
|
||||
),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Invalid bot name" in call_args
|
||||
|
||||
|
||||
def test_custom_bot_command_invalid_name_empty(
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command with empty name."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx,
|
||||
bot_name="",
|
||||
personality="this is a valid personality description",
|
||||
),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Invalid bot name" in call_args
|
||||
|
||||
|
||||
def test_custom_bot_command_invalid_personality(
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command with personality too short."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(mock_ctx, bot_name="testbot", personality="short"),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Invalid personality" in call_args
|
||||
|
||||
|
||||
def test_custom_bot_command_create_fails(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test custom bot command when creation fails."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.create_custom_bot.return_value = False
|
||||
|
||||
asyncio.run(
|
||||
main_module.custom_bot(
|
||||
mock_ctx, bot_name="alfred", personality="you are a british butler",
|
||||
),
|
||||
)
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Failed to create custom bot" in call_args
|
||||
|
||||
|
||||
def test_list_custom_bots_empty(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test listing custom bots when none exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = []
|
||||
|
||||
asyncio.run(main_module.list_custom_bots(mock_ctx))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "No custom bots" in call_args
|
||||
|
||||
|
||||
def test_list_custom_bots_with_bots(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test listing custom bots when bots exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||
("alfred", "british butler", "user-1"),
|
||||
("jarvis", "ai assistant", "user-2"),
|
||||
]
|
||||
|
||||
asyncio.run(main_module.list_custom_bots(mock_ctx))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Available Custom Bots" in call_args
|
||||
assert "* alfred" in call_args
|
||||
assert "* jarvis" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_success(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a custom bot successfully."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"prompt",
|
||||
"12345",
|
||||
"2024-01-01",
|
||||
)
|
||||
mock_custom_bot_manager.delete_custom_bot.return_value = True
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "has been deleted" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_not_found(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a non-existent custom bot."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = None
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="nonexistent"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "not found" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_not_owner(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a custom bot you don't own."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"prompt",
|
||||
"other-user-id",
|
||||
"2024-01-01",
|
||||
)
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "You can only delete your own" in call_args
|
||||
|
||||
|
||||
def test_delete_custom_bot_delete_fails(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test deleting a custom bot when delete fails."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"alfred",
|
||||
"prompt",
|
||||
"12345",
|
||||
"2024-01-01",
|
||||
)
|
||||
mock_custom_bot_manager.delete_custom_bot.return_value = False
|
||||
|
||||
asyncio.run(main_module.delete_custom_bot(mock_ctx, bot_name="alfred"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Failed to delete" in call_args
|
||||
|
||||
|
||||
def test_on_message_skips_bot_messages(mock_ctx: MagicMock) -> None:
|
||||
"""Test that on_message skips messages from the bot itself."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
message = MagicMock()
|
||||
message.author = main_module.bot.user
|
||||
message.content = "hello"
|
||||
|
||||
asyncio.run(main_module.on_message(message))
|
||||
|
||||
|
||||
def test_handle_chat_success(
|
||||
mock_ctx: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test handle_chat with successful response."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = "This is a bot response" # noqa: E501
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
ctx=mock_ctx,
|
||||
bot_name="alfred",
|
||||
message="hello",
|
||||
system_prompt="you are a butler",
|
||||
response_prefix="alfred response",
|
||||
),
|
||||
)
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
mock_database.add_message.assert_called()
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
|
||||
def test_handle_chat_error(
|
||||
mock_ctx: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test handle_chat when an exception occurs."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.side_effect = Exception("API error")
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
ctx=mock_ctx,
|
||||
bot_name="alfred",
|
||||
message="hello",
|
||||
system_prompt="you are a butler",
|
||||
response_prefix="alfred response",
|
||||
),
|
||||
)
|
||||
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "error occurred" in call_args.lower()
|
||||
|
||||
|
||||
def test_handle_chat_long_response_chunked(
|
||||
mock_ctx: MagicMock,
|
||||
mock_database: MagicMock,
|
||||
mock_llama_wrapper: MagicMock,
|
||||
) -> None:
|
||||
"""Test that long bot responses are sent in chunks."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
long_response = "x" * 2500
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = long_response
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
ctx=mock_ctx,
|
||||
bot_name="alfred",
|
||||
message="hello",
|
||||
system_prompt="you are a butler",
|
||||
response_prefix="alfred response",
|
||||
),
|
||||
)
|
||||
|
||||
assert mock_ctx.send.call_count >= 3
|
||||
|
||||
|
||||
def test_speak_plain_with_mock_tts(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Test _speak_plain function directly."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module._speak_plain(mock_ctx, "hello world", mock_tts_engine))
|
||||
|
||||
mock_tts_engine.generate_audio.assert_called_once_with(
|
||||
"hello world",
|
||||
voice=main_module.TTS_VOICE, # type: ignore[attr-defined]
|
||||
speed=main_module.TTS_SPEED, # type: ignore[attr-defined]
|
||||
)
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
|
||||
def test_speak_plain_error(
|
||||
mock_ctx: MagicMock,
|
||||
mock_tts_engine: MagicMock,
|
||||
) -> None:
|
||||
"""Test _speak_plain when audio generation fails."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_tts_engine.generate_audio.side_effect = Exception("generation error")
|
||||
|
||||
asyncio.run(main_module._speak_plain(mock_ctx, "hello world", mock_tts_engine))
|
||||
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "error generating speech" in call_args.lower()
|
||||
|
||||
|
||||
def test_flip_counter() -> None:
|
||||
"""Test the flip_counter helper function defined inside talkforme."""
|
||||
|
||||
def flip_counter(counter: int) -> int:
|
||||
return 1 if counter == 0 else 0
|
||||
|
||||
assert flip_counter(0) == 1
|
||||
assert flip_counter(1) == 0
|
||||
assert flip_counter(0) == 1
|
||||
|
||||
|
||||
def test_talkforme_invalid_args(mock_ctx: MagicMock) -> None:
|
||||
"""Test talkforme command with invalid arguments."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "Usage" in call_args
|
||||
|
||||
|
||||
def test_talkforme_bot1_not_found(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test talkforme when bot1 doesn't exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = None
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 4 a topic"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "is not a real bot" in call_args
|
||||
|
||||
|
||||
def test_talkforme_bot2_not_found(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test talkforme when bot2 doesn't exist."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.side_effect = [
|
||||
("bot1", "bot1 personality", "user-1", "2024-01-01"),
|
||||
None,
|
||||
]
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 4 a topic"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "is not a real bot" in call_args
|
||||
|
||||
|
||||
def test_talkforme_invalid_limit(
|
||||
mock_ctx: MagicMock,
|
||||
mock_custom_bot_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Test talkforme with non-integer limit."""
|
||||
import asyncio
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||
"bot1",
|
||||
"personality",
|
||||
"user-1",
|
||||
"2024-01-01",
|
||||
)
|
||||
|
||||
asyncio.run(main_module.talkforme(mock_ctx, message="bot1 bot2 abc topic"))
|
||||
call_args = mock_ctx.send.call_args[0][0]
|
||||
assert "must be an integer" in call_args
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tests for the tts module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def test_tts_engine_init(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test TTSEngine initialization."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
assert engine.model_path == "/tmp/test-model.onnx"
|
||||
assert engine.voices_path == "/tmp/test-voices.bin"
|
||||
|
||||
|
||||
def test_generate_audio(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation returns a BytesIO object."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("hello world this is a test")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
result.seek(0)
|
||||
data = result.read()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_generate_audio_empty_text(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that generating audio with empty text raises ValueError."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = []
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
|
||||
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||
engine.generate_audio("")
|
||||
|
||||
|
||||
def test_generate_audio_single_chunk(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with a single chunk."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["single chunk text"]
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("single chunk text")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
mock_kokoro_tts["process_chunk_sequential"].assert_called_once()
|
||||
|
||||
|
||||
def test_generate_audio_multiple_chunks(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with multiple chunks."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["chunk one", "chunk two", "chunk three"] # noqa: E501
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("this text is long enough to be split into multiple chunks") # noqa: E501
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
assert mock_kokoro_tts["process_chunk_sequential"].call_count == 3
|
||||
|
||||
|
||||
def test_generate_audio_chunk_failure(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that failed chunks are skipped but audio is still generated."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
def process_with_failure(
|
||||
chunk: str,
|
||||
kokoro: MagicMock,
|
||||
voice: str,
|
||||
speed: float,
|
||||
lang: str,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
if chunk == "bad chunk":
|
||||
raise Exception("processing error")
|
||||
return np.array([0.1, 0.2], dtype=np.float32), 24000
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["good chunk", "bad chunk", "another good"] # noqa: E501
|
||||
mock_kokoro_tts["process_chunk_sequential"].side_effect = process_with_failure
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("good chunk bad chunk another good")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
|
||||
|
||||
def test_generate_audio_all_chunks_fail(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that ValueError is raised when all chunks fail."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["chunk1", "chunk2"]
|
||||
mock_kokoro_tts["process_chunk_sequential"].side_effect = Exception("always fails")
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
|
||||
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||
engine.generate_audio("all chunks fail")
|
||||
|
||||
|
||||
def test_generate_audio_with_custom_voice(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with custom voice parameter."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
engine.generate_audio("hello", voice="af_bella", speed=1.5, lang="en-us")
|
||||
|
||||
call_args = mock_kokoro_tts["process_chunk_sequential"].call_args
|
||||
# Called with positional args: chunk, kokoro, voice, speed, lang
|
||||
assert call_args[0][2] == "af_bella"
|
||||
assert call_args[0][3] == 1.5
|
||||
assert call_args[0][4] == "en-us"
|
||||
|
||||
|
||||
def test_generate_audio_returns_seekable(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that the returned BytesIO is seekable."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("hello world")
|
||||
|
||||
result.seek(0)
|
||||
data = result.read()
|
||||
assert len(data) > 0
|
||||
|
||||
# Should be able to seek and read again
|
||||
result.seek(0)
|
||||
data2 = result.read()
|
||||
assert data == data2
|
||||
|
||||
|
||||
def test_default_voice_constant() -> None:
|
||||
"""Test that DEFAULT_VOICE has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_VOICE
|
||||
|
||||
assert DEFAULT_VOICE == "af_sarah"
|
||||
|
||||
|
||||
def test_default_speed_constant() -> None:
|
||||
"""Test that DEFAULT_SPEED has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_SPEED
|
||||
|
||||
assert DEFAULT_SPEED == 1.0
|
||||
|
||||
|
||||
def test_default_lang_constant() -> None:
|
||||
"""Test that DEFAULT_LANG has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_LANG
|
||||
|
||||
assert DEFAULT_LANG == "en-us"
|
||||
Reference in New Issue
Block a user