94 lines
2.6 KiB
Python
94 lines
2.6 KiB
Python
"""Text-to-speech engine using Kokoro TTS."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from io import BytesIO
|
|
|
|
import numpy as np
|
|
import soundfile as sf # type: ignore[import-untyped]
|
|
from kokoro_tts import ( # type: ignore[import-untyped]
|
|
Kokoro,
|
|
chunk_text,
|
|
process_chunk_sequential,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default voice settings
|
|
DEFAULT_VOICE = "af_sarah"
|
|
DEFAULT_SPEED = 1.0
|
|
DEFAULT_LANG = "en-us"
|
|
|
|
|
|
class TTSEngine:
|
|
"""Text-to-speech engine wrapper around Kokoro TTS."""
|
|
|
|
def __init__(self, model_path: str, voices_path: str) -> None:
|
|
"""Initialize the TTS engine with model and voices paths.
|
|
|
|
Args:
|
|
model_path: Path to the Kokoro model file.
|
|
voices_path: Path to the voices file.
|
|
|
|
"""
|
|
self.model_path = model_path
|
|
self.voices_path = voices_path
|
|
self.kokoro = Kokoro(model_path, voices_path)
|
|
logger.info("Kokoro TTS engine initialized")
|
|
|
|
def generate_audio(
|
|
self,
|
|
text: str,
|
|
voice: str = DEFAULT_VOICE,
|
|
speed: float = DEFAULT_SPEED,
|
|
lang: str = DEFAULT_LANG,
|
|
) -> BytesIO:
|
|
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
|
all_samples: list[np.ndarray] = []
|
|
sample_rate: int | None = None
|
|
|
|
chunks: list[str] = list(chunk_text(text))
|
|
logger.info("Split text into %d chunks", len(chunks))
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
samples, sr = process_chunk_sequential(
|
|
chunk,
|
|
self.kokoro,
|
|
voice,
|
|
speed,
|
|
lang,
|
|
)
|
|
if samples is not None:
|
|
if sample_rate is None:
|
|
sample_rate = sr
|
|
all_samples.append(np.asarray(samples))
|
|
logger.info("Processed chunk %d/%d", i + 1, len(chunks))
|
|
except Exception:
|
|
logger.exception("Error processing chunk %d", i + 1)
|
|
continue
|
|
|
|
if not all_samples:
|
|
msg = "No audio samples generated - text may be invalid or too long"
|
|
raise ValueError(msg)
|
|
|
|
combined = np.concatenate(all_samples)
|
|
|
|
buffer = BytesIO()
|
|
sf.write( # pyright: ignore[reportUnknownMemberType]
|
|
buffer,
|
|
combined,
|
|
sample_rate,
|
|
format="MP3",
|
|
subtype="MPEG_LAYER_III",
|
|
)
|
|
buffer.seek(0)
|
|
|
|
logger.info(
|
|
"Generated MP3 audio: %d samples at %dHz",
|
|
len(combined),
|
|
sample_rate or 0,
|
|
)
|
|
return buffer
|