工具创建模式
函数工具
从TypeScript/JavaScript函数创建工具:工具包
将相关工具组织成可重用的工具包:快速入门
基础工具创建
创建一个简单的自定义工具:复制
from agkit.tools import tool
from pydantic import BaseModel, Field
from typing import Literal
import aiohttp
class WeatherInput(BaseModel):
city: str = Field(description="城市名称")
units: Literal['celsius', 'fahrenheit'] = Field(default='celsius')
async def get_weather_func(input_data: WeatherInput):
try:
# 获取天气数据
async with aiohttp.ClientSession() as session:
url = f"https://api.weather.com/v1/current?city={input_data.city}&units={input_data.units}"
async with session.get(url) as response:
if not response.ok:
return {
"success": False,
"error": f"Weather API error: {response.reason}",
"error_type": "network"
}
data = await response.json()
return {
"success": True,
"data": {
"city": input_data.city,
"temperature": data["temperature"],
"condition": data["condition"],
"humidity": data["humidity"],
"units": input_data.units
}
}
except Exception as error:
return {
"success": False,
"error": str(error),
"error_type": "execution"
}
weather_tool = tool(
func=get_weather_func,
name="get_weather",
description="获取城市的当前天气信息",
schema=WeatherInput
)
工具集成
将自定义工具与代理一起使用:复制
from agkit.core import Agent
from agkit.providers.openai import OpenAIProvider
import os
provider = OpenAIProvider(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4"
)
agent = Agent(
name="weather-agent",
model=provider,
tools=[weather_tool],
instructions="您可以查询任何城市的天气信息。"
)
response = await agent.run(
input="旧金山的天气如何?"
)
工具架构
BaseTool接口
所有工具都实现了标准化接口:复制
from typing import Any, Generic, TypeVar, Protocol
from pydantic import BaseModel
from abc import ABC, abstractmethod
TInput = TypeVar('TInput')
TOutput = TypeVar('TOutput')
class BaseTool(ABC, Generic[TInput, TOutput]):
name: str
description: str
schema: type[BaseModel]
@abstractmethod
async def invoke(self, input_data: TInput) -> 'ToolResult[TOutput]':
pass
工具结果结构
所有工具的一致结果格式:复制
from typing import Generic, TypeVar, Optional, Literal
from dataclasses import dataclass
T = TypeVar('T')
@dataclass
class ToolResult(Generic[T]):
success: bool
data: Optional[T] = None
error: Optional[str] = None
error_type: Optional[Literal["validation", "execution", "permission", "network"]] = None
execution_time: Optional[float] = None
模式验证
使用Zod进行输入验证:复制
from pydantic import BaseModel, Field, UUID4
from typing import Optional, Literal, Union, Dict, Any, List
from datetime import datetime
from enum import Enum
class Priority(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
class Metadata(BaseModel):
created: datetime
author: str
class SimpleConfig(BaseModel):
type: Literal["simple"]
value: str
class ComplexConfig(BaseModel):
type: Literal["complex"]
settings: Dict[str, Any]
class ComplexSchema(BaseModel):
# 必填字段
id: UUID4
name: str = Field(min_length=1, max_length=100)
# 带默认值的可选字段
priority: Priority = Priority.MEDIUM
tags: List[str] = Field(default_factory=list)
# 嵌套对象
metadata: Optional[Metadata] = None
# 条件验证
config: Union[SimpleConfig, ComplexConfig]
工具包架构
自定义工具包
创建自定义工具包以组织相关工具:复制
from agkit.tools import BaseToolkit, tool
from pydantic import BaseModel, Field
from typing import Literal
class WeatherToolkit(BaseToolkit):
def __init__(self):
super().__init__(
name='weather-toolkit',
description='Comprehensive weather information toolkit'
)
async def on_initialize(self):
# Add weather tools
self.add_tool(self.create_current_weather_tool())
self.add_tool(self.create_forecast_tool())
self.add_tool(self.create_historical_tool())
def create_current_weather_tool(self):
return tool(
func=lambda city, units='celsius': {'success': True, 'data': {'temperature': 22, 'condition': 'sunny'}},
name='current_weather',
description='Get current weather conditions',
schema=BaseModel(
city=str,
units=Literal['celsius', 'fahrenheit'] = 'celsius'
)
)
def create_forecast_tool(self):
return tool(
func=lambda city, days=5: {'success': True, 'data': {'forecast': []}},
name='weather_forecast',
description='Get weather forecast',
schema=BaseModel(
city=str,
days=int = Field(ge=1, le=14, default=5)
)
)
def create_historical_tool(self):
return tool(
func=lambda city, date: {'success': True, 'data': {'historical': []}},
name='weather_historical',
description='Get historical weather data',
schema=BaseModel(
city=str,
date=str
)
)
# 使用工具包
weather_kit = WeatherToolkit()
await weather_kit.initialize()
tools = weather_kit.get_tools()
使用自定义工具包
初始化自定义工具包并与Agent一起使用:复制
# Create and initialize toolkit
weather_toolkit = WeatherToolkit()
await weather_toolkit.initialize()
# Get all tools from toolkit
weather_tools = weather_toolkit.get_tools()
工具包管理
使用工具包管理器进行集中式工具包管理:复制
from agkit.tools import ToolkitManager
toolkit_manager = ToolkitManager()
# Register toolkit
toolkit_manager.register(weather_toolkit)
# Get toolkit by name
toolkit = toolkit_manager.get_toolkit('weather-toolkit')
# Get all tools from all registered toolkits
all_tools = toolkit_manager.get_all_tools()
# Find specific tool across all toolkits
current_weather_tools = toolkit_manager.find_tool('current_weather')
# Initialize all registered toolkits
await toolkit_manager.initialize_all()
# Cleanup all toolkits
await toolkit_manager.destroy_all()
工具包事件
监听工具包生命周期事件:复制
def event_handler(event):
if event.type == 'toolkit_initialized':
print(f"Toolkit {event.toolkit.name} initialized")
elif event.type == 'tool_added':
print(f"Tool {event.tool.name} added")
elif event.type == 'tool_executed':
print(f"Tool {event.tool_name} executed")
elif event.type == 'toolkit_destroyed':
print(f"Toolkit {event.toolkit.name} destroyed")
weather_toolkit.add_event_listener(event_handler)
工具测试
单元测试
全面测试自定义工具:复制
import pytest
class TestWeatherTool:
@pytest.mark.asyncio
async def test_should_return_weather_data_for_valid_city(self):
result = await weather_tool.invoke({
"city": "San Francisco",
"units": "celsius"
})
assert result.success is True
assert "temperature" in result.data
assert result.data["city"] == "San Francisco"
assert result.data["units"] == "celsius"
@pytest.mark.asyncio
async def test_should_handle_invalid_city_gracefully(self):
result = await weather_tool.invoke({
"city": "InvalidCity123",
"units": "celsius"
})
assert result.success is False
assert result.error_type == "network"
@pytest.mark.asyncio
async def test_should_validate_input_schema(self):
result = await weather_tool.invoke({
"city": "", # Invalid empty city
"units": "celsius"
})
assert result.success is False
assert result.error_type == "validation"
工具包测试
全面测试自定义工具包:复制
import pytest
from typing import List, Any
class TestWeatherToolkit:
@pytest.fixture(autouse=True)
async def setup_and_teardown(self):
self.weather_toolkit = WeatherToolkit()
await self.weather_toolkit.initialize()
yield
await self.weather_toolkit.destroy()
def test_should_initialize_with_correct_tools(self):
tool_names = self.weather_toolkit.get_tool_names()
assert 'current_weather' in tool_names
assert 'weather_forecast' in tool_names
assert 'historical_weather' in tool_names
assert len(self.weather_toolkit.get_tools()) == 3
@pytest.mark.asyncio
async def test_should_execute_tools_correctly(self):
result = await self.weather_toolkit.invoke_tool('current_weather', {
'city': 'San Francisco',
'units': 'celsius'
})
assert result.success is True
assert 'temperature' in result.data
@pytest.mark.asyncio
async def test_should_handle_batch_tool_execution(self):
results = await self.weather_toolkit.invoke_tools([
{'tool_name': 'current_weather', 'input': {'city': 'Tokyo'}},
{'tool_name': 'weather_forecast', 'input': {'city': 'Tokyo', 'days': 3}}
])
assert len(results) == 2
assert all(r.success for r in results)
def test_should_validate_toolkit_integrity(self):
validation = self.weather_toolkit.validate()
assert validation.valid is True
assert len(validation.errors) == 0
@pytest.mark.asyncio
async def test_should_emit_events_correctly(self):
events: List[Any] = []
def event_handler(event):
events.append(event)
new_toolkit = WeatherToolkit()
new_toolkit.add_event_listener(event_handler)
await new_toolkit.initialize()
await new_toolkit.invoke_tool('current_weather', {'city': 'London'})
await new_toolkit.destroy()
assert any(e.type == 'toolkit_initialized' for e in events)
assert any(e.type == 'tool_executed' for e in events)
assert any(e.type == 'toolkit_destroyed' for e in events)
性能优化
缓存
为昂贵的操作实现缓存:复制
from typing import Dict, Any
import time
cache: Dict[str, Dict[str, Any]] = {}
class CachedWeatherInput(BaseModel):
city: str
units: Literal['celsius', 'fahrenheit'] = 'celsius'
async def cached_weather_func(input_data: CachedWeatherInput):
cache_key = f"{input_data.city}-{input_data.units}"
cached = cache.get(cache_key)
if cached and time.time() - cached['timestamp'] < 300: # 5 minutes
return {
"success": True,
"data": {**cached['data'], "cached": True}
}
result = await fetch_weather_data(input_data.city, input_data.units)
if result['success']:
cache[cache_key] = {
'data': result['data'],
'timestamp': time.time()
}
return result
cached_weather_tool = tool(
func=cached_weather_func,
name="get_weather_cached",
description="Get weather with caching",
schema=CachedWeatherInput
)
连接池
重用连接以获得更好的性能:复制
from typing import List, Any
import asyncio
class DatabaseConnectionPool:
def __init__(self):
self.pool: List[Any] = []
self.max_size = 10
async def get_connection(self):
if self.pool:
return self.pool.pop()
return await self.create_new_connection()
def release_connection(self, connection: Any):
if len(self.pool) < self.max_size:
self.pool.append(connection)
else:
connection.close()
async def create_new_connection(self):
# Implementation depends on your database
pass
pool = DatabaseConnectionPool()
class DbQueryInput(BaseModel):
query: str
parameters: List[Any] = Field(default_factory=list)
async def optimized_db_func(input_data: DbQueryInput):
connection = await pool.get_connection()
try:
result = await connection.query(input_data.query, input_data.parameters)
return {"success": True, "data": result}
finally:
pool.release_connection(connection)
optimized_db_tool = tool(
func=optimized_db_func,
name="optimized_db_query",
description="Database query with connection pooling",
schema=DbQueryInput
)