everything working again after cleanup

This commit is contained in:
2026-05-23 23:56:03 -04:00
parent 6ec9fbe85f
commit 87a578f1de
13 changed files with 380 additions and 200 deletions
+5 -4
View File
@@ -330,7 +330,7 @@ class ChatDatabase:
results.sort(key=lambda x: x[2], reverse=True)
return results[:top_k]
def get_user_history(self, _user_id: str, limit: int = 20) -> list[tuple[str, str]]:
def get_user_history(self, user_id: str, limit: int = 20) -> list[tuple[str, str]]:
"""Get message history for a specific user."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
@@ -340,11 +340,11 @@ class ChatDatabase:
"""
SELECT message_id, content, timestamp
FROM chat_messages
WHERE username != 'vibe-bot'
WHERE user_id = ? AND username != 'vibe-bot'
ORDER BY timestamp DESC
LIMIT ?
""",
(limit,),
(user_id, limit),
)
messages = cursor.fetchall()
@@ -528,9 +528,10 @@ class CustomBotManager:
"""
SELECT bot_name, system_prompt, created_by
FROM custom_bots
WHERE is_active = 1
WHERE is_active = 1 AND created_by = ?
ORDER BY created_at DESC
""",
(user_id,),
)
else:
cursor.execute(
+61 -21
View File
@@ -6,9 +6,11 @@ Allows custom endpoints for each of the above supported functions.
from __future__ import annotations
import json
from typing import TYPE_CHECKING, cast
import openai
import requests
if TYPE_CHECKING:
from io import BufferedReader, BytesIO
@@ -54,8 +56,12 @@ def chat_completion(
model=model,
messages=messages,
max_tokens=max_tokens,
timeout=60.0,
)
if not response.choices:
return ""
content = response.choices[0].message.content
if content:
return content.strip()
@@ -101,8 +107,12 @@ def chat_completion_with_history(
messages=messages,
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
content = response.choices[0].message.content
if content:
return content.strip()
@@ -148,8 +158,12 @@ def chat_completion_instruct(
messages=messages,
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
content = response.choices[0].message.content
if content:
return content.strip()
@@ -158,8 +172,10 @@ def chat_completion_instruct(
def image_generation(
prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str = "gen",
n: int = 1,
) -> str:
"""Generate an image using the given prompt.
@@ -168,19 +184,28 @@ def image_generation(
prompt: The image generation prompt.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for image generation.
n: Number of images to generate.
Returns:
The base64 encoded image data. Decode and write to a file.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.images.generate(
prompt=prompt,
n=n,
size="1024x1024",
model="gen",
client = openai.OpenAI(
base_url=openai_url,
api_key=openai_api_key,
max_retries=0,
)
try:
response = client.images.generate(
prompt=prompt,
n=n,
size="1024x1024",
model=model,
timeout=120.0,
)
except openai.APIConnectionError:
return ""
if response.data:
return response.data[0].b64_json or ""
return ""
@@ -189,8 +214,10 @@ def image_generation(
def image_edit(
image: BufferedReader | BytesIO | list[BufferedReader] | list[BytesIO],
prompt: str,
*,
openai_url: str,
openai_api_key: str,
model: str = "edit",
n: int = 1,
) -> str:
"""Edit an existing image using a prompt.
@@ -200,6 +227,7 @@ def image_edit(
prompt: The edit instruction.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for image editing.
n: Number of edited images to generate.
Returns:
@@ -212,7 +240,7 @@ def image_edit(
prompt=prompt,
n=n,
size="1024x1024",
model="edit",
model=model,
)
if response.data:
return response.data[0].b64_json or ""
@@ -228,6 +256,9 @@ def embedding(
) -> list[float]:
"""Generate an embedding vector for the given text.
Uses a raw HTTP request to avoid the OpenAI SDK injecting
unsupported parameters like encoding_format.
Args:
text: The text to embed.
openai_url: The OpenAI-compatible API URL.
@@ -238,17 +269,26 @@ def embedding(
The embedding vector as a list of floats, or an empty list on failure.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
response = client.embeddings.create(
input=[text],
model=model,
encoding_format="float",
)
if response:
data = response.data
raw_data = data[0].embedding
# The result could be an array of floats or a single float.
if not isinstance(raw_data, float):
return list(raw_data)
return [raw_data]
return []
url = f"{openai_url.rstrip('/')}/embeddings"
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json",
}
payload = {"model": model, "input": [text]}
try:
resp = requests.post(url, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
except requests.RequestException:
return []
data = resp.json()
if not data.get("data"):
return []
raw = data["data"][0].get("embedding")
if isinstance(raw, str):
raw = json.loads(raw)
if not isinstance(raw, list):
raw = list(raw)
return raw
+23 -9
View File
@@ -20,6 +20,10 @@ from vibe_bot.config import (
DISCORD_TOKEN,
IMAGE_EDIT_ENDPOINT,
IMAGE_EDIT_ENDPOINT_KEY,
IMAGE_EDIT_MODEL,
IMAGE_GEN_ENDPOINT,
IMAGE_GEN_ENDPOINT_KEY,
IMAGE_GEN_MODEL,
MAX_COMPLETION_TOKENS,
TTS_MODEL_PATH,
TTS_SPEED,
@@ -415,7 +419,7 @@ async def _speak_with_bot(
message_id=f"{ctx.message.id}_response",
user_id=str(ctx.bot.user.id),
username=ctx.bot.user.name,
content=f"Bot: {bot_response}",
content=bot_response,
channel_id=str(ctx.channel.id),
guild_id=str(ctx.guild.id) if ctx.guild else None,
)
@@ -497,14 +501,23 @@ async def doodlebob(ctx: CommandsContext[Bot], *, message: str) -> None:
image_b64 = llama_wrapper.image_generation(
prompt=image_prompt,
openai_url=IMAGE_EDIT_ENDPOINT,
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
openai_url=IMAGE_GEN_ENDPOINT,
openai_api_key=IMAGE_GEN_ENDPOINT_KEY,
model=IMAGE_GEN_MODEL,
)
# Save the image to a file
edited_image_data = BytesIO(base64.b64decode(image_b64))
send_img = discord.File(edited_image_data, filename="image.png")
await ctx.send(file=send_img)
if not image_b64:
logger.warning("Image generation returned empty response.")
await ctx.send("Failed to generate image. The server may be busy.")
return
try:
edited_image_data = BytesIO(base64.b64decode(image_b64))
send_img = discord.File(edited_image_data, filename="image.png")
await ctx.send(file=send_img)
except Exception:
logger.exception("Failed to decode image data")
await ctx.send("Failed to process the generated image.")
@bot.command(name="retcon")
@@ -529,6 +542,7 @@ async def retcon(ctx: CommandsContext[Bot], *, message: str) -> None:
prompt=message,
openai_url=IMAGE_EDIT_ENDPOINT,
openai_api_key=IMAGE_EDIT_ENDPOINT_KEY,
model=IMAGE_EDIT_MODEL,
)
# Save the image to a file
@@ -621,7 +635,7 @@ async def talkforme(ctx: CommandsContext[Bot], *, message: str) -> None:
bot_response = llama_wrapper.chat_completion_with_history(
system_prompt=(
current_bot[1] + f"\nKeep your responses under 2-3 sentences. "
f"{current_bot[flip_counter(bot_counter)]}"
f"You are talking to {current_bot[flip_counter(bot_counter)][0]}"
),
prompts=prompt_histories[bot_counter],
openai_url=CHAT_ENDPOINT,
@@ -709,7 +723,7 @@ async def handle_chat(
message_id=f"{ctx.message.id}_response",
user_id=str(ctx.bot.user.id),
username=ctx.bot.user.name,
content=f"Bot: {bot_response}",
content=bot_response,
channel_id=str(ctx.channel.id),
guild_id=str(ctx.guild.id) if ctx.guild else None,
)
+18 -13
View File
@@ -117,17 +117,22 @@ def mock_kokoro_tts() -> Generator[dict[str, Any]]:
mock_samples = np.array([0.1, 0.2, 0.3], dtype=np.float32)
mock_process = MagicMock(return_value=(mock_samples, 24000))
with patch("vibe_bot.tts.Kokoro", return_value=mock_kokoro_instance): # noqa: SIM117
with patch("vibe_bot.tts.chunk_text", mock_chunk):
with patch("vibe_bot.tts.process_chunk_sequential", mock_process):
yield {
"Kokoro": mock_kokoro,
"chunk_text": mock_chunk,
"process_chunk_sequential": mock_process,
"kokoro_instance": mock_kokoro_instance,
"mock_samples": mock_samples,
"mock_sr": 24000,
}
with (
patch(
"vibe_bot.tts.Kokoro",
return_value=mock_kokoro_instance,
),
patch("vibe_bot.tts.chunk_text", mock_chunk),
):
with patch("vibe_bot.tts.process_chunk_sequential", mock_process):
yield {
"Kokoro": mock_kokoro,
"chunk_text": mock_chunk,
"process_chunk_sequential": mock_process,
"kokoro_instance": mock_kokoro_instance,
"mock_samples": mock_samples,
"mock_sr": 24000,
}
@pytest.fixture
@@ -143,7 +148,7 @@ def mock_discord() -> Generator[dict[str, MagicMock]]:
mock_bot_instance.user.name = "test-bot"
mock_bot_instance.user.id = "123456789"
with patch("vibe_bot.main.discord") as mock_discord_module: # noqa: SIM117
with patch("vibe_bot.main.discord") as mock_discord_module:
with patch("vibe_bot.main.commands", MagicMock()):
with patch("vibe_bot.main.commands.Bot", mock_bot_class):
mock_bot_class.return_value = mock_bot_instance
@@ -162,7 +167,7 @@ def mock_tts_engine() -> Generator[MagicMock]:
"""Provide a mock TTSEngine."""
mock_engine = MagicMock()
mock_engine.generate_audio.return_value = MagicMock()
with patch("vibe_bot.main.tts_engine", mock_engine): # noqa: SIM117
with patch("vibe_bot.main.tts_engine", mock_engine):
with patch("vibe_bot.main.tts.TTSEngine", return_value=mock_engine):
yield mock_engine
+3 -3
View File
@@ -106,9 +106,9 @@ except Exception as e:
timeout=30,
)
output = result.stdout.strip()
assert output.startswith("ERROR:") and expected_error in output, ( # noqa: PT018
f"Expected error '{expected_error}' but got: {output}"
)
assert (
output.startswith("ERROR:") and expected_error in output
), f"Expected error '{expected_error}' but got: {output}"
def test_config_missing_discord_token() -> None:
+29 -7
View File
@@ -129,13 +129,22 @@ def test_get_recent_messages(
) -> None:
"""Test retrieving recent messages."""
chat_db.add_message(
message_id="msg-1", user_id="u1", username="alice", content="First",
message_id="msg-1",
user_id="u1",
username="alice",
content="First",
)
chat_db.add_message(
message_id="msg-2", user_id="u2", username="bob", content="Second",
message_id="msg-2",
user_id="u2",
username="bob",
content="Second",
)
chat_db.add_message(
message_id="msg-3", user_id="u1", username="alice", content="Third",
message_id="msg-3",
user_id="u1",
username="alice",
content="Third",
)
messages = chat_db.get_recent_messages(limit=2)
@@ -167,10 +176,16 @@ def test_clear_all_messages(
) -> None:
"""Test clearing all messages."""
chat_db.add_message(
message_id="msg-1", user_id="u1", username="alice", content="Hello",
message_id="msg-1",
user_id="u1",
username="alice",
content="Hello",
)
chat_db.add_message(
message_id="msg-2", user_id="u2", username="bob", content="World",
message_id="msg-2",
user_id="u2",
username="bob",
content="World",
)
chat_db.clear_all_messages()
@@ -185,7 +200,10 @@ def test_get_user_history(
) -> None:
"""Test retrieving user message history."""
chat_db.add_message(
message_id="msg-1", user_id="u1", username="alice", content="User question",
message_id="msg-1",
user_id="u1",
username="alice",
content="User question",
)
chat_db.add_message(
message_id="msg-1_response",
@@ -422,7 +440,9 @@ def test_custom_bot_delete_with_error(
) -> None:
"""Test that delete_custom_bot returns False on error."""
with patch.object(
custom_bot_manager, "_initialize_custom_bots_table", side_effect=Exception("db error"), # noqa: E501
custom_bot_manager,
"_initialize_custom_bots_table",
side_effect=Exception("db error"),
):
pass
result = custom_bot_manager.delete_custom_bot("nonexistent")
@@ -433,6 +453,7 @@ def test_database_get_database_singleton(temp_db_path: str) -> None:
"""Test that get_database returns the same instance."""
import vibe_bot.database as db_module
from vibe_bot.database import ChatDatabase, get_database
db_module._chat_db = None
db1 = get_database()
@@ -453,6 +474,7 @@ def test_database_init_creates_tables(temp_db_path: str) -> None:
db.client.close()
import sqlite3
conn = sqlite3.connect(temp_db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
+17 -16
View File
@@ -6,6 +6,7 @@ import base64
import tempfile
from io import BytesIO
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import numpy as np
@@ -106,24 +107,24 @@ EMBEDDING_SIMILARITY_LOW = 0.5
def test_embeddings() -> None:
"""Test embedding similarity for similar and different texts."""
with patch("vibe_bot.llama_wrapper.openai.OpenAI") as mock_openai:
mock_horse_vec = [0.8] * 1024 + [0.6] * 1024
mock_horse_also_vec = [0.79] * 1024 + [0.61] * 1024
mock_donkey_vec = [-0.8] * 1024 + [-0.6] * 1024
mock_horse_vec = [0.8] * 1024 + [0.6] * 1024
mock_horse_also_vec = [0.79] * 1024 + [0.61] * 1024
mock_donkey_vec = [-0.8] * 1024 + [-0.6] * 1024
mock_response1 = MagicMock()
mock_response1.data = [MagicMock(embedding=mock_horse_vec)]
mock_response2 = MagicMock()
mock_response2.data = [MagicMock(embedding=mock_horse_also_vec)]
mock_response3 = MagicMock()
mock_response3.data = [MagicMock(embedding=mock_donkey_vec)]
mock_openai.return_value.embeddings.create.side_effect = [
mock_response1,
mock_response2,
mock_response3,
]
def mock_post(*args: Any, **kwargs: Any) -> MagicMock:
json_data = kwargs.get("json", {})
text = json_data["input"][0]
if "horse" in text and "donkey" not in text and "also" not in text:
embedding_data = mock_horse_vec
elif "also" in text:
embedding_data = mock_horse_also_vec
else:
embedding_data = mock_donkey_vec
mock_resp = MagicMock()
mock_resp.json.return_value = {"data": [{"embedding": embedding_data}]}
return mock_resp
with patch("vibe_bot.llama_wrapper.requests.post", side_effect=mock_post):
result1 = embedding(
"this is a horse",
openai_url=EMBEDDING_ENDPOINT,
+9 -3
View File
@@ -125,7 +125,9 @@ def test_custom_bot_command_success(
asyncio.run(
main_module.custom_bot(
mock_ctx, bot_name="alfred", personality="you are a british butler",
mock_ctx,
bot_name="alfred",
personality="you are a british butler",
),
)
@@ -199,7 +201,9 @@ def test_custom_bot_command_create_fails(
asyncio.run(
main_module.custom_bot(
mock_ctx, bot_name="alfred", personality="you are a british butler",
mock_ctx,
bot_name="alfred",
personality="you are a british butler",
),
)
call_args = mock_ctx.send.call_args[0][0]
@@ -347,7 +351,9 @@ def test_handle_chat_success(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.return_value = "This is a bot response" # noqa: E501
mock_llama_wrapper.chat_completion_with_history.return_value = (
"This is a bot response"
)
asyncio.run(
main_module.handle_chat(
+13 -3
View File
@@ -63,9 +63,15 @@ def test_generate_audio_multiple_chunks(mock_kokoro_tts: MagicMock) -> None:
from vibe_bot.tts import TTSEngine
mock_kokoro_tts["chunk_text"].return_value = ["chunk one", "chunk two", "chunk three"] # noqa: E501
mock_kokoro_tts["chunk_text"].return_value = [
"chunk one",
"chunk two",
"chunk three",
]
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")
result = engine.generate_audio("this text is long enough to be split into multiple chunks") # noqa: E501
result = engine.generate_audio(
"this text is long enough to be split into multiple chunks",
)
assert isinstance(result, BytesIO)
assert mock_kokoro_tts["process_chunk_sequential"].call_count == 3
@@ -88,7 +94,11 @@ def test_generate_audio_chunk_failure(mock_kokoro_tts: MagicMock) -> None:
raise Exception("processing error")
return np.array([0.1, 0.2], dtype=np.float32), 24000
mock_kokoro_tts["chunk_text"].return_value = ["good chunk", "bad chunk", "another good"] # noqa: E501
mock_kokoro_tts["chunk_text"].return_value = [
"good chunk",
"bad chunk",
"another good",
]
mock_kokoro_tts["process_chunk_sequential"].side_effect = process_with_failure
engine = TTSEngine("/tmp/test-model.onnx", "/tmp/test-voices.bin")