add tool calling
This commit is contained in:
+119
-1
@@ -7,7 +7,7 @@ Allows custom endpoints for each of the above supported functions.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, cast
|
||||
|
||||
import openai
|
||||
import requests
|
||||
@@ -170,6 +170,124 @@ def chat_completion_instruct(
|
||||
return ""
|
||||
|
||||
|
||||
async def chat_completion_with_tools(
|
||||
system_prompt: str,
|
||||
prompts: list[dict[str, str]],
|
||||
tools: list[dict[str, object]],
|
||||
tool_executor: Callable[[str, dict[str, str]], str],
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
max_tool_rounds: int = 5,
|
||||
tool_call_notifier: (
|
||||
Callable[[str, dict[str, str]], None]
|
||||
| Callable[[str, dict[str, str]], Awaitable[None]]
|
||||
| None
|
||||
) = None,
|
||||
) -> str:
|
||||
"""Send a chat completion request with tool support and iterative tool calling.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
prompts: List of prompt dicts with role and content.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
tool_executor: A callable that takes (tool_name: str, tool_args: dict) -> str.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
max_tool_rounds: Maximum number of tool call rounds before giving up.
|
||||
tool_call_notifier: Optional callback invoked before each tool call
|
||||
with (tool_name, tool_args).
|
||||
|
||||
Returns:
|
||||
The model's final response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
cast(
|
||||
"ChatCompletionMessageParam",
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
),
|
||||
]
|
||||
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
||||
|
||||
for _round in range(max_tool_rounds):
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
max_tokens=max_tokens,
|
||||
seed=-1,
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
return ""
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
# Check if the model wants to call a tool
|
||||
tool_calls = message.tool_calls
|
||||
if tool_calls:
|
||||
assistant_msg: dict[str, object] = {
|
||||
"role": "assistant",
|
||||
"content": message.content or "",
|
||||
}
|
||||
tool_call_dicts: list[dict[str, object]] = []
|
||||
for tool_call in tool_calls:
|
||||
tool_call_dicts.append(
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name, # type: ignore[union-attr]
|
||||
"arguments": tool_call.function.arguments, # type: ignore[union-attr]
|
||||
},
|
||||
},
|
||||
)
|
||||
assistant_msg["tool_calls"] = tool_call_dicts
|
||||
messages.append(cast("ChatCompletionMessageParam", assistant_msg))
|
||||
|
||||
# Execute each tool call and add results to messages
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name # type: ignore[union-attr]
|
||||
tool_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
||||
|
||||
if tool_call_notifier:
|
||||
result = tool_call_notifier(tool_name, tool_args)
|
||||
if hasattr(result, "__await__"):
|
||||
await result # type: ignore[misc]
|
||||
|
||||
tool_result = tool_executor(tool_name, tool_args)
|
||||
|
||||
messages.append(
|
||||
cast(
|
||||
"ChatCompletionMessageParam",
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result,
|
||||
},
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
# No more tool calls, return the final response
|
||||
content = message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def image_generation(
|
||||
prompt: str,
|
||||
*,
|
||||
|
||||
+55
-2
@@ -33,6 +33,7 @@ from vibe_bot.config import (
|
||||
VOICES_LIST,
|
||||
)
|
||||
from vibe_bot.database import CustomBotManager, get_database
|
||||
from vibe_bot.tools import get_channel_members, get_channel_members_impl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.ext.commands import Bot
|
||||
@@ -495,9 +496,35 @@ async def _speak_with_bot(
|
||||
if context:
|
||||
prompts = context + prompts
|
||||
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
# Build tool definitions from LangChain tools
|
||||
speak_tools: list[dict[str, object]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_channel_members.name,
|
||||
"description": get_channel_members.description,
|
||||
"parameters": get_channel_members.args_schema.model_json_schema(), # type: ignore
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def speak_tool_executor(tool_name: str, tool_args: dict[str, str]) -> str:
|
||||
"""Execute a tool by name with the given arguments."""
|
||||
if tool_name == "get_channel_members":
|
||||
return get_channel_members_impl(ctx.channel)
|
||||
return f"Unknown tool: {tool_name}"
|
||||
|
||||
async def speak_tool_call_notifier(tool_name: str, tool_args: dict[str, str]) -> None:
|
||||
"""Send a notification message when a tool is called."""
|
||||
if tool_name == "get_channel_members":
|
||||
await ctx.send(f"**{bot_name}** is looking at the channel members...")
|
||||
|
||||
bot_response = await llama_wrapper.chat_completion_with_tools(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts,
|
||||
tools=speak_tools,
|
||||
tool_executor=speak_tool_executor,
|
||||
tool_call_notifier=speak_tool_call_notifier,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
@@ -869,10 +896,36 @@ async def handle_chat(
|
||||
|
||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences.\n\nUser Information:\n{get_user_info(ctx.author)}"
|
||||
|
||||
# Build tool definitions from LangChain tools
|
||||
tools: list[dict[str, object]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_channel_members.name,
|
||||
"description": get_channel_members.description,
|
||||
"parameters": get_channel_members.args_schema.model_json_schema(), # type: ignore
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def tool_executor(tool_name: str, tool_args: dict[str, str]) -> str:
|
||||
"""Execute a tool by name with the given arguments."""
|
||||
if tool_name == "get_channel_members":
|
||||
return get_channel_members_impl(ctx.channel)
|
||||
return f"Unknown tool: {tool_name}"
|
||||
|
||||
async def tool_call_notifier(tool_name: str, tool_args: dict[str, str]) -> None:
|
||||
"""Send a notification message when a tool is called."""
|
||||
if tool_name == "get_channel_members":
|
||||
await ctx.send(f"{bot_name} is looking at the channel members...")
|
||||
|
||||
try:
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
bot_response = await llama_wrapper.chat_completion_with_tools(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
tool_call_notifier=tool_call_notifier,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
|
||||
@@ -7,7 +7,7 @@ import warnings
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -195,6 +195,7 @@ def mock_llama_wrapper() -> Generator[MagicMock]:
|
||||
"""Provide mock llama_wrapper module."""
|
||||
with patch("vibe_bot.main.llama_wrapper") as mock_wrapper:
|
||||
mock_wrapper.chat_completion_with_history.return_value = "Bot response"
|
||||
mock_wrapper.chat_completion_with_tools = AsyncMock(return_value="Bot response")
|
||||
mock_wrapper.chat_completion_instruct.return_value = "image prompt"
|
||||
mock_wrapper.image_generation.return_value = ""
|
||||
mock_wrapper.image_edit.return_value = ""
|
||||
|
||||
+10
-10
@@ -143,7 +143,7 @@ def test_speak_with_custom_bot(
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
|
||||
mock_tts_engine.generate_audio.assert_called_once()
|
||||
assert mock_ctx.send.call_count >= 3
|
||||
text_response = mock_ctx.send.call_args_list[1][0][0]
|
||||
@@ -387,7 +387,7 @@ def test_handle_chat_success(
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = (
|
||||
mock_llama_wrapper.chat_completion_with_tools.return_value = (
|
||||
"This is a bot response"
|
||||
)
|
||||
|
||||
@@ -401,7 +401,7 @@ def test_handle_chat_success(
|
||||
),
|
||||
)
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
|
||||
mock_database.add_message.assert_called()
|
||||
assert mock_ctx.send.call_count >= 2
|
||||
|
||||
@@ -416,7 +416,7 @@ def test_handle_chat_error(
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.side_effect = Exception("API error")
|
||||
mock_llama_wrapper.chat_completion_with_tools.side_effect = Exception("API error")
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
@@ -443,7 +443,7 @@ def test_handle_chat_long_response_chunked(
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
long_response = "x" * 2500
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = long_response
|
||||
mock_llama_wrapper.chat_completion_with_tools.return_value = long_response
|
||||
|
||||
asyncio.run(
|
||||
main_module.handle_chat(
|
||||
@@ -722,7 +722,7 @@ def test_handle_chat_includes_user_info(
|
||||
|
||||
import vibe_bot.main as main_module
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.return_value = (
|
||||
mock_llama_wrapper.chat_completion_with_tools.return_value = (
|
||||
"This is a bot response"
|
||||
)
|
||||
|
||||
@@ -736,8 +736,8 @@ def test_handle_chat_includes_user_info(
|
||||
),
|
||||
)
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
|
||||
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
|
||||
call_kwargs = mock_llama_wrapper.chat_completion_with_tools.call_args
|
||||
system_prompt = call_kwargs.kwargs["system_prompt"]
|
||||
assert "you are a butler" in system_prompt
|
||||
assert "User Information:" in system_prompt
|
||||
@@ -769,8 +769,8 @@ def test_speak_with_bot_includes_user_info(
|
||||
|
||||
asyncio.run(main_module.speak(mock_ctx, message="alfred what time is it"))
|
||||
|
||||
mock_llama_wrapper.chat_completion_with_history.assert_called_once()
|
||||
call_kwargs = mock_llama_wrapper.chat_completion_with_history.call_args
|
||||
mock_llama_wrapper.chat_completion_with_tools.assert_called_once()
|
||||
call_kwargs = mock_llama_wrapper.chat_completion_with_tools.call_args
|
||||
system_prompt = call_kwargs.kwargs["system_prompt"]
|
||||
assert "british butler" in system_prompt
|
||||
assert "User Information:" in system_prompt
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
"""Tests for the tools module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vibe_bot.tools import get_channel_members, get_channel_members_impl
|
||||
|
||||
|
||||
def test_get_channel_members_impl_returns_formatted_list() -> None:
|
||||
"""Test get_channel_members_impl returns a formatted member list."""
|
||||
mock_member1 = MagicMock()
|
||||
mock_member1.display_name = "Alice"
|
||||
mock_member1.name = "alice"
|
||||
mock_member1.nick = None
|
||||
mock_member1.global_name = None
|
||||
mock_member1.status = MagicMock(value="online")
|
||||
|
||||
mock_member2 = MagicMock()
|
||||
mock_member2.display_name = "Bob"
|
||||
mock_member2.name = "bob"
|
||||
mock_member2.nick = "bobby"
|
||||
mock_member2.global_name = None
|
||||
mock_member2.status = MagicMock(value="idle")
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = [mock_member1, mock_member2]
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
assert "Members in this channel (2 total):" in result
|
||||
assert "Alice" in result
|
||||
assert "Bob (nickname: bobby)" in result
|
||||
assert "[online]" in result
|
||||
assert "[idle]" in result
|
||||
|
||||
|
||||
def test_get_channel_members_impl_empty() -> None:
|
||||
"""Test get_channel_members_impl with no members."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = []
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
assert "No members found in this channel." in result
|
||||
|
||||
|
||||
def test_get_channel_members_impl_with_global_name() -> None:
|
||||
"""Test get_channel_members_impl includes global name."""
|
||||
mock_member = MagicMock()
|
||||
mock_member.display_name = "charlie"
|
||||
mock_member.name = "charlie"
|
||||
mock_member.nick = None
|
||||
mock_member.global_name = "Charlie Global"
|
||||
mock_member.status = MagicMock(value="dnd")
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = [mock_member]
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
assert "(global name: Charlie Global)" in result
|
||||
|
||||
|
||||
def test_get_channel_members_impl_no_status() -> None:
|
||||
"""Test get_channel_members_impl when member has no status."""
|
||||
mock_member = MagicMock()
|
||||
mock_member.display_name = "dave"
|
||||
mock_member.name = "dave"
|
||||
mock_member.nick = None
|
||||
mock_member.global_name = None
|
||||
mock_member.status = None
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = [mock_member]
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
assert "dave" in result
|
||||
assert "[]" not in result
|
||||
|
||||
|
||||
def test_get_channel_members_impl_exception() -> None:
|
||||
"""Test get_channel_members_impl handles exceptions gracefully."""
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = None
|
||||
type(mock_channel).members = property(
|
||||
lambda self: (_ for _ in ()).throw(Exception("test"))
|
||||
)
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
assert "Failed to retrieve channel members." in result
|
||||
|
||||
|
||||
def test_get_channel_members_tool_schema() -> None:
|
||||
"""Test that get_channel_members has a valid tool schema."""
|
||||
assert get_channel_members.name == "get_channel_members"
|
||||
assert get_channel_members.description is not None
|
||||
schema = get_channel_members.args_schema.model_json_schema() # type: ignore
|
||||
assert "properties" in schema
|
||||
assert schema["properties"] == {}
|
||||
|
||||
|
||||
def test_get_channel_members_impl_sorted_by_display_name() -> None:
|
||||
"""Test that members are sorted by display name."""
|
||||
mock_member_z = MagicMock()
|
||||
mock_member_z.display_name = "Zara"
|
||||
mock_member_z.name = "zara"
|
||||
mock_member_z.nick = None
|
||||
mock_member_z.global_name = None
|
||||
mock_member_z.status = MagicMock(value="online")
|
||||
|
||||
mock_member_a = MagicMock()
|
||||
mock_member_a.display_name = "Aaron"
|
||||
mock_member_a.name = "aaron"
|
||||
mock_member_a.nick = None
|
||||
mock_member_a.global_name = None
|
||||
mock_member_a.status = MagicMock(value="online")
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = [mock_member_z, mock_member_a]
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
lines = [l for l in result.split("\n") if l.strip().startswith("- ")]
|
||||
assert "Aaron" in lines[0]
|
||||
assert "Zara" in lines[1]
|
||||
|
||||
|
||||
def test_get_channel_members_impl_no_nick_when_same_as_display() -> None:
|
||||
"""Test that nickname is not shown when same as display name."""
|
||||
mock_member = MagicMock()
|
||||
mock_member.display_name = "same"
|
||||
mock_member.name = "same"
|
||||
mock_member.nick = "same"
|
||||
mock_member.global_name = None
|
||||
mock_member.status = MagicMock(value="online")
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_channel.members = [mock_member]
|
||||
|
||||
result = get_channel_members_impl(mock_channel)
|
||||
|
||||
assert "(nickname: same)" not in result
|
||||
assert "same" in result
|
||||
|
||||
|
||||
def test_format_member_minimal() -> None:
|
||||
"""Test _format_member with minimal data."""
|
||||
from vibe_bot.tools import _format_member
|
||||
|
||||
mock_member = MagicMock()
|
||||
mock_member.display_name = None
|
||||
mock_member.name = "unknown_user"
|
||||
mock_member.nick = None
|
||||
mock_member.global_name = None
|
||||
mock_member.status = None
|
||||
|
||||
result = _format_member(mock_member)
|
||||
|
||||
assert "unknown_user" in result
|
||||
|
||||
|
||||
def test_format_member_with_all_fields() -> None:
|
||||
"""Test _format_member with all fields populated."""
|
||||
from vibe_bot.tools import _format_member
|
||||
|
||||
mock_member = MagicMock()
|
||||
mock_member.display_name = "Alice"
|
||||
mock_member.name = "alice"
|
||||
mock_member.nick = "Al"
|
||||
mock_member.global_name = "Alice Global"
|
||||
mock_member.status = MagicMock(value="online")
|
||||
|
||||
result = _format_member(mock_member)
|
||||
|
||||
assert "Alice" in result
|
||||
assert "(nickname: Al)" in result
|
||||
assert "(global name: Alice Global)" in result
|
||||
assert "[online]" in result
|
||||
@@ -0,0 +1,78 @@
|
||||
"""LangChain tools for the Discord bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_member(member: Any) -> str:
|
||||
"""Format a single member for display."""
|
||||
raw_display = getattr(member, "display_name", None)
|
||||
raw_name = getattr(member, "name", "Unknown")
|
||||
display_name = str(raw_display) if raw_display else str(raw_name)
|
||||
parts: list[str] = [display_name]
|
||||
|
||||
nick = getattr(member, "nick", None)
|
||||
if nick and nick != display_name:
|
||||
parts.append(f"(nickname: {nick})")
|
||||
|
||||
global_name = getattr(member, "global_name", None)
|
||||
if global_name and global_name != getattr(member, "name", ""):
|
||||
parts.append(f"(global name: {global_name})")
|
||||
|
||||
status = getattr(member, "status", None)
|
||||
if status:
|
||||
parts.append(f"[{status.value}]")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def get_channel_members_impl(channel: Any) -> str:
|
||||
"""Get a list of all members in the Discord channel the bot is part of.
|
||||
|
||||
Use this tool when asked about who is in the channel, who the members are,
|
||||
or to get a roster of people present in the current channel.
|
||||
|
||||
Returns:
|
||||
A formatted string listing all members in the channel with their usernames,
|
||||
display names, and nicknames.
|
||||
|
||||
"""
|
||||
try:
|
||||
members = channel.members
|
||||
if not members:
|
||||
return "No members found in this channel."
|
||||
|
||||
lines: list[str] = [f"Members in this channel ({len(members)} total):"]
|
||||
for member in sorted(
|
||||
members,
|
||||
key=lambda m: (
|
||||
getattr(m, "display_name", "") or getattr(m, "name", "")
|
||||
).lower(),
|
||||
):
|
||||
lines.append(f" - {_format_member(member)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
except Exception:
|
||||
logger.exception("Error fetching channel members")
|
||||
return "Failed to retrieve channel members."
|
||||
|
||||
|
||||
@tool
|
||||
def get_channel_members() -> str:
|
||||
"""Get a list of all members in the Discord channel the bot is part of.
|
||||
|
||||
Use this tool when asked about who is in the channel, who the members are,
|
||||
or to get a roster of people present in the current channel.
|
||||
|
||||
Returns:
|
||||
A formatted string listing all members in the channel with their usernames,
|
||||
display names, and nicknames.
|
||||
|
||||
"""
|
||||
return "No channel provided."
|
||||
Reference in New Issue
Block a user