add tool calling

This commit is contained in:
2026-05-24 15:26:13 -04:00
parent 879cd5cbe8
commit 7a1ba05068
8 changed files with 838 additions and 17 deletions
+119 -1
View File
@@ -7,7 +7,7 @@ Allows custom endpoints for each of the above supported functions.
from __future__ import annotations
import json
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Awaitable, Callable, cast
import openai
import requests
@@ -170,6 +170,124 @@ def chat_completion_instruct(
return ""
async def chat_completion_with_tools(
system_prompt: str,
prompts: list[dict[str, str]],
tools: list[dict[str, object]],
tool_executor: Callable[[str, dict[str, str]], str],
*,
openai_url: str,
openai_api_key: str,
model: str,
max_tokens: int = 1000,
max_tool_rounds: int = 5,
tool_call_notifier: (
Callable[[str, dict[str, str]], None]
| Callable[[str, dict[str, str]], Awaitable[None]]
| None
) = None,
) -> str:
"""Send a chat completion request with tool support and iterative tool calling.
Args:
system_prompt: The system prompt to use.
prompts: List of prompt dicts with role and content.
tools: List of tool definitions in OpenAI format.
tool_executor: A callable that takes (tool_name: str, tool_args: dict) -> str.
openai_url: The OpenAI-compatible API URL.
openai_api_key: The API key for authentication.
model: The model to use for completion.
max_tokens: Maximum number of tokens to generate.
max_tool_rounds: Maximum number of tool call rounds before giving up.
tool_call_notifier: Optional callback invoked before each tool call
with (tool_name, tool_args).
Returns:
The model's final response text, stripped of whitespace.
"""
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
messages: list[ChatCompletionMessageParam] = [
cast(
"ChatCompletionMessageParam",
{
"role": "system",
"content": system_prompt,
},
),
]
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
for _round in range(max_tool_rounds):
response = client.chat.completions.create(
model=model,
messages=messages,
tools=tools, # type: ignore[arg-type]
max_tokens=max_tokens,
seed=-1,
timeout=60.0,
)
if not response.choices:
return ""
message = response.choices[0].message
# Check if the model wants to call a tool
tool_calls = message.tool_calls
if tool_calls:
assistant_msg: dict[str, object] = {
"role": "assistant",
"content": message.content or "",
}
tool_call_dicts: list[dict[str, object]] = []
for tool_call in tool_calls:
tool_call_dicts.append(
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name, # type: ignore[union-attr]
"arguments": tool_call.function.arguments, # type: ignore[union-attr]
},
},
)
assistant_msg["tool_calls"] = tool_call_dicts
messages.append(cast("ChatCompletionMessageParam", assistant_msg))
# Execute each tool call and add results to messages
for tool_call in tool_calls:
tool_name = tool_call.function.name # type: ignore[union-attr]
tool_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
if tool_call_notifier:
result = tool_call_notifier(tool_name, tool_args)
if hasattr(result, "__await__"):
await result # type: ignore[misc]
tool_result = tool_executor(tool_name, tool_args)
messages.append(
cast(
"ChatCompletionMessageParam",
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
},
),
)
continue
# No more tool calls, return the final response
content = message.content
if content:
return content.strip()
return ""
return ""
def image_generation(
prompt: str,
*,