add tool calling

This commit is contained in:
2026-05-24 15:26:13 -04:00
parent 879cd5cbe8
commit 7a1ba05068
8 changed files with 838 additions and 17 deletions
+119 -1
View File
@@ -7,7 +7,7 @@ Allows custom endpoints for each of the above supported functions.
from __future__ import annotations
import json
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Awaitable, Callable, cast
import openai
import requests
@@ -170,6 +170,124 @@ def chat_completion_instruct(
return ""
async def chat_completion_with_tools(
system_prompt: str,
prompts: list[dict[str, str]],
tools: list[dict[str, object]],
tool_executor: Callable[[str, dict[str, str]], str],
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
max_tool_rounds: int = 5,
tool_call_notifier: (
Callable[[str, dict[str, str]], None]
| Callable[[str, dict[str, str]], Awaitable[None]]
| None
) = None,
) -> str:
"""Send a chat completion request with tool support and iterative tool calling.
Args:
system_prompt: The system prompt to use.
prompts: List of prompt dicts with role and content.
tools: List of tool definitions in OpenAI format.
tool_executor: A callable that takes (tool_name: str, tool_args: dict) -> str.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
max_tool_rounds: Maximum number of tool call rounds before giving up.
tool_call_notifier: Optional callback invoked before each tool call
with (tool_name, tool_args).
Returns:
The model's final response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
cast(
"ChatCompletionMessageParam",
{
"role": "system",
"content": system_prompt,
},
),
]
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
for _round in range(max_tool_rounds):
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools, # type: ignore[arg-type]
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
message = response.choices[0].message
# Check if the model wants to call a tool
tool_calls = message.tool_calls
if tool_calls:
assistant_msg: dict[str, object] = {
"role": "assistant",
"content": message.content or "",
}
tool_call_dicts: list[dict[str, object]] = []
for tool_call in tool_calls:
tool_call_dicts.append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name, # type: ignore[union-attr]
"arguments": tool_call.function.arguments, # type: ignore[union-attr]
},
},
)
assistant_msg["tool_calls"] = tool_call_dicts
messages.append(cast("ChatCompletionMessageParam", assistant_msg))
# Execute each tool call and add results to messages
for tool_call in tool_calls:
tool_name = tool_call.function.name # type: ignore[union-attr]
tool_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
if tool_call_notifier:
result = tool_call_notifier(tool_name, tool_args)
if hasattr(result, "__await__"):
await result # type: ignore[misc]
tool_result = tool_executor(tool_name, tool_args)
messages.append(
cast(
"ChatCompletionMessageParam",
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
},
),
)
continue
# No more tool calls, return the final response
content = message.content
if content:
return content.strip()
return ""
return ""
def image_generation(
prompt: str,
*,
+55 -2
View File
@@ -33,6 +33,7 @@ from vibe_bot.config import (
VOICES_LIST,
)
from vibe_bot.database import CustomBotManager, get_database
from vibe_bot.tools import get_channel_members, get_channel_members_impl
if TYPE_CHECKING:
from discord.ext.commands import Bot
@@ -495,9 +496,35 @@ async def _speak_with_bot(
if context:
prompts = context + prompts
bot_response = llama_wrapper.chat_completion_with_history(
# Build tool definitions from LangChain tools
speak_tools: list[dict[str, object]] = [
{
"type": "function",
"function": {
"name": get_channel_members.name,
"description": get_channel_members.description,
"parameters": get_channel_members.args_schema.model_json_schema(), # type: ignore
},
},
]
def speak_tool_executor(tool_name: str, tool_args: dict[str, str]) -> str:
"""Execute a tool by name with the given arguments."""
if tool_name == "get_channel_members":
return get_channel_members_impl(ctx.channel)
return f"Unknown tool: {tool_name}"
async def speak_tool_call_notifier(tool_name: str, tool_args: dict[str, str]) -> None:
"""Send a notification message when a tool is called."""
if tool_name == "get_channel_members":
await ctx.send(f"**{bot_name}** is looking at the channel members...")
bot_response = await llama_wrapper.chat_completion_with_tools(
system_prompt=system_prompt_edit,
prompts=prompts,
tools=speak_tools,
tool_executor=speak_tool_executor,
tool_call_notifier=speak_tool_call_notifier,
openai_url=CHAT_ENDPOINT,
openai_api_key=CHAT_ENDPOINT_KEY,
model=CHAT_MODEL,
@@ -869,10 +896,36 @@ async def handle_chat(
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences.\n\nUser Information:\n{get_user_info(ctx.author)}"
# Build tool definitions from LangChain tools
tools: list[dict[str, object]] = [
{
"type": "function",
"function": {
"name": get_channel_members.name,
"description": get_channel_members.description,
"parameters": get_channel_members.args_schema.model_json_schema(), # type: ignore
},
},
]
def tool_executor(tool_name: str, tool_args: dict[str, str]) -> str:
"""Execute a tool by name with the given arguments."""
if tool_name == "get_channel_members":
return get_channel_members_impl(ctx.channel)
return f"Unknown tool: {tool_name}"
async def tool_call_notifier(tool_name: str, tool_args: dict[str, str]) -> None:
"""Send a notification message when a tool is called."""
if tool_name == "get_channel_members":
await ctx.send(f"{bot_name} is looking at the channel members...")
try:
bot_response = llama_wrapper.chat_completion_with_history(
bot_response = await llama_wrapper.chat_completion_with_tools(
system_prompt=system_prompt_edit,
prompts=prompts,
tools=tools,
tool_executor=tool_executor,
tool_call_notifier=tool_call_notifier,
openai_url=CHAT_ENDPOINT,
openai_api_key=CHAT_ENDPOINT_KEY,
model=CHAT_MODEL,
+2 -1
View File
@@ -7,7 +7,7 @@ import warnings
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@@ -195,6 +195,7 @@ def mock_llama_wrapper() -> Generator[MagicMock]:
"""Provide mock llama_wrapper module."""
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
mock_wrapper.chat_completion_with_tools = AsyncMock(return_value="Bot response")
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
mock_wrapper.image_generation.return_value = ""
mock_wrapper.image_edit.return_value = ""
+10 -10
View File
@@ -143,7 +143,7 @@ def test_speak_with_custom_bot(
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
mock_tts_engine.generate_audio.assert_called_once()
assert mock_ctx.send.call_count >= 3
text_response = mock_ctx.send.call_args_list[1][0][0]
@@ -387,7 +387,7 @@ def test_handle_chat_success(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.return_value = (
mock_llama_wrapper.chat_completion_with_tools.return_value = (
"This is a bot response"
)
@@ -401,7 +401,7 @@ def test_handle_chat_success(
),
)
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
mock_database.add_message.assert_called()
assert mock_ctx.send.call_count >= 2
@@ -416,7 +416,7 @@ def test_handle_chat_error(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.side_effect = Exception("API error")
mock_llama_wrapper.chat_completion_with_tools.side_effect = Exception("API error")
asyncio.run(
main_module.handle_chat(
@@ -443,7 +443,7 @@ def test_handle_chat_long_response_chunked(
import vibe_bot.main as main_module
long_response = "x" * 2500
mock_llama_wrapper.chat_completion_with_history.return_value = long_response
mock_llama_wrapper.chat_completion_with_tools.return_value = long_response
asyncio.run(
main_module.handle_chat(
@@ -722,7 +722,7 @@ def test_handle_chat_includes_user_info(
import vibe_bot.main as main_module
mock_llama_wrapper.chat_completion_with_history.return_value = (
mock_llama_wrapper.chat_completion_with_tools.return_value = (
"This is a bot response"
)
@@ -736,8 +736,8 @@ def test_handle_chat_includes_user_info(
),
)
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_tools.call_args
system_prompt = call_kwargs.kwargs["system_prompt"]
assert "you are a butler" in system_prompt
assert "User Information:" in system_prompt
@@ -769,8 +769,8 @@ def test_speak_with_bot_includes_user_info(
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
call_kwargs = mock_llama_wrapper.chat_completion_with_tools.call_args
system_prompt = call_kwargs.kwargs["system_prompt"]
assert "british butler" in system_prompt
assert "User Information:" in system_prompt
+181
View File
@@ -0,0 +1,181 @@
"""Tests for the tools module."""
from __future__ import annotations
from unittest.mock import MagicMock
from vibe_bot.tools import get_channel_members, get_channel_members_impl
def test_get_channel_members_impl_returns_formatted_list() -> None:
"""Test get_channel_members_impl returns a formatted member list."""
mock_member1 = MagicMock()
mock_member1.display_name = "Alice"
mock_member1.name = "alice"
mock_member1.nick = None
mock_member1.global_name = None
mock_member1.status = MagicMock(value="online")
mock_member2 = MagicMock()
mock_member2.display_name = "Bob"
mock_member2.name = "bob"
mock_member2.nick = "bobby"
mock_member2.global_name = None
mock_member2.status = MagicMock(value="idle")
mock_channel = MagicMock()
mock_channel.members = [mock_member1, mock_member2]
result = get_channel_members_impl(mock_channel)
assert "Members in this channel (2 total):" in result
assert "Alice" in result
assert "Bob (nickname: bobby)" in result
assert "[online]" in result
assert "[idle]" in result
def test_get_channel_members_impl_empty() -> None:
"""Test get_channel_members_impl with no members."""
mock_channel = MagicMock()
mock_channel.members = []
result = get_channel_members_impl(mock_channel)
assert "No members found in this channel." in result
def test_get_channel_members_impl_with_global_name() -> None:
"""Test get_channel_members_impl includes global name."""
mock_member = MagicMock()
mock_member.display_name = "charlie"
mock_member.name = "charlie"
mock_member.nick = None
mock_member.global_name = "Charlie Global"
mock_member.status = MagicMock(value="dnd")
mock_channel = MagicMock()
mock_channel.members = [mock_member]
result = get_channel_members_impl(mock_channel)
assert "(global name: Charlie Global)" in result
def test_get_channel_members_impl_no_status() -> None:
"""Test get_channel_members_impl when member has no status."""
mock_member = MagicMock()
mock_member.display_name = "dave"
mock_member.name = "dave"
mock_member.nick = None
mock_member.global_name = None
mock_member.status = None
mock_channel = MagicMock()
mock_channel.members = [mock_member]
result = get_channel_members_impl(mock_channel)
assert "dave" in result
assert "[]" not in result
def test_get_channel_members_impl_exception() -> None:
"""Test get_channel_members_impl handles exceptions gracefully."""
mock_channel = MagicMock()
mock_channel.members = None
type(mock_channel).members = property(
lambda self: (_ for _ in ()).throw(Exception("test"))
)
result = get_channel_members_impl(mock_channel)
assert "Failed to retrieve channel members." in result
def test_get_channel_members_tool_schema() -> None:
"""Test that get_channel_members has a valid tool schema."""
assert get_channel_members.name == "get_channel_members"
assert get_channel_members.description is not None
schema = get_channel_members.args_schema.model_json_schema() # type: ignore
assert "properties" in schema
assert schema["properties"] == {}
def test_get_channel_members_impl_sorted_by_display_name() -> None:
"""Test that members are sorted by display name."""
mock_member_z = MagicMock()
mock_member_z.display_name = "Zara"
mock_member_z.name = "zara"
mock_member_z.nick = None
mock_member_z.global_name = None
mock_member_z.status = MagicMock(value="online")
mock_member_a = MagicMock()
mock_member_a.display_name = "Aaron"
mock_member_a.name = "aaron"
mock_member_a.nick = None
mock_member_a.global_name = None
mock_member_a.status = MagicMock(value="online")
mock_channel = MagicMock()
mock_channel.members = [mock_member_z, mock_member_a]
result = get_channel_members_impl(mock_channel)
lines = [l for l in result.split("\n") if l.strip().startswith("- ")]
assert "Aaron" in lines[0]
assert "Zara" in lines[1]
def test_get_channel_members_impl_no_nick_when_same_as_display() -> None:
"""Test that nickname is not shown when same as display name."""
mock_member = MagicMock()
mock_member.display_name = "same"
mock_member.name = "same"
mock_member.nick = "same"
mock_member.global_name = None
mock_member.status = MagicMock(value="online")
mock_channel = MagicMock()
mock_channel.members = [mock_member]
result = get_channel_members_impl(mock_channel)
assert "(nickname: same)" not in result
assert "same" in result
def test_format_member_minimal() -> None:
"""Test _format_member with minimal data."""
from vibe_bot.tools import _format_member
mock_member = MagicMock()
mock_member.display_name = None
mock_member.name = "unknown_user"
mock_member.nick = None
mock_member.global_name = None
mock_member.status = None
result = _format_member(mock_member)
assert "unknown_user" in result
def test_format_member_with_all_fields() -> None:
"""Test _format_member with all fields populated."""
from vibe_bot.tools import _format_member
mock_member = MagicMock()
mock_member.display_name = "Alice"
mock_member.name = "alice"
mock_member.nick = "Al"
mock_member.global_name = "Alice Global"
mock_member.status = MagicMock(value="online")
result = _format_member(mock_member)
assert "Alice" in result
assert "(nickname: Al)" in result
assert "(global name: Alice Global)" in result
assert "[online]" in result
+78
View File
@@ -0,0 +1,78 @@
"""LangChain tools for the Discord bot."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
def _format_member(member: Any) -> str:
"""Format a single member for display."""
raw_display = getattr(member, "display_name", None)
raw_name = getattr(member, "name", "Unknown")
display_name = str(raw_display) if raw_display else str(raw_name)
parts: list[str] = [display_name]
nick = getattr(member, "nick", None)
if nick and nick != display_name:
parts.append(f"(nickname: {nick})")
global_name = getattr(member, "global_name", None)
if global_name and global_name != getattr(member, "name", ""):
parts.append(f"(global name: {global_name})")
status = getattr(member, "status", None)
if status:
parts.append(f"[{status.value}]")
return " ".join(parts)
def get_channel_members_impl(channel: Any) -> str:
"""Get a list of all members in the Discord channel the bot is part of.
Use this tool when asked about who is in the channel, who the members are,
or to get a roster of people present in the current channel.
Returns:
A formatted string listing all members in the channel with their usernames,
display names, and nicknames.
"""
try:
members = channel.members
if not members:
return "No members found in this channel."
lines: list[str] = [f"Members in this channel ({len(members)} total):"]
for member in sorted(
members,
key=lambda m: (
getattr(m, "display_name", "") or getattr(m, "name", "")
).lower(),
):
lines.append(f" - {_format_member(member)}")
return "\n".join(lines)
except Exception:
logger.exception("Error fetching channel members")
return "Failed to retrieve channel members."
@tool
def get_channel_members() -> str:
"""Get a list of all members in the Discord channel the bot is part of.
Use this tool when asked about who is in the channel, who the members are,
or to get a roster of people present in the current channel.
Returns:
A formatted string listing all members in the channel with their usernames,
display names, and nicknames.
"""
return "No channel provided."