简单函数
基础函数工具
从简单函数创建工具:复制
from agkit.tools import tool
from pydantic import BaseModel, Field
class AddNumbersInput(BaseModel):
a: float = Field(description="First number")
b: float = Field(description="Second number")
add_numbers_tool = tool(
func=lambda input_data: {
"success": True,
"data": {"result": input_data.a + input_data.b}
},
name="add_numbers",
description="Add two numbers together",
schema=AddNumbersInput
)
文本处理工具
创建文本处理工具:复制
from typing import Optional, Literal
from pydantic import BaseModel, Field
class TextProcessOptions(BaseModel):
trim: bool = Field(default=True)
remove_spaces: bool = Field(default=False)
class TextProcessInput(BaseModel):
text: str = Field(description="Input text")
operation: Literal['uppercase', 'lowercase', 'reverse', 'word_count']
options: Optional[TextProcessOptions] = None
def process_text_func(input_data: TextProcessInput):
processed_text = input_data.text
options = input_data.options or TextProcessOptions()
if options.trim:
processed_text = processed_text.strip()
if options.remove_spaces:
processed_text = ''.join(processed_text.split())
if input_data.operation == 'uppercase':
result = processed_text.upper()
elif input_data.operation == 'lowercase':
result = processed_text.lower()
elif input_data.operation == 'reverse':
result = processed_text[::-1]
elif input_data.operation == 'word_count':
result = len(processed_text.split())
return {
"success": True,
"data": {
"original": input_data.text,
"processed": result,
"operation": input_data.operation
}
}
text_process_tool = tool(
func=process_text_func,
name="process_text",
description="Process text with various operations",
schema=TextProcessInput
)
异步操作
HTTP请求工具
创建发起HTTP请求的工具:复制
import aiohttp
import asyncio
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field, HttpUrl
class HttpRequestInput(BaseModel):
url: HttpUrl
method: Literal['GET', 'POST', 'PUT', 'DELETE'] = 'GET'
headers: Optional[Dict[str, str]] = None
body: Optional[Any] = None
timeout: int = Field(default=30000, description="Timeout in milliseconds")
async def http_request_func(input_data: HttpRequestInput):
try:
timeout_seconds = input_data.timeout / 1000
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
headers = {
'Content-Type': 'application/json',
**(input_data.headers or {})
}
async with aiohttp.ClientSession(timeout=timeout) as session:
kwargs = {
'method': input_data.method,
'url': str(input_data.url),
'headers': headers
}
if input_data.body:
kwargs['json'] = input_data.body
async with session.request(**kwargs) as response:
if not response.ok:
return {
"success": False,
"error": f"HTTP {response.status}: {response.reason}",
"error_type": "network"
}
data = await response.json()
return {
"success": True,
"data": {
"status": response.status,
"headers": dict(response.headers),
"body": data
}
}
except asyncio.TimeoutError:
return {
"success": False,
"error": "Request timeout",
"error_type": "network"
}
except Exception as error:
return {
"success": False,
"error": str(error),
"error_type": "execution"
}
http_request_tool = tool(
func=http_request_func,
name="http_request",
description="Make HTTP requests to external APIs",
schema=HttpRequestInput
)
复杂参数
验证与转换
创建具有复杂验证的工具:复制
import uuid
from datetime import datetime
from typing import Optional, List, Literal
from pydantic import BaseModel, Field, EmailStr, validator
from pydantic.validators import str_validator
class UserPreferences(BaseModel):
theme: Literal['light', 'dark'] = 'light'
notifications: bool = True
language: str = Field(default='en', min_length=2, max_length=2)
class UserData(BaseModel):
email: EmailStr
name: str = Field(min_length=2, max_length=50)
age: int = Field(ge=13, le=120)
roles: List[Literal['admin', 'user', 'moderator']] = Field(default=['user'])
preferences: Optional[UserPreferences] = None
class UserManagementInput(BaseModel):
action: Literal['create', 'update', 'delete', 'get']
user_id: Optional[str] = Field(None, description="UUID字符串")
user_data: Optional[UserData] = None
@validator('user_id')
def validate_uuid(cls, v):
if v is not None:
try:
uuid.UUID(v)
except ValueError:
raise ValueError('user_id必须是有效的UUID')
return v
@validator('user_data', always=True)
def validate_user_data_required(cls, v, values):
action = values.get('action')
if action == 'create' and v is None:
raise ValueError('创建操作需要user_data')
return v
@validator('user_id', always=True)
def validate_user_id_required(cls, v, values):
action = values.get('action')
if action in ['update', 'delete', 'get'] and v is None:
raise ValueError('更新/删除/获取操作需要user_id')
return v
async def user_management_func(input_data: UserManagementInput):
try:
action = input_data.action
user_id = input_data.user_id
user_data = input_data.user_data
if action == 'create':
new_user = {
'id': str(uuid.uuid4()),
**user_data.dict(),
'created_at': datetime.now().isoformat()
}
# 模拟数据库保存
await save_user(new_user)
return {
"success": True,
"data": {
"user": new_user,
"message": "用户创建成功"
}
}
elif action == 'update':
existing_user = await get_user(user_id)
if not existing_user:
return {
"success": False,
"error": "用户未找到",
"error_type": "execution"
}
updated_user = {
**existing_user,
**user_data.dict(exclude_unset=True),
'updated_at': datetime.now().isoformat()
}
await save_user(updated_user)
return {
"success": True,
"data": {
"user": updated_user,
"message": "用户更新成功"
}
}
elif action == 'delete':
await delete_user(user_id)
return {
"success": True,
"data": {"message": "用户删除成功"}
}
elif action == 'get':
user = await get_user(user_id)
if not user:
return {
"success": False,
"error": "用户未找到",
"error_type": "execution"
}
return {
"success": True,
"data": {"user": user}
}
else:
return {
"success": False,
"error": "无效操作",
"error_type": "validation"
}
except Exception as error:
return {
"success": False,
"error": str(error),
"error_type": "execution"
}
# 模拟数据库函数
async def save_user(user):
# 实现代码
pass
async def get_user(user_id: str):
# 实现代码
pass
async def delete_user(user_id: str):
# 实现代码
pass
user_management_tool = tool(
func=user_management_func,
name="manage_user",
description="通过验证管理用户账户",
schema=UserManagementInput
)
错误处理
全面的错误处理
实现健壮的错误处理模式:复制
import asyncio
import aiohttp
from typing import Optional
from pydantic import BaseModel, Field, HttpUrl
class RobustApiInput(BaseModel):
endpoint: HttpUrl
retries: int = Field(default=3, ge=0, le=5)
backoff_ms: int = Field(default=1000, ge=100, le=10000)
async def robust_api_func(input_data: RobustApiInput):
last_error: Optional[Exception] = None
for attempt in range(input_data.retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=10)
headers = {'User-Agent': 'AG-Kit Tool/1.0'}
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(str(input_data.endpoint), headers=headers) as response:
if not response.ok:
raise Exception(f"HTTP {response.status}: {response.reason}")
data = await response.json()
return {
"success": True,
"data": {
"result": data,
"attempts": attempt + 1,
"endpoint": str(input_data.endpoint)
}
}
except Exception as error:
last_error = error
# 特定错误不重试
error_msg = str(error)
if '404' in error_msg or '401' in error_msg:
break
# 重试前等待(指数退避)
if attempt < input_data.retries:
delay = input_data.backoff_ms * (2 ** attempt) / 1000
await asyncio.sleep(delay)
return {
"success": False,
"error": f"经过{input_data.retries + 1}次尝试后失败: {str(last_error)}",
"error_type": "network"
}
robust_api_tool = tool(
func=robust
robust_api_tool = tool(
func=robust_api_func,
name="robust_api_call",
description="进行具备全面错误处理的应用程序接口调用",
schema=RobustApiInput
)
输入净化
对输入进行净化和验证:复制
import re
import html
from pydantic import BaseModel, Field
class ProcessUserInputInput(BaseModel):
user_input: str
allow_html: bool = False
max_length: int = 1000
async def process_user_input_func(input_data: ProcessUserInputInput):
try:
user_input = input_data.user_input
allow_html = input_data.allow_html
max_length = input_data.max_length
# 长度验证
if len(user_input) > max_length:
return {
"success": False,
"error": f"输入过长(最多{max_length}个字符)",
"error_type": "validation"
}
# 净化输入
sanitized = user_input.strip()
if not allow_html:
# 移除HTML标签
sanitized = re.sub(r'<[^>]*>', '', sanitized)
# 转义特殊字符
sanitized = html.escape(sanitized, quote=True)
# 检查可疑模式
suspicious_patterns = [
re.compile(r'javascript:', re.IGNORECASE),
re.compile(r'data:text/html', re.IGNORECASE),
re.compile(r'vbscript:', re.IGNORECASE),
re.compile(r'<script', re.IGNORECASE)
]
has_suspicious_content = any(
pattern.search(sanitized) for pattern in suspicious_patterns
)
if has_suspicious_content:
return {
"success": False,
"error": "输入包含可疑内容",
"error_type": "validation"
}
return {
"success": True,
"data": {
"original": user_input,
"sanitized": sanitized,
"length": len(sanitized),
"was_modified": user_input != sanitized
}
}
except Exception as error:
return {
"success": False,
"error": str(error),
"error_type": "execution"
}
sanitized_input_tool = tool(
func=process_user_input_func,
name="process_user_input",
description="对用户输入进行净化处理",
schema=ProcessUserInputInput
)
集成模式
工具组合
组合多个功能工具:复制
import re
from datetime import datetime
from pydantic import BaseModel, EmailStr
# 创建独立工具
class ValidateEmailInput(BaseModel):
email: str
async def validate_email_func(input_data: ValidateEmailInput):
email = input_data.email
is_valid = bool(re.match(r'^[^\s@]+@[^\s@]+\.[^\s@]+, email))
return {
"success": True,
"data": {"email": email, "is_valid": is_valid}
}
validate_email_tool = tool(
func=validate_email_func,
name="validate_email",
description="验证电子邮件地址格式",
schema=ValidateEmailInput
)
class SendEmailInput(BaseModel):
to: EmailStr
subject: str
body: str
async def send_email_func(input_data: SendEmailInput):
# 邮件发送逻辑
return {
"success": True,
"data": {
"message_id": "msg_123",
"sent_at": datetime.now().isoformat()
}
}
send_email_tool = tool(
func=send_email_func,
name="send_email",
description="发送电子邮件",
schema=SendEmailInput
)
# 组合工具
class EmailWorkflowInput(BaseModel):
to: str
subject: str
body: str
async def email_workflow_func(input_data: EmailWorkflowInput):
# 先验证电子邮件
validation_result = await validate_email_func(ValidateEmailInput(email=input_data.to))
if not validation_result["success"] or not validation_result["data"]["is_valid"]:
return {
"success": False,
"error": "无效的电子邮件地址",
"error_type": "validation"
}
# 发送邮件
send_result = await send_email_func(SendEmailInput(
to=input_data.to,
subject=input_data.subject,
body=input_data.body
))
return {
"success": send_result["success"],
"data": {
"validation": validation_result["data"],
"email": send_result.get("data")
},
"error": send_result.get("error"),
"error_type": send_result.get("error_type")
}
email_workflow_tool = tool(
func=email_workflow_func,
name="email_workflow",
description="一次性完成电子邮件验证和发送",
schema=EmailWorkflowInput
)
与Agent配合使用
将功能工具与AG-Kit agent结合使用:复制
import asyncio
from typing import List, Any
from pydantic import BaseModel, Field
# 为数据库工具使用连接池
class DatabasePool:
def __init__(self, max_connections: int = 10):
self.connections: List[Any] = []
self.max_connections = max_connections
async def get_connection(self):
if self.connections:
return self.connections.pop()
return await self.create_new_connection()
def release_connection(self, conn):
if len(self.connections) < self.max_connections:
self.connections.append(conn)
else:
conn.close()
async def create_new_connection(self):
# Implementation here
pass
db_pool = DatabasePool()
class OptimizedDbQueryInput(BaseModel):
query: str
params: List[Any] = Field(default_factory=list)
async def optimized_db_query_func(input_data: OptimizedDbQueryInput):
conn = await db_pool.get_connection()
try:
result = await conn.query(input_data.query, input_data.params)
return {"success": True, "data": result}
finally:
db_pool.release_connection(conn)
optimized_db_tool = tool(
func=optimized_db_query_func,
name="optimized_db_query",
description="Database query with connection pooling",
schema=OptimizedDbQueryInput
)
缓存实现
为耗时操作实现缓存:复制
import time
import aiohttp
from typing import Dict, Any
from pydantic import BaseModel, HttpUrl
# 全局缓存字典
cache: Dict[str, Dict[str, Any]] = {}
class CachedApiInput(BaseModel):
url: HttpUrl
cache_ttl: int = 300000 # 5分钟(毫秒)
async def cached_api_func(input_data: CachedApiInput):
url_str = str(input_data.url)
cached = cache.get(url_str)
current_time = int(time.time() * 1000) # 当前时间(毫秒)
if cached and current_time - cached["timestamp"] < input_data.cache_ttl:
return {
"success": True,
"data": {**cached["data"], "from_cache": True}
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url_str) as response:
if not response.ok:
raise Exception(f"HTTP {response.status}: {response.reason}")
data = await response.json()
cache[url_str] = {
"data": data,
"timestamp": current_time
}
return {
"success": True,
"data": {**data, "from_cache": False}
}
except Exception as error:
return {
"success": False,
"error": str(error),
"error_type": "network"
}
cached_api_tool = tool(
func=cached_api_func,
name="cached_api_call",
description="API call with caching",
schema=CachedApiInput
)