add tts
This commit is contained in:
Binary file not shown.
@@ -13,4 +13,5 @@ dependencies = [
|
||||
"pytest>=9.0.2",
|
||||
"python-dotenv>=1.2.2",
|
||||
"pytest-env>=1.5.0",
|
||||
"kokoro-tts>=2.3.1",
|
||||
]
|
||||
|
||||
@@ -78,6 +78,12 @@ if not IMAGE_EDIT_MODEL:
|
||||
if not EMBEDDING_MODEL:
|
||||
raise Exception("EMBEDDING_MODEL required.")
|
||||
|
||||
# TTS
|
||||
TTS_MODEL_PATH = os.getenv("TTS_MODEL_PATH", "kokoro-v1.0.onnx")
|
||||
TTS_VOICES_PATH = os.getenv("TTS_VOICES_PATH", "voices-v1.0.bin")
|
||||
TTS_VOICE = os.getenv("TTS_VOICE", "af_sarah")
|
||||
TTS_SPEED = float(os.getenv("TTS_SPEED", "1.0"))
|
||||
|
||||
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
|
||||
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
|
||||
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
|
||||
|
||||
@@ -61,16 +61,9 @@ def chat_completion_with_history(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
seed=-1,
|
||||
)
|
||||
|
||||
# Assert that thinking was used
|
||||
if response.choices[0].message.model_extra:
|
||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
@@ -101,15 +94,9 @@ def chat_completion_instruct(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
extra_body={
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
seed=-1,
|
||||
)
|
||||
|
||||
# Assert that thinking wasn't used
|
||||
if response.choices[0].message.model_extra:
|
||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
|
||||
@@ -2,6 +2,7 @@ import discord
|
||||
from discord.ext import commands
|
||||
import os
|
||||
import base64
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
from openai import OpenAI
|
||||
import logging
|
||||
@@ -15,7 +16,12 @@ from config import ( # type: ignore
|
||||
IMAGE_GEN_ENDPOINT,
|
||||
IMAGE_EDIT_ENDPOINT,
|
||||
MAX_COMPLETION_TOKENS,
|
||||
TTS_MODEL_PATH,
|
||||
TTS_VOICES_PATH,
|
||||
TTS_VOICE,
|
||||
TTS_SPEED,
|
||||
)
|
||||
import tts # type: ignore
|
||||
import llama_wrapper # type: ignore
|
||||
import requests
|
||||
|
||||
@@ -30,6 +36,15 @@ intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||
|
||||
# Initialize TTS engine
|
||||
try:
|
||||
tts_engine = tts.TTSEngine(TTS_MODEL_PATH, TTS_VOICES_PATH)
|
||||
logger.info("TTS engine initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TTS engine: {e}")
|
||||
logger.info("Make sure kokoro-v1.0.onnx and voices-v1.0.bin are in the project directory")
|
||||
tts_engine = None
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
@@ -232,6 +247,84 @@ async def on_message(message):
|
||||
await bot.process_commands(message)
|
||||
|
||||
|
||||
@bot.command(name="speak")
|
||||
async def speak(ctx, *, message: str):
|
||||
"""Have the bot speak the given text using Kokoro TTS, or have a custom bot speak
|
||||
|
||||
Usage: !speak <text> - plain text to speech
|
||||
Usage: !speak <bot_name> <text> - have a custom bot respond and speak
|
||||
Example: !speak hello world
|
||||
Example: !speak alfred what time is it
|
||||
"""
|
||||
if tts_engine is None:
|
||||
await ctx.send("❌ TTS engine not initialized. Make sure kokoro-v1.0.onnx and voices-v1.0.bin are present.")
|
||||
return
|
||||
|
||||
if not message or len(message.strip()) == 0:
|
||||
await ctx.send("❌ Please provide text to speak.")
|
||||
return
|
||||
|
||||
custom_bot_manager = CustomBotManager()
|
||||
custom_bots = custom_bot_manager.list_custom_bots()
|
||||
bot_names = [b[0] for b in custom_bots]
|
||||
|
||||
first_word = message.split()[0] if message.split() else ""
|
||||
if first_word in bot_names:
|
||||
bot_name = first_word
|
||||
text_to_speak = message[len(bot_name):].lstrip()
|
||||
if not text_to_speak:
|
||||
await ctx.send("❌ Please provide text for the bot to respond to.")
|
||||
return
|
||||
|
||||
await ctx.send(f"🔊 **{bot_name}** is thinking...")
|
||||
|
||||
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||
if not bot_info:
|
||||
await ctx.send(f"❌ Custom bot '{bot_name}' not found.")
|
||||
return
|
||||
|
||||
_, system_prompt, _, _ = bot_info
|
||||
|
||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||
|
||||
try:
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=[{"role": "user", "content": text_to_speak}],
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
max_tokens=MAX_COMPLETION_TOKENS,
|
||||
)
|
||||
|
||||
if not bot_response:
|
||||
await ctx.send(f"❌ **{bot_name}** failed to generate a response.")
|
||||
return
|
||||
|
||||
await ctx.send(f"🔊 Generating speech for **{bot_name}**...")
|
||||
audio_buffer = tts_engine.generate_audio(bot_response, voice=TTS_VOICE, speed=TTS_SPEED)
|
||||
|
||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||
await ctx.send(file=audio_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in !speak command with bot '{bot_name}': {traceback.format_exc()}")
|
||||
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
||||
else:
|
||||
if not message or len(message.strip()) == 0:
|
||||
await ctx.send("❌ Please provide text to speak.")
|
||||
return
|
||||
|
||||
try:
|
||||
await ctx.send("🔊 Generating speech...")
|
||||
audio_buffer = tts_engine.generate_audio(message, voice=TTS_VOICE, speed=TTS_SPEED)
|
||||
|
||||
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||
await ctx.send(file=audio_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in !speak command: {e}")
|
||||
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
||||
|
||||
|
||||
@bot.command(name="doodlebob")
|
||||
async def doodlebob(ctx, *, message: str):
|
||||
# add some logging
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from io import BytesIO
|
||||
import os
|
||||
import logging
|
||||
from kokoro_tts import Kokoro, chunk_text, process_chunk_sequential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default voice settings
|
||||
DEFAULT_VOICE = "af_sarah"
|
||||
DEFAULT_SPEED = 1.0
|
||||
DEFAULT_LANG = "en-us"
|
||||
|
||||
|
||||
class TTSEngine:
|
||||
def __init__(self, model_path: str, voices_path: str):
|
||||
self.model_path = model_path
|
||||
self.voices_path = voices_path
|
||||
self.kokoro = Kokoro(model_path, voices_path)
|
||||
logger.info("Kokoro TTS engine initialized")
|
||||
|
||||
def generate_audio(self, text: str, voice: str = DEFAULT_VOICE, speed: float = DEFAULT_SPEED, lang: str = DEFAULT_LANG) -> BytesIO:
|
||||
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
||||
all_samples = []
|
||||
sample_rate = None
|
||||
|
||||
chunks = chunk_text(text)
|
||||
logger.info(f"Split text into {len(chunks)} chunks")
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
samples, sr = process_chunk_sequential(chunk, self.kokoro, voice, speed, lang)
|
||||
if samples is not None:
|
||||
if sample_rate is None:
|
||||
sample_rate = sr
|
||||
all_samples.append(samples)
|
||||
logger.info(f"Processed chunk {i+1}/{len(chunks)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk {i+1}: {e}")
|
||||
continue
|
||||
|
||||
if not all_samples:
|
||||
raise ValueError("No audio samples generated - text may be invalid or too long")
|
||||
|
||||
combined = np.concatenate(all_samples)
|
||||
|
||||
buffer = BytesIO()
|
||||
sf.write(buffer, combined, sample_rate, format="MP3", subtype="MPEG_LAYER_III")
|
||||
buffer.seek(0)
|
||||
|
||||
logger.info(f"Generated MP3 audio: {len(combined)} samples at {sample_rate}Hz")
|
||||
return buffer
|
||||
Binary file not shown.
Reference in New Issue
Block a user