add tts
This commit is contained in:
Binary file not shown.
@@ -13,4 +13,5 @@ dependencies = [
|
|||||||
"pytest>=9.0.2",
|
"pytest>=9.0.2",
|
||||||
"python-dotenv>=1.2.2",
|
"python-dotenv>=1.2.2",
|
||||||
"pytest-env>=1.5.0",
|
"pytest-env>=1.5.0",
|
||||||
|
"kokoro-tts>=2.3.1",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -78,6 +78,12 @@ if not IMAGE_EDIT_MODEL:
|
|||||||
if not EMBEDDING_MODEL:
|
if not EMBEDDING_MODEL:
|
||||||
raise Exception("EMBEDDING_MODEL required.")
|
raise Exception("EMBEDDING_MODEL required.")
|
||||||
|
|
||||||
|
# TTS
|
||||||
|
TTS_MODEL_PATH = os.getenv("TTS_MODEL_PATH", "kokoro-v1.0.onnx")
|
||||||
|
TTS_VOICES_PATH = os.getenv("TTS_VOICES_PATH", "voices-v1.0.bin")
|
||||||
|
TTS_VOICE = os.getenv("TTS_VOICE", "af_sarah")
|
||||||
|
TTS_SPEED = float(os.getenv("TTS_SPEED", "1.0"))
|
||||||
|
|
||||||
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
|
logger.info(f"CHAT_ENDPOINT set to {CHAT_ENDPOINT}")
|
||||||
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
|
logger.info(f"COMPLETION_ENDPOINT set to {COMPLETION_ENDPOINT}")
|
||||||
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
|
logger.info(f"IMAGE_GEN_ENDPOINT set to {IMAGE_GEN_ENDPOINT}")
|
||||||
|
|||||||
@@ -61,16 +61,9 @@ def chat_completion_with_history(
|
|||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
extra_body={
|
|
||||||
"chat_template_kwargs": {"enable_thinking": False},
|
|
||||||
},
|
|
||||||
seed=-1,
|
seed=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that thinking was used
|
|
||||||
if response.choices[0].message.model_extra:
|
|
||||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if content:
|
if content:
|
||||||
return content.strip()
|
return content.strip()
|
||||||
@@ -101,15 +94,9 @@ def chat_completion_instruct(
|
|||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
extra_body={
|
seed=-1,
|
||||||
"chat_template_kwargs": {"enable_thinking": False},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert that thinking wasn't used
|
|
||||||
if response.choices[0].message.model_extra:
|
|
||||||
assert response.choices[0].message.model_extra.get("reasoning_content")
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if content:
|
if content:
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import discord
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
|
import traceback
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import logging
|
import logging
|
||||||
@@ -15,7 +16,12 @@ from config import ( # type: ignore
|
|||||||
IMAGE_GEN_ENDPOINT,
|
IMAGE_GEN_ENDPOINT,
|
||||||
IMAGE_EDIT_ENDPOINT,
|
IMAGE_EDIT_ENDPOINT,
|
||||||
MAX_COMPLETION_TOKENS,
|
MAX_COMPLETION_TOKENS,
|
||||||
|
TTS_MODEL_PATH,
|
||||||
|
TTS_VOICES_PATH,
|
||||||
|
TTS_VOICE,
|
||||||
|
TTS_SPEED,
|
||||||
)
|
)
|
||||||
|
import tts # type: ignore
|
||||||
import llama_wrapper # type: ignore
|
import llama_wrapper # type: ignore
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -30,6 +36,15 @@ intents = discord.Intents.default()
|
|||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||||
|
|
||||||
|
# Initialize TTS engine
|
||||||
|
try:
|
||||||
|
tts_engine = tts.TTSEngine(TTS_MODEL_PATH, TTS_VOICES_PATH)
|
||||||
|
logger.info("TTS engine initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize TTS engine: {e}")
|
||||||
|
logger.info("Make sure kokoro-v1.0.onnx and voices-v1.0.bin are in the project directory")
|
||||||
|
tts_engine = None
|
||||||
|
|
||||||
|
|
||||||
@bot.event
|
@bot.event
|
||||||
async def on_ready():
|
async def on_ready():
|
||||||
@@ -232,6 +247,84 @@ async def on_message(message):
|
|||||||
await bot.process_commands(message)
|
await bot.process_commands(message)
|
||||||
|
|
||||||
|
|
||||||
|
@bot.command(name="speak")
|
||||||
|
async def speak(ctx, *, message: str):
|
||||||
|
"""Have the bot speak the given text using Kokoro TTS, or have a custom bot speak
|
||||||
|
|
||||||
|
Usage: !speak <text> - plain text to speech
|
||||||
|
Usage: !speak <bot_name> <text> - have a custom bot respond and speak
|
||||||
|
Example: !speak hello world
|
||||||
|
Example: !speak alfred what time is it
|
||||||
|
"""
|
||||||
|
if tts_engine is None:
|
||||||
|
await ctx.send("❌ TTS engine not initialized. Make sure kokoro-v1.0.onnx and voices-v1.0.bin are present.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not message or len(message.strip()) == 0:
|
||||||
|
await ctx.send("❌ Please provide text to speak.")
|
||||||
|
return
|
||||||
|
|
||||||
|
custom_bot_manager = CustomBotManager()
|
||||||
|
custom_bots = custom_bot_manager.list_custom_bots()
|
||||||
|
bot_names = [b[0] for b in custom_bots]
|
||||||
|
|
||||||
|
first_word = message.split()[0] if message.split() else ""
|
||||||
|
if first_word in bot_names:
|
||||||
|
bot_name = first_word
|
||||||
|
text_to_speak = message[len(bot_name):].lstrip()
|
||||||
|
if not text_to_speak:
|
||||||
|
await ctx.send("❌ Please provide text for the bot to respond to.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await ctx.send(f"🔊 **{bot_name}** is thinking...")
|
||||||
|
|
||||||
|
bot_info = custom_bot_manager.get_custom_bot(bot_name)
|
||||||
|
if not bot_info:
|
||||||
|
await ctx.send(f"❌ Custom bot '{bot_name}' not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
_, system_prompt, _, _ = bot_info
|
||||||
|
|
||||||
|
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
||||||
|
|
||||||
|
try:
|
||||||
|
bot_response = llama_wrapper.chat_completion_with_history(
|
||||||
|
system_prompt=system_prompt_edit,
|
||||||
|
prompts=[{"role": "user", "content": text_to_speak}],
|
||||||
|
openai_url=CHAT_ENDPOINT,
|
||||||
|
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||||
|
model=CHAT_MODEL,
|
||||||
|
max_tokens=MAX_COMPLETION_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not bot_response:
|
||||||
|
await ctx.send(f"❌ **{bot_name}** failed to generate a response.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await ctx.send(f"🔊 Generating speech for **{bot_name}**...")
|
||||||
|
audio_buffer = tts_engine.generate_audio(bot_response, voice=TTS_VOICE, speed=TTS_SPEED)
|
||||||
|
|
||||||
|
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||||
|
await ctx.send(file=audio_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in !speak command with bot '{bot_name}': {traceback.format_exc()}")
|
||||||
|
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
||||||
|
else:
|
||||||
|
if not message or len(message.strip()) == 0:
|
||||||
|
await ctx.send("❌ Please provide text to speak.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await ctx.send("🔊 Generating speech...")
|
||||||
|
audio_buffer = tts_engine.generate_audio(message, voice=TTS_VOICE, speed=TTS_SPEED)
|
||||||
|
|
||||||
|
audio_file = discord.File(audio_buffer, filename="speech.mp3")
|
||||||
|
await ctx.send(file=audio_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in !speak command: {e}")
|
||||||
|
await ctx.send(f"❌ Error generating speech: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@bot.command(name="doodlebob")
|
@bot.command(name="doodlebob")
|
||||||
async def doodlebob(ctx, *, message: str):
|
async def doodlebob(ctx, *, message: str):
|
||||||
# add some logging
|
# add some logging
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
from io import BytesIO
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from kokoro_tts import Kokoro, chunk_text, process_chunk_sequential
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default voice settings
|
||||||
|
DEFAULT_VOICE = "af_sarah"
|
||||||
|
DEFAULT_SPEED = 1.0
|
||||||
|
DEFAULT_LANG = "en-us"
|
||||||
|
|
||||||
|
|
||||||
|
class TTSEngine:
|
||||||
|
def __init__(self, model_path: str, voices_path: str):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.voices_path = voices_path
|
||||||
|
self.kokoro = Kokoro(model_path, voices_path)
|
||||||
|
logger.info("Kokoro TTS engine initialized")
|
||||||
|
|
||||||
|
def generate_audio(self, text: str, voice: str = DEFAULT_VOICE, speed: float = DEFAULT_SPEED, lang: str = DEFAULT_LANG) -> BytesIO:
|
||||||
|
"""Convert text to audio and return as BytesIO (MP3 format)."""
|
||||||
|
all_samples = []
|
||||||
|
sample_rate = None
|
||||||
|
|
||||||
|
chunks = chunk_text(text)
|
||||||
|
logger.info(f"Split text into {len(chunks)} chunks")
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
try:
|
||||||
|
samples, sr = process_chunk_sequential(chunk, self.kokoro, voice, speed, lang)
|
||||||
|
if samples is not None:
|
||||||
|
if sample_rate is None:
|
||||||
|
sample_rate = sr
|
||||||
|
all_samples.append(samples)
|
||||||
|
logger.info(f"Processed chunk {i+1}/{len(chunks)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing chunk {i+1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_samples:
|
||||||
|
raise ValueError("No audio samples generated - text may be invalid or too long")
|
||||||
|
|
||||||
|
combined = np.concatenate(all_samples)
|
||||||
|
|
||||||
|
buffer = BytesIO()
|
||||||
|
sf.write(buffer, combined, sample_rate, format="MP3", subtype="MPEG_LAYER_III")
|
||||||
|
buffer.seek(0)
|
||||||
|
|
||||||
|
logger.info(f"Generated MP3 audio: {len(combined)} samples at {sample_rate}Hz")
|
||||||
|
return buffer
|
||||||
Binary file not shown.
Reference in New Issue
Block a user