229 lines
7.9 KiB
Python
229 lines
7.9 KiB
Python
"""Shared test fixtures for vibe_bot tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import tempfile
|
|
import warnings
|
|
from collections.abc import Generator
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="Exception ignored in.*FileIO.*Bad file descriptor",
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vibe_bot.database import ChatDatabase, CustomBotManager
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_env_vars() -> Generator[None]:
|
|
"""Provide minimal env vars for config loading."""
|
|
with patch.dict(
|
|
"os.environ",
|
|
{
|
|
"DISCORD_TOKEN": "test-token",
|
|
"CHAT_ENDPOINT": "https://chat.example.com/v1",
|
|
"COMPLETION_ENDPOINT": "https://completion.example.com/v1",
|
|
"IMAGE_GEN_ENDPOINT": "https://image.example.com/v1",
|
|
"IMAGE_EDIT_ENDPOINT": "https://image-edit.example.com/v1",
|
|
"EMBEDDING_ENDPOINT": "https://embedding.example.com/v1",
|
|
"CHAT_MODEL": "test-chat-model",
|
|
"COMPLETION_MODEL": "test-completion-model",
|
|
"IMAGE_GEN_MODEL": "test-image-model",
|
|
"IMAGE_EDIT_MODEL": "test-image-edit-model",
|
|
"EMBEDDING_MODEL": "test-embedding-model",
|
|
"CHAT_ENDPOINT_KEY": "test-key",
|
|
"COMPLETION_ENDPOINT_KEY": "test-completion-key",
|
|
"IMAGE_GEN_ENDPOINT_KEY": "test-image-key",
|
|
"IMAGE_EDIT_ENDPOINT_KEY": "test-image-edit-key",
|
|
"EMBEDDING_ENDPOINT_KEY": "test-embedding-key",
|
|
"MAX_COMPLETION_TOKENS": "1000",
|
|
"MAX_HISTORY_MESSAGES": "1000",
|
|
"SIMILARITY_THRESHOLD": "0.7",
|
|
"TOP_K_RESULTS": "5",
|
|
"TTS_MODEL_PATH": "/tmp/test-model.onnx",
|
|
"TTS_VOICES_PATH": "/tmp/test-voices.bin",
|
|
"TTS_VOICE": "af_sarah",
|
|
"TTS_SPEED": "1.0",
|
|
"DB_PATH": ":memory:",
|
|
},
|
|
clear=False,
|
|
):
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db_path() -> Generator[str]:
|
|
"""Provide a temporary SQLite database path."""
|
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
|
path = f.name
|
|
yield path
|
|
Path(path).unlink(missing_ok=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_embedding() -> Generator[MagicMock]:
|
|
"""Provide a mock embedding function returning a fixed vector."""
|
|
vector: list[float] = [0.1] * 2048
|
|
with patch("vibe_bot.llama_wrapper.embedding", return_value=vector) as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_openai_client() -> Generator[MagicMock]:
|
|
"""Provide a mock OpenAI client."""
|
|
mock_client = MagicMock()
|
|
with patch("vibe_bot.database.OpenAI", return_value=mock_client) as mock:
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def chat_db(
|
|
temp_db_path: str,
|
|
mock_openai_client: MagicMock,
|
|
mock_embedding: MagicMock,
|
|
) -> Generator[ChatDatabase]:
|
|
"""Provide a ChatDatabase instance with a temp database."""
|
|
from vibe_bot.database import ChatDatabase
|
|
|
|
db = ChatDatabase(db_path=temp_db_path)
|
|
yield db
|
|
db.client.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def custom_bot_manager(temp_db_path: str) -> CustomBotManager:
|
|
"""Provide a CustomBotManager instance with a temp database."""
|
|
from vibe_bot.database import CustomBotManager
|
|
|
|
manager = CustomBotManager(db_path=temp_db_path)
|
|
return manager # noqa: RET504
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_kokoro_tts() -> Generator[dict[str, Any]]:
|
|
"""Provide mock Kokoro TTS components."""
|
|
mock_kokoro = MagicMock()
|
|
mock_kokoro_instance = MagicMock()
|
|
mock_chunk = MagicMock()
|
|
mock_chunk.return_value = ["hello world", "this is a test"]
|
|
|
|
mock_samples = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
|
mock_process = MagicMock(return_value=(mock_samples, 24000))
|
|
|
|
with patch("vibe_bot.tts.Kokoro", return_value=mock_kokoro_instance): # noqa: SIM117
|
|
with patch("vibe_bot.tts.chunk_text", mock_chunk):
|
|
with patch("vibe_bot.tts.process_chunk_sequential", mock_process):
|
|
yield {
|
|
"Kokoro": mock_kokoro,
|
|
"chunk_text": mock_chunk,
|
|
"process_chunk_sequential": mock_process,
|
|
"kokoro_instance": mock_kokoro_instance,
|
|
"mock_samples": mock_samples,
|
|
"mock_sr": 24000,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_discord() -> Generator[dict[str, MagicMock]]:
|
|
"""Mock discord module components."""
|
|
mock_intents = MagicMock()
|
|
mock_intents.default.return_value = MagicMock()
|
|
mock_intents.default.return_value.message_content = True
|
|
|
|
mock_bot_class = MagicMock()
|
|
mock_bot_instance = MagicMock()
|
|
mock_bot_instance.user = MagicMock()
|
|
mock_bot_instance.user.name = "test-bot"
|
|
mock_bot_instance.user.id = "123456789"
|
|
|
|
with patch("vibe_bot.main.discord") as mock_discord_module: # noqa: SIM117
|
|
with patch("vibe_bot.main.commands", MagicMock()):
|
|
with patch("vibe_bot.main.commands.Bot", mock_bot_class):
|
|
mock_bot_class.return_value = mock_bot_instance
|
|
mock_discord_module.Intents = mock_intents
|
|
mock_discord_module.Message = MagicMock
|
|
mock_discord_module.File = MagicMock
|
|
yield {
|
|
"Intents": mock_intents,
|
|
"Bot": mock_bot_class,
|
|
"bot_instance": mock_bot_instance,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tts_engine() -> Generator[MagicMock]:
|
|
"""Provide a mock TTSEngine."""
|
|
mock_engine = MagicMock()
|
|
mock_engine.generate_audio.return_value = MagicMock()
|
|
with patch("vibe_bot.main.tts_engine", mock_engine): # noqa: SIM117
|
|
with patch("vibe_bot.main.tts.TTSEngine", return_value=mock_engine):
|
|
yield mock_engine
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_requests() -> Generator[MagicMock]:
|
|
"""Provide mock requests module."""
|
|
with patch("vibe_bot.main.requests") as mock_requests_module:
|
|
mock_response = MagicMock()
|
|
mock_response.content = b"fake image data"
|
|
mock_requests_module.get.return_value = mock_response
|
|
yield mock_requests_module
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_base64() -> Generator[MagicMock]:
|
|
"""Provide mock base64 module."""
|
|
with patch("vibe_bot.main.base64") as mock_base64_module:
|
|
mock_base64_module.b64decode.return_value = b"fake image data"
|
|
yield mock_base64_module
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llama_wrapper() -> Generator[MagicMock]:
|
|
"""Provide mock llama_wrapper module."""
|
|
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
|
|
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
|
|
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
|
|
mock_wrapper.image_generation.return_value = ""
|
|
mock_wrapper.image_edit.return_value = ""
|
|
mock_wrapper.embedding.return_value = [0.1] * 2048
|
|
yield mock_wrapper
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_database() -> Generator[MagicMock]:
|
|
"""Provide mock database module."""
|
|
with patch("vibe_bot.main.get_database") as mock_get_db:
|
|
mock_db = MagicMock()
|
|
mock_db.get_conversation_context.return_value = []
|
|
mock_db.add_message.return_value = True
|
|
mock_get_db.return_value = mock_db
|
|
yield mock_db
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_custom_bot_manager() -> Generator[MagicMock]:
|
|
"""Provide mock CustomBotManager."""
|
|
with patch("vibe_bot.main.CustomBotManager") as mock_manager_class:
|
|
mock_manager = MagicMock()
|
|
mock_manager.create_custom_bot.return_value = True
|
|
mock_manager.get_custom_bot.return_value = (
|
|
"alfred",
|
|
"british butler personality",
|
|
"user123",
|
|
"2024-01-01",
|
|
)
|
|
mock_manager.list_custom_bots.return_value = [
|
|
("alfred", "british butler personality", "user123"),
|
|
]
|
|
mock_manager.delete_custom_bot.return_value = True
|
|
mock_manager_class.return_value = mock_manager
|
|
yield mock_manager
|