bots know who you are and where you live
This commit is contained in:
+35
-2
@@ -50,6 +50,38 @@ intents = discord.Intents.default()
|
|||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_info(user: discord.User | discord.Member) -> str:
|
||||||
|
"""Format user information for inclusion in bot prompts."""
|
||||||
|
parts: list[str] = []
|
||||||
|
if user.global_name:
|
||||||
|
parts.append(f"Global Name: {user.global_name}")
|
||||||
|
nick = getattr(user, "nick", None)
|
||||||
|
if nick:
|
||||||
|
parts.append(f"Nickname: {nick}")
|
||||||
|
top_role = getattr(user, "top_role", None)
|
||||||
|
if top_role and top_role.name != "@everyone":
|
||||||
|
parts.append(f"Top Role: {top_role.name}")
|
||||||
|
activities = getattr(user, "activities", None)
|
||||||
|
if activities:
|
||||||
|
activity_names = [
|
||||||
|
getattr(a, "name", str(a))
|
||||||
|
for a in activities
|
||||||
|
if getattr(a, "name", "") != "custom_status"
|
||||||
|
]
|
||||||
|
if activity_names:
|
||||||
|
parts.append(f"Activities: {', '.join(activity_names)}")
|
||||||
|
joined_at = getattr(user, "joined_at", None)
|
||||||
|
if joined_at:
|
||||||
|
parts.append(f"Joined: {joined_at.strftime('%Y-%m-%d')}")
|
||||||
|
parts.append(f"Username: {user.name}")
|
||||||
|
parts.append(f"User ID: {user.id}")
|
||||||
|
parts.append(
|
||||||
|
f"Account Created: {user.created_at.strftime('%Y-%m-%d') if user.created_at else 'Unknown'}"
|
||||||
|
)
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
# Initialize TTS engine
|
# Initialize TTS engine
|
||||||
tts_engine: tts.TTSEngine | None = None
|
tts_engine: tts.TTSEngine | None = None
|
||||||
try:
|
try:
|
||||||
@@ -441,7 +473,7 @@ async def _speak_with_bot(
|
|||||||
return
|
return
|
||||||
|
|
||||||
_, system_prompt, _, _ = bot_info
|
_, system_prompt, _, _ = bot_info
|
||||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences.\n\nUser Information:\n{get_user_info(ctx.author)}"
|
||||||
|
|
||||||
# Determine language for the chosen voice
|
# Determine language for the chosen voice
|
||||||
chosen_voice = voice or TTS_VOICE
|
chosen_voice = voice or TTS_VOICE
|
||||||
@@ -497,6 +529,7 @@ async def _speak_with_bot(
|
|||||||
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
guild_id=str(ctx.guild.id) if ctx.guild else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await ctx.send(f"**{bot_name}**: {bot_response}")
|
||||||
await ctx.send(f"Generating speech for **{bot_name}**...")
|
await ctx.send(f"Generating speech for **{bot_name}**...")
|
||||||
audio_buffer = engine.generate_audio(
|
audio_buffer = engine.generate_audio(
|
||||||
bot_response,
|
bot_response,
|
||||||
@@ -834,7 +867,7 @@ async def handle_chat(
|
|||||||
|
|
||||||
logger.info("Chat prompts: %s", prompts)
|
logger.info("Chat prompts: %s", prompts)
|
||||||
|
|
||||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences."
|
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences.\n\nUser Information:\n{get_user_info(ctx.author)}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
bot_response = llama_wrapper.chat_completion_with_history(
|
bot_response = llama_wrapper.chat_completion_with_history(
|
||||||
|
|||||||
+175
-6
@@ -13,6 +13,39 @@ def mock_ctx() -> MagicMock:
|
|||||||
ctx = MagicMock()
|
ctx = MagicMock()
|
||||||
ctx.author.name = "testuser"
|
ctx.author.name = "testuser"
|
||||||
ctx.author.id = "12345"
|
ctx.author.id = "12345"
|
||||||
|
ctx.author.global_name = "Test User"
|
||||||
|
ctx.author.nick = "tester"
|
||||||
|
ctx.author.top_role.name = "@everyone"
|
||||||
|
ctx.author.activities = []
|
||||||
|
ctx.author.joined_at = None
|
||||||
|
ctx.author.created_at = None
|
||||||
|
ctx.channel.id = "channel-1"
|
||||||
|
ctx.guild.id = "guild-1"
|
||||||
|
ctx.message.id = "msg-1"
|
||||||
|
ctx.message.attachments = []
|
||||||
|
ctx.bot.user = MagicMock()
|
||||||
|
ctx.bot.user.name = "test-bot"
|
||||||
|
ctx.bot.user.id = "bot-123"
|
||||||
|
ctx.send = AsyncMock()
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ctx_with_member() -> MagicMock:
|
||||||
|
"""Create a mock Discord command context with full member data."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.author.name = "testuser"
|
||||||
|
ctx.author.id = "12345"
|
||||||
|
ctx.author.global_name = "Test User"
|
||||||
|
ctx.author.nick = "tester"
|
||||||
|
ctx.author.top_role.name = "Admin"
|
||||||
|
mock_activity = MagicMock()
|
||||||
|
mock_activity.name = "Chess"
|
||||||
|
ctx.author.activities = [mock_activity]
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
ctx.author.joined_at = datetime(2024, 1, 15)
|
||||||
|
ctx.author.created_at = datetime(2023, 6, 1)
|
||||||
ctx.channel.id = "channel-1"
|
ctx.channel.id = "channel-1"
|
||||||
ctx.guild.id = "guild-1"
|
ctx.guild.id = "guild-1"
|
||||||
ctx.message.id = "msg-1"
|
ctx.message.id = "msg-1"
|
||||||
@@ -112,6 +145,9 @@ def test_speak_with_custom_bot(
|
|||||||
|
|
||||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||||
mock_tts_engine.generate_audio.assert_called_once()
|
mock_tts_engine.generate_audio.assert_called_once()
|
||||||
|
assert mock_ctx.send.call_count >= 3
|
||||||
|
text_response = mock_ctx.send.call_args_list[1][0][0]
|
||||||
|
assert "**alfred**:" in text_response or "**alfred** :" in text_response
|
||||||
|
|
||||||
|
|
||||||
def test_custom_bot_command_success(
|
def test_custom_bot_command_success(
|
||||||
@@ -435,10 +471,13 @@ def test_speak_plain_with_mock_tts(
|
|||||||
|
|
||||||
from vibe_bot.config import TTS_SPEED, TTS_VOICE
|
from vibe_bot.config import TTS_SPEED, TTS_VOICE
|
||||||
|
|
||||||
|
from vibe_bot.tts import DEFAULT_LANG
|
||||||
|
|
||||||
mock_tts_engine.generate_audio.assert_called_once_with(
|
mock_tts_engine.generate_audio.assert_called_once_with(
|
||||||
"hello world",
|
"hello world",
|
||||||
voice=TTS_VOICE,
|
voice=TTS_VOICE,
|
||||||
speed=TTS_SPEED,
|
speed=TTS_SPEED,
|
||||||
|
lang=DEFAULT_LANG,
|
||||||
)
|
)
|
||||||
assert mock_ctx.send.call_count >= 2
|
assert mock_ctx.send.call_count >= 2
|
||||||
|
|
||||||
@@ -602,12 +641,142 @@ def test_history_with_data(
|
|||||||
asyncio.run(main_module.history(mock_ctx, bot_name="alfred"))
|
asyncio.run(main_module.history(mock_ctx, bot_name="alfred"))
|
||||||
|
|
||||||
assert mock_ctx.send.call_count >= 1
|
assert mock_ctx.send.call_count >= 1
|
||||||
first_call = mock_ctx.send.call_args_list[0][0][0]
|
|
||||||
assert "Chat History for **alfred**" in first_call
|
|
||||||
assert "hello" in first_call
|
def test_get_user_info_minimal(mock_ctx: MagicMock) -> None:
|
||||||
assert "alfred: yes master?" in first_call
|
"""Test get_user_info with minimal member data."""
|
||||||
assert "what time is it" in first_call
|
import vibe_bot.main as main_module
|
||||||
assert "alfred: it is currently 3pm" in first_call
|
|
||||||
|
result = main_module.get_user_info(mock_ctx.author)
|
||||||
|
|
||||||
|
assert "Username: testuser" in result
|
||||||
|
assert "User ID: 12345" in result
|
||||||
|
assert "Global Name: Test User" in result
|
||||||
|
assert "Nickname: tester" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_info_with_member_data(mock_ctx_with_member: MagicMock) -> None:
|
||||||
|
"""Test get_user_info with full member data including roles and activities."""
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
result = main_module.get_user_info(mock_ctx_with_member.author)
|
||||||
|
|
||||||
|
assert "Global Name: Test User" in result
|
||||||
|
assert "Nickname: tester" in result
|
||||||
|
assert "Username: testuser" in result
|
||||||
|
assert "User ID: 12345" in result
|
||||||
|
assert "Top Role: Admin" in result
|
||||||
|
assert "Activities: Chess" in result
|
||||||
|
assert "Joined: 2024-01-15" in result
|
||||||
|
assert "Account Created: 2023-06-01" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_info_no_global_name(mock_ctx: MagicMock) -> None:
|
||||||
|
"""Test get_user_info when user has no global name."""
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_ctx.author.global_name = None
|
||||||
|
mock_ctx.author.nick = None
|
||||||
|
mock_ctx.author.top_role.name = "@everyone"
|
||||||
|
mock_ctx.author.activities = []
|
||||||
|
|
||||||
|
result = main_module.get_user_info(mock_ctx.author)
|
||||||
|
|
||||||
|
assert "Global Name:" not in result
|
||||||
|
assert "Nickname:" not in result
|
||||||
|
assert "Top Role:" not in result
|
||||||
|
assert "Activities:" not in result
|
||||||
|
assert "Username: testuser" in result
|
||||||
|
assert "User ID: 12345" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_info_with_top_role_not_everyone(
|
||||||
|
mock_ctx_with_member: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test get_user_info includes top role when not @everyone."""
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
result = main_module.get_user_info(mock_ctx_with_member.author)
|
||||||
|
|
||||||
|
assert "Top Role: Admin" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_user_info_no_activities(mock_ctx: MagicMock) -> None:
|
||||||
|
"""Test get_user_info when user has no activities."""
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_ctx.author.activities = []
|
||||||
|
|
||||||
|
result = main_module.get_user_info(mock_ctx.author)
|
||||||
|
|
||||||
|
assert "Activities:" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_chat_includes_user_info(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_database: MagicMock,
|
||||||
|
mock_llama_wrapper: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test handle_chat includes user info in system prompt."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.return_value = (
|
||||||
|
"This is a bot response"
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
main_module.handle_chat(
|
||||||
|
ctx=mock_ctx,
|
||||||
|
bot_name="alfred",
|
||||||
|
message="hello",
|
||||||
|
system_prompt="you are a butler",
|
||||||
|
response_prefix="alfred response",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||||
|
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
|
||||||
|
system_prompt = call_kwargs.kwargs["system_prompt"]
|
||||||
|
assert "you are a butler" in system_prompt
|
||||||
|
assert "User Information:" in system_prompt
|
||||||
|
assert "Username: testuser" in system_prompt
|
||||||
|
assert "User ID: 12345" in system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_with_bot_includes_user_info(
|
||||||
|
mock_ctx: MagicMock,
|
||||||
|
mock_tts_engine: MagicMock,
|
||||||
|
mock_custom_bot_manager: MagicMock,
|
||||||
|
mock_database: MagicMock,
|
||||||
|
mock_llama_wrapper: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test _speak_with_bot includes user info in system prompt."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import vibe_bot.main as main_module
|
||||||
|
|
||||||
|
mock_custom_bot_manager.list_custom_bots.return_value = [
|
||||||
|
("alfred", "british butler", "user-123"),
|
||||||
|
]
|
||||||
|
mock_custom_bot_manager.get_custom_bot.return_value = (
|
||||||
|
"alfred",
|
||||||
|
"british butler",
|
||||||
|
"user-123",
|
||||||
|
"2024-01-01",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
|
||||||
|
|
||||||
|
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||||
|
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
|
||||||
|
system_prompt = call_kwargs.kwargs["system_prompt"]
|
||||||
|
assert "british butler" in system_prompt
|
||||||
|
assert "User Information:" in system_prompt
|
||||||
|
assert "Username: testuser" in system_prompt
|
||||||
|
assert "User ID: 12345" in system_prompt
|
||||||
|
mock_tts_engine.generate_audio.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_history_long_response_chunked(
|
def test_history_long_response_chunked(
|
||||||
|
|||||||
Reference in New Issue
Block a user