跳转到主要内容
将TypeScript/JavaScript函数转换为AI可访问的工具,自动生成模式并保持类型安全。这是创建自定义工具最直接的方式。

简单函数

基础函数工具

从简单函数创建工具:
from agkit.tools import tool
from pydantic import BaseModel, Field

class AddNumbersInput(BaseModel):
    a: float = Field(description="First number")
    b: float = Field(description="Second number")

add_numbers_tool = tool(
    func=lambda input_data: {
        "success": True,
        "data": {"result": input_data.a + input_data.b}
    },
    name="add_numbers",
    description="Add two numbers together",
    schema=AddNumbersInput
)

文本处理工具

创建文本处理工具:
from typing import Optional, Literal
from pydantic import BaseModel, Field

class TextProcessOptions(BaseModel):
    trim: bool = Field(default=True)
    remove_spaces: bool = Field(default=False)

class TextProcessInput(BaseModel):
    text: str = Field(description="Input text")
    operation: Literal['uppercase', 'lowercase', 'reverse', 'word_count']
    options: Optional[TextProcessOptions] = None

def process_text_func(input_data: TextProcessInput):
    processed_text = input_data.text
    options = input_data.options or TextProcessOptions()

    if options.trim:
        processed_text = processed_text.strip()

    if options.remove_spaces:
        processed_text = ''.join(processed_text.split())

    if input_data.operation == 'uppercase':
        result = processed_text.upper()
    elif input_data.operation == 'lowercase':
        result = processed_text.lower()
    elif input_data.operation == 'reverse':
        result = processed_text[::-1]
    elif input_data.operation == 'word_count':
        result = len(processed_text.split())

    return {
        "success": True,
        "data": {
            "original": input_data.text,
            "processed": result,
            "operation": input_data.operation
        }
    }

text_process_tool = tool(
    func=process_text_func,
    name="process_text",
    description="Process text with various operations",
    schema=TextProcessInput
)

异步操作

HTTP请求工具

创建发起HTTP请求的工具:
import aiohttp
import asyncio
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field, HttpUrl

class HttpRequestInput(BaseModel):
    url: HttpUrl
    method: Literal['GET', 'POST', 'PUT', 'DELETE'] = 'GET'
    headers: Optional[Dict[str, str]] = None
    body: Optional[Any] = None
    timeout: int = Field(default=30000, description="Timeout in milliseconds")

async def http_request_func(input_data: HttpRequestInput):
    try:
        timeout_seconds = input_data.timeout / 1000
        timeout = aiohttp.ClientTimeout(total=timeout_seconds)

        headers = {
            'Content-Type': 'application/json',
            **(input_data.headers or {})
        }

        async with aiohttp.ClientSession(timeout=timeout) as session:
            kwargs = {
                'method': input_data.method,
                'url': str(input_data.url),
                'headers': headers
            }

            if input_data.body:
                kwargs['json'] = input_data.body

            async with session.request(**kwargs) as response:
                if not response.ok:
                    return {
                        "success": False,
                        "error": f"HTTP {response.status}: {response.reason}",
                        "error_type": "network"
                    }

                data = await response.json()

                return {
                    "success": True,
                    "data": {
                        "status": response.status,
                        "headers": dict(response.headers),
                        "body": data
                    }
                }
    except asyncio.TimeoutError:
        return {
            "success": False,
            "error": "Request timeout",
            "error_type": "network"
        }
    except Exception as error:
        return {
            "success": False,
            "error": str(error),
            "error_type": "execution"
        }

http_request_tool = tool(
    func=http_request_func,
    name="http_request",
    description="Make HTTP requests to external APIs",
    schema=HttpRequestInput
)

复杂参数

验证与转换

创建具有复杂验证的工具:
import uuid
from datetime import datetime
from typing import Optional, List, Literal
from pydantic import BaseModel, Field, EmailStr, validator
from pydantic.validators import str_validator

class UserPreferences(BaseModel):
    theme: Literal['light', 'dark'] = 'light'
    notifications: bool = True
    language: str = Field(default='en', min_length=2, max_length=2)

class UserData(BaseModel):
    email: EmailStr
    name: str = Field(min_length=2, max_length=50)
    age: int = Field(ge=13, le=120)
    roles: List[Literal['admin', 'user', 'moderator']] = Field(default=['user'])
    preferences: Optional[UserPreferences] = None

class UserManagementInput(BaseModel):
    action: Literal['create', 'update', 'delete', 'get']
    user_id: Optional[str] = Field(None, description="UUID字符串")
    user_data: Optional[UserData] = None

    @validator('user_id')
    def validate_uuid(cls, v):
        if v is not None:
            try:
                uuid.UUID(v)
            except ValueError:
                raise ValueError('user_id必须是有效的UUID')
        return v

    @validator('user_data', always=True)
    def validate_user_data_required(cls, v, values):
        action = values.get('action')
        if action == 'create' and v is None:
            raise ValueError('创建操作需要user_data')
        return v

    @validator('user_id', always=True)
    def validate_user_id_required(cls, v, values):
        action = values.get('action')
        if action in ['update', 'delete', 'get'] and v is None:
            raise ValueError('更新/删除/获取操作需要user_id')
        return v

async def user_management_func(input_data: UserManagementInput):
    try:
        action = input_data.action
        user_id = input_data.user_id
        user_data = input_data.user_data

        if action == 'create':
            new_user = {
                'id': str(uuid.uuid4()),
                **user_data.dict(),
                'created_at': datetime.now().isoformat()
            }
            
            # 模拟数据库保存
            await save_user(new_user)
            
            return {
                "success": True,
                "data": {
                    "user": new_user,
                    "message": "用户创建成功"
                }
            }
            
        elif action == 'update':
            existing_user = await get_user(user_id)
            if not existing_user:
                return {
                    "success": False,
                    "error": "用户未找到",
                    "error_type": "execution"
                }
            
            updated_user = {
                **existing_user,
                **user_data.dict(exclude_unset=True),
                'updated_at': datetime.now().isoformat()
            }
            
            await save_user(updated_user)
            
            return {
                "success": True,
                "data": {
                    "user": updated_user,
                    "message": "用户更新成功"
                }
            }
            
        elif action == 'delete':
            await delete_user(user_id)
            return {
                "success": True,
                "data": {"message": "用户删除成功"}
            }
            
        elif action == 'get':
            user = await get_user(user_id)
            if not user:
                return {
                    "success": False,
                    "error": "用户未找到",
                    "error_type": "execution"
                }
            
            return {
                "success": True,
                "data": {"user": user}
            }
            
        else:
            return {
                "success": False,
                "error": "无效操作",
                "error_type": "validation"
            }
            
    except Exception as error:
        return {
            "success": False,
            "error": str(error),
            "error_type": "execution"
        }

# 模拟数据库函数
async def save_user(user):
    # 实现代码
    pass

async def get_user(user_id: str):
    # 实现代码
    pass

async def delete_user(user_id: str):
    # 实现代码
    pass

user_management_tool = tool(
    func=user_management_func,
    name="manage_user",
    description="通过验证管理用户账户",
    schema=UserManagementInput
)

错误处理

全面的错误处理

实现健壮的错误处理模式:
import asyncio
import aiohttp
from typing import Optional
from pydantic import BaseModel, Field, HttpUrl

class RobustApiInput(BaseModel):
    endpoint: HttpUrl
    retries: int = Field(default=3, ge=0, le=5)
    backoff_ms: int = Field(default=1000, ge=100, le=10000)

async def robust_api_func(input_data: RobustApiInput):
    last_error: Optional[Exception] = None
    
    for attempt in range(input_data.retries + 1):
        try:
            timeout = aiohttp.ClientTimeout(total=10)
            headers = {'User-Agent': 'AG-Kit Tool/1.0'}
            
            async with aiohttp.ClientSession(timeout=timeout) as session:
                async with session.get(str(input_data.endpoint), headers=headers) as response:
                    if not response.ok:
                        raise Exception(f"HTTP {response.status}: {response.reason}")
                    
                    data = await response.json()
                    
                    return {
                        "success": True,
                        "data": {
                            "result": data,
                            "attempts": attempt + 1,
                            "endpoint": str(input_data.endpoint)
                        }
                    }
                    
        except Exception as error:
            last_error = error
            
            # 特定错误不重试
            error_msg = str(error)
            if '404' in error_msg or '401' in error_msg:
                break
            
            # 重试前等待(指数退避)
            if attempt < input_data.retries:
                delay = input_data.backoff_ms * (2 ** attempt) / 1000
                await asyncio.sleep(delay)
    
    return {
        "success": False,
        "error": f"经过{input_data.retries + 1}次尝试后失败: {str(last_error)}",
        "error_type": "network"
    }

robust_api_tool = tool(
    func=robust
robust_api_tool = tool(
    func=robust_api_func,
    name="robust_api_call",
    description="进行具备全面错误处理的应用程序接口调用",
    schema=RobustApiInput
)

输入净化

对输入进行净化和验证:
import re
import html
from pydantic import BaseModel, Field

class ProcessUserInputInput(BaseModel):
    user_input: str
    allow_html: bool = False
    max_length: int = 1000

async def process_user_input_func(input_data: ProcessUserInputInput):
    try:
        user_input = input_data.user_input
        allow_html = input_data.allow_html
        max_length = input_data.max_length
        
        # 长度验证
        if len(user_input) > max_length:
            return {
                "success": False,
                "error": f"输入过长(最多{max_length}个字符)",
                "error_type": "validation"
            }
        
        # 净化输入
        sanitized = user_input.strip()
        
        if not allow_html:
            # 移除HTML标签
            sanitized = re.sub(r'<[^>]*>', '', sanitized)
            
            # 转义特殊字符
            sanitized = html.escape(sanitized, quote=True)
        
        # 检查可疑模式
        suspicious_patterns = [
            re.compile(r'javascript:', re.IGNORECASE),
            re.compile(r'data:text/html', re.IGNORECASE),
            re.compile(r'vbscript:', re.IGNORECASE),
            re.compile(r'<script', re.IGNORECASE)
        ]
        
        has_suspicious_content = any(
            pattern.search(sanitized) for pattern in suspicious_patterns
        )
        
        if has_suspicious_content:
            return {
                "success": False,
                "error": "输入包含可疑内容",
                "error_type": "validation"
            }
        
        return {
            "success": True,
            "data": {
                "original": user_input,
                "sanitized": sanitized,
                "length": len(sanitized),
                "was_modified": user_input != sanitized
            }
        }
        
    except Exception as error:
        return {
            "success": False,
            "error": str(error),
            "error_type": "execution"
        }

sanitized_input_tool = tool(
    func=process_user_input_func,
    name="process_user_input",
    description="对用户输入进行净化处理",
    schema=ProcessUserInputInput
)

集成模式

工具组合

组合多个功能工具:
import re
from datetime import datetime
from pydantic import BaseModel, EmailStr

# 创建独立工具
class ValidateEmailInput(BaseModel):
    email: str

async def validate_email_func(input_data: ValidateEmailInput):
    email = input_data.email
    is_valid = bool(re.match(r'^[^\s@]+@[^\s@]+\.[^\s@]+, email))
    return {
        "success": True,
        "data": {"email": email, "is_valid": is_valid}
    }

validate_email_tool = tool(
    func=validate_email_func,
    name="validate_email",
    description="验证电子邮件地址格式",
    schema=ValidateEmailInput
)

class SendEmailInput(BaseModel):
    to: EmailStr
    subject: str
    body: str

async def send_email_func(input_data: SendEmailInput):
    # 邮件发送逻辑
    return {
        "success": True,
        "data": {
            "message_id": "msg_123",
            "sent_at": datetime.now().isoformat()
        }
    }

send_email_tool = tool(
    func=send_email_func,
    name="send_email",
    description="发送电子邮件",
    schema=SendEmailInput
)

# 组合工具
class EmailWorkflowInput(BaseModel):
    to: str
    subject: str
    body: str

async def email_workflow_func(input_data: EmailWorkflowInput):
    # 先验证电子邮件
    validation_result = await validate_email_func(ValidateEmailInput(email=input_data.to))
    
    if not validation_result["success"] or not validation_result["data"]["is_valid"]:
        return {
            "success": False,
            "error": "无效的电子邮件地址",
            "error_type": "validation"
        }
    
    # 发送邮件
    send_result = await send_email_func(SendEmailInput(
        to=input_data.to,
        subject=input_data.subject,
        body=input_data.body
    ))
    
    return {
        "success": send_result["success"],
        "data": {
            "validation": validation_result["data"],
            "email": send_result.get("data")
        },
        "error": send_result.get("error"),
        "error_type": send_result.get("error_type")
    }

email_workflow_tool = tool(
    func=email_workflow_func,
    name="email_workflow",
    description="一次性完成电子邮件验证和发送",
    schema=EmailWorkflowInput
)

与Agent配合使用

将功能工具与AG-Kit agent结合使用:
import asyncio
from typing import List, Any
from pydantic import BaseModel, Field

# 为数据库工具使用连接池
class DatabasePool:
    def __init__(self, max_connections: int = 10):
        self.connections: List[Any] = []
        self.max_connections = max_connections
    
    async def get_connection(self):
        if self.connections:
            return self.connections.pop()
        return await self.create_new_connection()
    
    def release_connection(self, conn):
        if len(self.connections) < self.max_connections:
            self.connections.append(conn)
        else:
            conn.close()
    
    async def create_new_connection(self):
        # Implementation here
        pass

db_pool = DatabasePool()

class OptimizedDbQueryInput(BaseModel):
    query: str
    params: List[Any] = Field(default_factory=list)

async def optimized_db_query_func(input_data: OptimizedDbQueryInput):
    conn = await db_pool.get_connection()
    try:
        result = await conn.query(input_data.query, input_data.params)
        return {"success": True, "data": result}
    finally:
        db_pool.release_connection(conn)

optimized_db_tool = tool(
    func=optimized_db_query_func,
    name="optimized_db_query",
    description="Database query with connection pooling",
    schema=OptimizedDbQueryInput
)

缓存实现

为耗时操作实现缓存:
import time
import aiohttp
from typing import Dict, Any
from pydantic import BaseModel, HttpUrl

# 全局缓存字典
cache: Dict[str, Dict[str, Any]] = {}

class CachedApiInput(BaseModel):
    url: HttpUrl
    cache_ttl: int = 300000  # 5分钟(毫秒)

async def cached_api_func(input_data: CachedApiInput):
    url_str = str(input_data.url)
    cached = cache.get(url_str)
    
    current_time = int(time.time() * 1000)  # 当前时间(毫秒)
    
    if cached and current_time - cached["timestamp"] < input_data.cache_ttl:
        return {
            "success": True,
            "data": {**cached["data"], "from_cache": True}
        }
    
    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(url_str) as response:
                if not response.ok:
                    raise Exception(f"HTTP {response.status}: {response.reason}")
                
                data = await response.json()
                
                cache[url_str] = {
                    "data": data,
                    "timestamp": current_time
                }
                
                return {
                    "success": True,
                    "data": {**data, "from_cache": False}
                }
    except Exception as error:
        return {
            "success": False,
            "error": str(error),
            "error_type": "network"
        }

cached_api_tool = tool(
    func=cached_api_func,
    name="cached_api_call",
    description="API call with caching",
    schema=CachedApiInput
)

后续步骤