add tool calling

This commit is contained in:
2026-05-24 15:26:13 -04:00
parent 879cd5cbe8
commit 7a1ba05068
8 changed files with 838 additions and 17 deletions
+2 -1
View File
@@ -7,7 +7,7 @@ import warnings
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@@ -195,6 +195,7 @@ def mock_llama_wrapper() -> Generator[MagicMock]:
"""Provide mock llama_wrapper module."""
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
mock_wrapper.chat_completion_with_tools = AsyncMock(return_value="Bot response")
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
mock_wrapper.image_generation.return_value = ""
mock_wrapper.image_edit.return_value = ""
+10 -10
View File
@@ -143,7 +143,7 @@ def test_speak_with_custom_bot(
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
mock_tts_engine.generate_audio.assert_called_once()
assert mock_ctx.send.call_count >= 3
text_response = mock_ctx.send.call_args_list[1][0][0]
@@ -387,7 +387,7 @@ def test_handle_chat_success(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.return_value = (
mock_llama_wrapper.chat_completion_with_tools.return_value = (
"This is a bot response"
)
@@ -401,7 +401,7 @@ def test_handle_chat_success(
),
)
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
mock_database.add_message.assert_called()
assert mock_ctx.send.call_count >= 2
@@ -416,7 +416,7 @@ def test_handle_chat_error(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.side_effect = Exception("API error")
mock_llama_wrapper.chat_completion_with_tools.side_effect = Exception("API error")
asyncio.run(
main_module.handle_chat(
@@ -443,7 +443,7 @@ def test_handle_chat_long_response_chunked(
import vibe_bot.main as main_module
long_response = "x" * 2500
mock_llama_wrapper.chat_completion_with_history.return_value = long_response
mock_llama_wrapper.chat_completion_with_tools.return_value = long_response
asyncio.run(
main_module.handle_chat(
@@ -722,7 +722,7 @@ def test_handle_chat_includes_user_info(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.return_value = (
mock_llama_wrapper.chat_completion_with_tools.return_value = (
"This is a bot response"
)
@@ -736,8 +736,8 @@ def test_handle_chat_includes_user_info(
),
)
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_tools.call_args
system_prompt = call_kwargs.kwargs["system_prompt"]
assert "you are a butler" in system_prompt
assert "User Information:" in system_prompt
@@ -769,8 +769,8 @@ def test_speak_with_bot_includes_user_info(
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_tools.call_args
system_prompt = call_kwargs.kwargs["system_prompt"]
assert "british butler" in system_prompt
assert "User Information:" in system_prompt
+181
View File
@@ -0,0 +1,181 @@
"""Tests for the tools module."""
from __future__ import annotations
from unittest.mock import MagicMock
from vibe_bot.tools import get_channel_members, get_channel_members_impl
def test_get_channel_members_impl_returns_formatted_list() -> None:
"""Test get_channel_members_impl returns a formatted member list."""
mock_member1 = MagicMock()
mock_member1.display_name = "Alice"
mock_member1.name = "alice"
mock_member1.nick = None
mock_member1.global_name = None
mock_member1.status = MagicMock(value="online")
mock_member2 = MagicMock()
mock_member2.display_name = "Bob"
mock_member2.name = "bob"
mock_member2.nick = "bobby"
mock_member2.global_name = None
mock_member2.status = MagicMock(value="idle")
mock_channel = MagicMock()
mock_channel.members = [mock_member1, mock_member2]
result = get_channel_members_impl(mock_channel)
assert "Members in this channel (2 total):" in result
assert "Alice" in result
assert "Bob (nickname: bobby)" in result
assert "[online]" in result
assert "[idle]" in result
def test_get_channel_members_impl_empty() -> None:
"""Test get_channel_members_impl with no members."""
mock_channel = MagicMock()
mock_channel.members = []
result = get_channel_members_impl(mock_channel)
assert "No members found in this channel." in result
def test_get_channel_members_impl_with_global_name() -> None:
"""Test get_channel_members_impl includes global name."""
mock_member = MagicMock()
mock_member.display_name = "charlie"
mock_member.name = "charlie"
mock_member.nick = None
mock_member.global_name = "Charlie Global"
mock_member.status = MagicMock(value="dnd")
mock_channel = MagicMock()
mock_channel.members = [mock_member]
result = get_channel_members_impl(mock_channel)
assert "(global name: Charlie Global)" in result
def test_get_channel_members_impl_no_status() -> None:
"""Test get_channel_members_impl when member has no status."""
mock_member = MagicMock()
mock_member.display_name = "dave"
mock_member.name = "dave"
mock_member.nick = None
mock_member.global_name = None
mock_member.status = None
mock_channel = MagicMock()
mock_channel.members = [mock_member]
result = get_channel_members_impl(mock_channel)
assert "dave" in result
assert "[]" not in result
def test_get_channel_members_impl_exception() -> None:
"""Test get_channel_members_impl handles exceptions gracefully."""
mock_channel = MagicMock()
mock_channel.members = None
type(mock_channel).members = property(
lambda self: (_ for _ in ()).throw(Exception("test"))
)
result = get_channel_members_impl(mock_channel)
assert "Failed to retrieve channel members." in result
def test_get_channel_members_tool_schema() -> None:
"""Test that get_channel_members has a valid tool schema."""
assert get_channel_members.name == "get_channel_members"
assert get_channel_members.description is not None
schema = get_channel_members.args_schema.model_json_schema() # type: ignore
assert "properties" in schema
assert schema["properties"] == {}
def test_get_channel_members_impl_sorted_by_display_name() -> None:
"""Test that members are sorted by display name."""
mock_member_z = MagicMock()
mock_member_z.display_name = "Zara"
mock_member_z.name = "zara"
mock_member_z.nick = None
mock_member_z.global_name = None
mock_member_z.status = MagicMock(value="online")
mock_member_a = MagicMock()
mock_member_a.display_name = "Aaron"
mock_member_a.name = "aaron"
mock_member_a.nick = None
mock_member_a.global_name = None
mock_member_a.status = MagicMock(value="online")
mock_channel = MagicMock()
mock_channel.members = [mock_member_z, mock_member_a]
result = get_channel_members_impl(mock_channel)
lines = [l for l in result.split("\n") if l.strip().startswith("- ")]
assert "Aaron" in lines[0]
assert "Zara" in lines[1]
def test_get_channel_members_impl_no_nick_when_same_as_display() -> None:
"""Test that nickname is not shown when same as display name."""
mock_member = MagicMock()
mock_member.display_name = "same"
mock_member.name = "same"
mock_member.nick = "same"
mock_member.global_name = None
mock_member.status = MagicMock(value="online")
mock_channel = MagicMock()
mock_channel.members = [mock_member]
result = get_channel_members_impl(mock_channel)
assert "(nickname: same)" not in result
assert "same" in result
def test_format_member_minimal() -> None:
"""Test _format_member with minimal data."""
from vibe_bot.tools import _format_member
mock_member = MagicMock()
mock_member.display_name = None
mock_member.name = "unknown_user"
mock_member.nick = None
mock_member.global_name = None
mock_member.status = None
result = _format_member(mock_member)
assert "unknown_user" in result
def test_format_member_with_all_fields() -> None:
"""Test _format_member with all fields populated."""
from vibe_bot.tools import _format_member
mock_member = MagicMock()
mock_member.display_name = "Alice"
mock_member.name = "alice"
mock_member.nick = "Al"
mock_member.global_name = "Alice Global"
mock_member.status = MagicMock(value="online")
result = _format_member(mock_member)
assert "Alice" in result
assert "(nickname: Al)" in result
assert "(global name: Alice Global)" in result
assert "[online]" in result