fix linting, formatting, and add tests
This commit is contained in:
+59
-19
@@ -1,9 +1,17 @@
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from io import BytesIO
|
||||
import os
|
||||
"""Text-to-speech engine using Kokoro TTS."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from kokoro_tts import Kokoro, chunk_text, process_chunk_sequential
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf # type: ignore[import-untyped]
|
||||
from kokoro_tts import ( # type: ignore[import-untyped]
|
||||
Kokoro,
|
||||
chunk_text,
|
||||
process_chunk_sequential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,40 +22,72 @@ DEFAULT_LANG = "en-us"
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
def __init__(self, model_path: str, voices_path: str):
|
||||
"""Text-to-speech engine wrapper around Kokoro TTS."""
|
||||
|
||||
def __init__(self, model_path: str, voices_path: str) -> None:
|
||||
"""Initialize the TTS engine with model and voices paths.
|
||||
|
||||
Args:
|
||||
model_path: Path to the Kokoro model file.
|
||||
voices_path: Path to the voices file.
|
||||
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.voices_path = voices_path
|
||||
self.kokoro = Kokoro(model_path, voices_path)
|
||||
logger.info("Kokoro TTS engine initialized")
|
||||
|
||||
def generate_audio(self, text: str, voice: str = DEFAULT_VOICE, speed: float = DEFAULT_SPEED, lang: str = DEFAULT_LANG) -> BytesIO:
|
||||
def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
voice: str = DEFAULT_VOICE,
|
||||
speed: float = DEFAULT_SPEED,
|
||||
lang: str = DEFAULT_LANG,
|
||||
) -> BytesIO:
|
||||
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
||||
all_samples = []
|
||||
sample_rate = None
|
||||
all_samples: list[np.ndarray] = []
|
||||
sample_rate: int | None = None
|
||||
|
||||
chunks = chunk_text(text)
|
||||
logger.info(f"Split text into {len(chunks)} chunks")
|
||||
chunks: list[str] = list(chunk_text(text))
|
||||
logger.info("Split text into %d chunks", len(chunks))
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
samples, sr = process_chunk_sequential(chunk, self.kokoro, voice, speed, lang)
|
||||
samples, sr = process_chunk_sequential(
|
||||
chunk,
|
||||
self.kokoro,
|
||||
voice,
|
||||
speed,
|
||||
lang,
|
||||
)
|
||||
if samples is not None:
|
||||
if sample_rate is None:
|
||||
sample_rate = sr
|
||||
all_samples.append(samples)
|
||||
logger.info(f"Processed chunk {i+1}/{len(chunks)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk {i+1}: {e}")
|
||||
all_samples.append(np.asarray(samples))
|
||||
logger.info("Processed chunk %d/%d", i + 1, len(chunks))
|
||||
except Exception:
|
||||
logger.exception("Error processing chunk %d", i + 1)
|
||||
continue
|
||||
|
||||
if not all_samples:
|
||||
raise ValueError("No audio samples generated - text may be invalid or too long")
|
||||
msg = "No audio samples generated - text may be invalid or too long"
|
||||
raise ValueError(msg)
|
||||
|
||||
combined = np.concatenate(all_samples)
|
||||
|
||||
buffer = BytesIO()
|
||||
sf.write(buffer, combined, sample_rate, format="MP3", subtype="MPEG_LAYER_III")
|
||||
sf.write( # pyright: ignore[reportUnknownMemberType]
|
||||
buffer,
|
||||
combined,
|
||||
sample_rate,
|
||||
format="MP3",
|
||||
subtype="MPEG_LAYER_III",
|
||||
)
|
||||
buffer.seek(0)
|
||||
|
||||
logger.info(f"Generated MP3 audio: {len(combined)} samples at {sample_rate}Hz")
|
||||
logger.info(
|
||||
"Generated MP3 audio: %d samples at %dHz",
|
||||
len(combined),
|
||||
sample_rate or 0,
|
||||
)
|
||||
return buffer
|
||||
|
||||
Reference in New Issue
Block a user