fix linting, formatting, and add tests
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
"""Tests for the tts module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def test_tts_engine_init(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test TTSEngine initialization."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
assert engine.model_path == "/tmp/test-model.onnx"
|
||||
assert engine.voices_path == "/tmp/test-voices.bin"
|
||||
|
||||
|
||||
def test_generate_audio(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation returns a BytesIO object."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("hello world this is a test")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
result.seek(0)
|
||||
data = result.read()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_generate_audio_empty_text(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that generating audio with empty text raises ValueError."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = []
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
|
||||
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||
engine.generate_audio("")
|
||||
|
||||
|
||||
def test_generate_audio_single_chunk(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with a single chunk."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["single chunk text"]
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("single chunk text")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
mock_kokoro_tts["process_chunk_sequential"].assert_called_once()
|
||||
|
||||
|
||||
def test_generate_audio_multiple_chunks(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with multiple chunks."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["chunk one", "chunk two", "chunk three"] # noqa: E501
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("this text is long enough to be split into multiple chunks") # noqa: E501
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
assert mock_kokoro_tts["process_chunk_sequential"].call_count == 3
|
||||
|
||||
|
||||
def test_generate_audio_chunk_failure(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that failed chunks are skipped but audio is still generated."""
|
||||
from io import BytesIO
|
||||
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
def process_with_failure(
|
||||
chunk: str,
|
||||
kokoro: MagicMock,
|
||||
voice: str,
|
||||
speed: float,
|
||||
lang: str,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
if chunk == "bad chunk":
|
||||
raise Exception("processing error")
|
||||
return np.array([0.1, 0.2], dtype=np.float32), 24000
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["good chunk", "bad chunk", "another good"] # noqa: E501
|
||||
mock_kokoro_tts["process_chunk_sequential"].side_effect = process_with_failure
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("good chunk bad chunk another good")
|
||||
|
||||
assert isinstance(result, BytesIO)
|
||||
|
||||
|
||||
def test_generate_audio_all_chunks_fail(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that ValueError is raised when all chunks fail."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
mock_kokoro_tts["chunk_text"].return_value = ["chunk1", "chunk2"]
|
||||
mock_kokoro_tts["process_chunk_sequential"].side_effect = Exception("always fails")
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
|
||||
with pytest.raises(ValueError, match="No audio samples generated"):
|
||||
engine.generate_audio("all chunks fail")
|
||||
|
||||
|
||||
def test_generate_audio_with_custom_voice(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test audio generation with custom voice parameter."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
engine.generate_audio("hello", voice="af_bella", speed=1.5, lang="en-us")
|
||||
|
||||
call_args = mock_kokoro_tts["process_chunk_sequential"].call_args
|
||||
# Called with positional args: chunk, kokoro, voice, speed, lang
|
||||
assert call_args[0][2] == "af_bella"
|
||||
assert call_args[0][3] == 1.5
|
||||
assert call_args[0][4] == "en-us"
|
||||
|
||||
|
||||
def test_generate_audio_returns_seekable(mock_kokoro_tts: MagicMock) -> None:
|
||||
"""Test that the returned BytesIO is seekable."""
|
||||
from vibe_bot.tts import TTSEngine
|
||||
|
||||
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
|
||||
result = engine.generate_audio("hello world")
|
||||
|
||||
result.seek(0)
|
||||
data = result.read()
|
||||
assert len(data) > 0
|
||||
|
||||
# Should be able to seek and read again
|
||||
result.seek(0)
|
||||
data2 = result.read()
|
||||
assert data == data2
|
||||
|
||||
|
||||
def test_default_voice_constant() -> None:
|
||||
"""Test that DEFAULT_VOICE has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_VOICE
|
||||
|
||||
assert DEFAULT_VOICE == "af_sarah"
|
||||
|
||||
|
||||
def test_default_speed_constant() -> None:
|
||||
"""Test that DEFAULT_SPEED has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_SPEED
|
||||
|
||||
assert DEFAULT_SPEED == 1.0
|
||||
|
||||
|
||||
def test_default_lang_constant() -> None:
|
||||
"""Test that DEFAULT_LANG has expected value."""
|
||||
from vibe_bot.tts import DEFAULT_LANG
|
||||
|
||||
assert DEFAULT_LANG == "en-us"
|
||||
Reference in New Issue
Block a user