add tool calling
This commit is contained in:
+55
-2
@@ -33,6 +33,7 @@ from vibe_bot.config import (
|
||||
VOICES_LIST,
|
||||
)
|
||||
from vibe_bot.database import CustomBotManager, get_database
|
||||
from vibe_bot.tools import get_channel_members, get_channel_members_impl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from discord.ext.commands import Bot
|
||||
@@ -495,9 +496,35 @@ async def _speak_with_bot(
|
||||
if context:
|
||||
prompts = context + prompts
|
||||
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
# Build tool definitions from LangChain tools
|
||||
speak_tools: list[dict[str, object]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_channel_members.name,
|
||||
"description": get_channel_members.description,
|
||||
"parameters": get_channel_members.args_schema.model_json_schema(), # type: ignore
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def speak_tool_executor(tool_name: str, tool_args: dict[str, str]) -> str:
|
||||
"""Execute a tool by name with the given arguments."""
|
||||
if tool_name == "get_channel_members":
|
||||
return get_channel_members_impl(ctx.channel)
|
||||
return f"Unknown tool: {tool_name}"
|
||||
|
||||
async def speak_tool_call_notifier(tool_name: str, tool_args: dict[str, str]) -> None:
|
||||
"""Send a notification message when a tool is called."""
|
||||
if tool_name == "get_channel_members":
|
||||
await ctx.send(f"**{bot_name}** is looking at the channel members...")
|
||||
|
||||
bot_response = await llama_wrapper.chat_completion_with_tools(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts,
|
||||
tools=speak_tools,
|
||||
tool_executor=speak_tool_executor,
|
||||
tool_call_notifier=speak_tool_call_notifier,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
@@ -869,10 +896,36 @@ async def handle_chat(
|
||||
|
||||
system_prompt_edit = f"{system_prompt}\nKeep your responses under 2-3 sentences.\n\nUser Information:\n{get_user_info(ctx.author)}"
|
||||
|
||||
# Build tool definitions from LangChain tools
|
||||
tools: list[dict[str, object]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": get_channel_members.name,
|
||||
"description": get_channel_members.description,
|
||||
"parameters": get_channel_members.args_schema.model_json_schema(), # type: ignore
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def tool_executor(tool_name: str, tool_args: dict[str, str]) -> str:
|
||||
"""Execute a tool by name with the given arguments."""
|
||||
if tool_name == "get_channel_members":
|
||||
return get_channel_members_impl(ctx.channel)
|
||||
return f"Unknown tool: {tool_name}"
|
||||
|
||||
async def tool_call_notifier(tool_name: str, tool_args: dict[str, str]) -> None:
|
||||
"""Send a notification message when a tool is called."""
|
||||
if tool_name == "get_channel_members":
|
||||
await ctx.send(f"{bot_name} is looking at the channel members...")
|
||||
|
||||
try:
|
||||
bot_response = llama_wrapper.chat_completion_with_history(
|
||||
bot_response = await llama_wrapper.chat_completion_with_tools(
|
||||
system_prompt=system_prompt_edit,
|
||||
prompts=prompts,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
tool_call_notifier=tool_call_notifier,
|
||||
openai_url=CHAT_ENDPOINT,
|
||||
openai_api_key=CHAT_ENDPOINT_KEY,
|
||||
model=CHAT_MODEL,
|
||||
|
||||
Reference in New Issue
Block a user