add tool calling
This commit is contained in:
+119
-1
@@ -7,7 +7,7 @@ Allows custom endpoints for each of the above supported functions.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, cast
|
||||
|
||||
import openai
|
||||
import requests
|
||||
@@ -170,6 +170,124 @@ def chat_completion_instruct(
|
||||
return ""
|
||||
|
||||
|
||||
async def chat_completion_with_tools(
|
||||
system_prompt: str,
|
||||
prompts: list[dict[str, str]],
|
||||
tools: list[dict[str, object]],
|
||||
tool_executor: Callable[[str, dict[str, str]], str],
|
||||
*,
|
||||
openai_url: str,
|
||||
openai_api_key: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
max_tool_rounds: int = 5,
|
||||
tool_call_notifier: (
|
||||
Callable[[str, dict[str, str]], None]
|
||||
| Callable[[str, dict[str, str]], Awaitable[None]]
|
||||
| None
|
||||
) = None,
|
||||
) -> str:
|
||||
"""Send a chat completion request with tool support and iterative tool calling.
|
||||
|
||||
Args:
|
||||
system_prompt: The system prompt to use.
|
||||
prompts: List of prompt dicts with role and content.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
tool_executor: A callable that takes (tool_name: str, tool_args: dict) -> str.
|
||||
openai_url: The OpenAI-compatible API URL.
|
||||
openai_api_key: The API key for authentication.
|
||||
model: The model to use for completion.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
max_tool_rounds: Maximum number of tool call rounds before giving up.
|
||||
tool_call_notifier: Optional callback invoked before each tool call
|
||||
with (tool_name, tool_args).
|
||||
|
||||
Returns:
|
||||
The model's final response text, stripped of whitespace.
|
||||
|
||||
"""
|
||||
client = openai.OpenAI(base_url=openai_url, api_key=openai_api_key)
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
cast(
|
||||
"ChatCompletionMessageParam",
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
),
|
||||
]
|
||||
messages.extend(cast("list[ChatCompletionMessageParam]", prompts))
|
||||
|
||||
for _round in range(max_tool_rounds):
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
max_tokens=max_tokens,
|
||||
seed=-1,
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
if not response.choices:
|
||||
return ""
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
# Check if the model wants to call a tool
|
||||
tool_calls = message.tool_calls
|
||||
if tool_calls:
|
||||
assistant_msg: dict[str, object] = {
|
||||
"role": "assistant",
|
||||
"content": message.content or "",
|
||||
}
|
||||
tool_call_dicts: list[dict[str, object]] = []
|
||||
for tool_call in tool_calls:
|
||||
tool_call_dicts.append(
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name, # type: ignore[union-attr]
|
||||
"arguments": tool_call.function.arguments, # type: ignore[union-attr]
|
||||
},
|
||||
},
|
||||
)
|
||||
assistant_msg["tool_calls"] = tool_call_dicts
|
||||
messages.append(cast("ChatCompletionMessageParam", assistant_msg))
|
||||
|
||||
# Execute each tool call and add results to messages
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name # type: ignore[union-attr]
|
||||
tool_args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
||||
|
||||
if tool_call_notifier:
|
||||
result = tool_call_notifier(tool_name, tool_args)
|
||||
if hasattr(result, "__await__"):
|
||||
await result # type: ignore[misc]
|
||||
|
||||
tool_result = tool_executor(tool_name, tool_args)
|
||||
|
||||
messages.append(
|
||||
cast(
|
||||
"ChatCompletionMessageParam",
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result,
|
||||
},
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
# No more tool calls, return the final response
|
||||
content = message.content
|
||||
if content:
|
||||
return content.strip()
|
||||
return ""
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def image_generation(
|
||||
prompt: str,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user