跳转到主要内容
AG-Kit 提供了一个强大的框架,用于创建扩展AI代理能力的自定义工具。通过完整的类型安全和无缝集成,为您的特定用例构建专业化工具。

工具创建模式

函数工具

从TypeScript/JavaScript函数创建工具:

工具包

将相关工具组织成可重用的工具包:

快速入门

基础工具创建

创建一个简单的自定义工具:
from agkit.tools import tool
from pydantic import BaseModel, Field
from typing import Literal
import aiohttp

class WeatherInput(BaseModel):
    city: str = Field(description="城市名称")
    units: Literal['celsius', 'fahrenheit'] = Field(default='celsius')

async def get_weather_func(input_data: WeatherInput):
    try:
        # 获取天气数据
        async with aiohttp.ClientSession() as session:
            url = f"https://api.weather.com/v1/current?city={input_data.city}&units={input_data.units}"
            async with session.get(url) as response:
                if not response.ok:
                    return {
                        "success": False,
                        "error": f"Weather API error: {response.reason}",
                        "error_type": "network"
                    }

                data = await response.json()

                return {
                    "success": True,
                    "data": {
                        "city": input_data.city,
                        "temperature": data["temperature"],
                        "condition": data["condition"],
                        "humidity": data["humidity"],
                        "units": input_data.units
                    }
                }
    except Exception as error:
        return {
            "success": False,
            "error": str(error),
            "error_type": "execution"
        }

weather_tool = tool(
    func=get_weather_func,
    name="get_weather",
    description="获取城市的当前天气信息",
    schema=WeatherInput
)

工具集成

将自定义工具与代理一起使用:
from agkit.core import Agent
from agkit.providers.openai import OpenAIProvider
import os

provider = OpenAIProvider(
    api_key=os.getenv("OPENAI_API_KEY"),
    model="gpt-4"
)

agent = Agent(
    name="weather-agent",
    model=provider,
    tools=[weather_tool],
    instructions="您可以查询任何城市的天气信息。"
)

response = await agent.run(
    input="旧金山的天气如何?"
)

工具架构

BaseTool接口

所有工具都实现了标准化接口:
from typing import Any, Generic, TypeVar, Protocol
from pydantic import BaseModel
from abc import ABC, abstractmethod

TInput = TypeVar('TInput')
TOutput = TypeVar('TOutput')

class BaseTool(ABC, Generic[TInput, TOutput]):
    name: str
    description: str
    schema: type[BaseModel]

    @abstractmethod
    async def invoke(self, input_data: TInput) -> 'ToolResult[TOutput]':
        pass

工具结果结构

所有工具的一致结果格式:
from typing import Generic, TypeVar, Optional, Literal
from dataclasses import dataclass

T = TypeVar('T')

@dataclass
class ToolResult(Generic[T]):
    success: bool
    data: Optional[T] = None
    error: Optional[str] = None
    error_type: Optional[Literal["validation", "execution", "permission", "network"]] = None
    execution_time: Optional[float] = None

模式验证

使用Zod进行输入验证:
from pydantic import BaseModel, Field, UUID4
from typing import Optional, Literal, Union, Dict, Any, List
from datetime import datetime
from enum import Enum

class Priority(str, Enum):
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"

class Metadata(BaseModel):
    created: datetime
    author: str

class SimpleConfig(BaseModel):
    type: Literal["simple"]
    value: str

class ComplexConfig(BaseModel):
    type: Literal["complex"]
    settings: Dict[str, Any]

class ComplexSchema(BaseModel):
    # 必填字段
    id: UUID4
    name: str = Field(min_length=1, max_length=100)

    # 带默认值的可选字段
    priority: Priority = Priority.MEDIUM
    tags: List[str] = Field(default_factory=list)

    # 嵌套对象
    metadata: Optional[Metadata] = None

    # 条件验证
    config: Union[SimpleConfig, ComplexConfig]

工具包架构

自定义工具包

创建自定义工具包以组织相关工具:
from agkit.tools import BaseToolkit, tool
from pydantic import BaseModel, Field
from typing import Literal

class WeatherToolkit(BaseToolkit):
    def __init__(self):
        super().__init__(
            name='weather-toolkit',
            description='Comprehensive weather information toolkit'
        )

    async def on_initialize(self):
        # Add weather tools
        self.add_tool(self.create_current_weather_tool())
        self.add_tool(self.create_forecast_tool())
        self.add_tool(self.create_historical_tool())

    def create_current_weather_tool(self):
        return tool(
            func=lambda city, units='celsius': {'success': True, 'data': {'temperature': 22, 'condition': 'sunny'}},
            name='current_weather',
            description='Get current weather conditions',
            schema=BaseModel(
                city=str,
                units=Literal['celsius', 'fahrenheit'] = 'celsius'
            )
        )

    def create_forecast_tool(self):
        return tool(
            func=lambda city, days=5: {'success': True, 'data': {'forecast': []}},
            name='weather_forecast',
            description='Get weather forecast',
            schema=BaseModel(
                city=str,
                days=int = Field(ge=1, le=14, default=5)
            )
        )

    def create_historical_tool(self):
        return tool(
            func=lambda city, date: {'success': True, 'data': {'historical': []}},
            name='weather_historical',
            description='Get historical weather data',
            schema=BaseModel(
                city=str,
                date=str
            )
        )

# 使用工具包
weather_kit = WeatherToolkit()
await weather_kit.initialize()
tools = weather_kit.get_tools()

使用自定义工具包

初始化自定义工具包并与Agent一起使用:

# Create and initialize toolkit
weather_toolkit = WeatherToolkit()
await weather_toolkit.initialize()

# Get all tools from toolkit
weather_tools = weather_toolkit.get_tools()

工具包管理

使用工具包管理器进行集中式工具包管理:
from agkit.tools import ToolkitManager

toolkit_manager = ToolkitManager()
# Register toolkit
toolkit_manager.register(weather_toolkit)

# Get toolkit by name
toolkit = toolkit_manager.get_toolkit('weather-toolkit')

# Get all tools from all registered toolkits
all_tools = toolkit_manager.get_all_tools()

# Find specific tool across all toolkits
current_weather_tools = toolkit_manager.find_tool('current_weather')

# Initialize all registered toolkits
await toolkit_manager.initialize_all()

# Cleanup all toolkits
await toolkit_manager.destroy_all()

工具包事件

监听工具包生命周期事件:
def event_handler(event):
    if event.type == 'toolkit_initialized':
        print(f"Toolkit {event.toolkit.name} initialized")
    elif event.type == 'tool_added':
        print(f"Tool {event.tool.name} added")
    elif event.type == 'tool_executed':
        print(f"Tool {event.tool_name} executed")
    elif event.type == 'toolkit_destroyed':
        print(f"Toolkit {event.toolkit.name} destroyed")

weather_toolkit.add_event_listener(event_handler)

工具测试

单元测试

全面测试自定义工具:
import pytest

class TestWeatherTool:
    @pytest.mark.asyncio
    async def test_should_return_weather_data_for_valid_city(self):
        result = await weather_tool.invoke({
            "city": "San Francisco",
            "units": "celsius"
        })

        assert result.success is True
        assert "temperature" in result.data
        assert result.data["city"] == "San Francisco"
        assert result.data["units"] == "celsius"

    @pytest.mark.asyncio
    async def test_should_handle_invalid_city_gracefully(self):
        result = await weather_tool.invoke({
            "city": "InvalidCity123",
            "units": "celsius"
        })

        assert result.success is False
        assert result.error_type == "network"

    @pytest.mark.asyncio
    async def test_should_validate_input_schema(self):
        result = await weather_tool.invoke({
            "city": "",  # Invalid empty city
            "units": "celsius"
        })

        assert result.success is False
        assert result.error_type == "validation"

工具包测试

全面测试自定义工具包:
import pytest
from typing import List, Any

class TestWeatherToolkit:
    @pytest.fixture(autouse=True)
    async def setup_and_teardown(self):
        self.weather_toolkit = WeatherToolkit()
        await self.weather_toolkit.initialize()
        yield
        await self.weather_toolkit.destroy()

    def test_should_initialize_with_correct_tools(self):
        tool_names = self.weather_toolkit.get_tool_names()
        assert 'current_weather' in tool_names
        assert 'weather_forecast' in tool_names
        assert 'historical_weather' in tool_names
        assert len(self.weather_toolkit.get_tools()) == 3

    @pytest.mark.asyncio
    async def test_should_execute_tools_correctly(self):
        result = await self.weather_toolkit.invoke_tool('current_weather', {
            'city': 'San Francisco',
            'units': 'celsius'
        })

        assert result.success is True
        assert 'temperature' in result.data

    @pytest.mark.asyncio
    async def test_should_handle_batch_tool_execution(self):
        results = await self.weather_toolkit.invoke_tools([
            {'tool_name': 'current_weather', 'input': {'city': 'Tokyo'}},
            {'tool_name': 'weather_forecast', 'input': {'city': 'Tokyo', 'days': 3}}
        ])

        assert len(results) == 2
        assert all(r.success for r in results)

    def test_should_validate_toolkit_integrity(self):
        validation = self.weather_toolkit.validate()
        assert validation.valid is True
        assert len(validation.errors) == 0

    @pytest.mark.asyncio
    async def test_should_emit_events_correctly(self):
        events: List[Any] = []

        def event_handler(event):
            events.append(event)

        new_toolkit = WeatherToolkit()
        new_toolkit.add_event_listener(event_handler)

        await new_toolkit.initialize()
        await new_toolkit.invoke_tool('current_weather', {'city': 'London'})
        await new_toolkit.destroy()

        assert any(e.type == 'toolkit_initialized' for e in events)
        assert any(e.type == 'tool_executed' for e in events)
        assert any(e.type == 'toolkit_destroyed' for e in events)

性能优化

缓存

为昂贵的操作实现缓存:
from typing import Dict, Any
import time

cache: Dict[str, Dict[str, Any]] = {}

class CachedWeatherInput(BaseModel):
    city: str
    units: Literal['celsius', 'fahrenheit'] = 'celsius'

async def cached_weather_func(input_data: CachedWeatherInput):
    cache_key = f"{input_data.city}-{input_data.units}"
    cached = cache.get(cache_key)

    if cached and time.time() - cached['timestamp'] < 300:  # 5 minutes
        return {
            "success": True,
            "data": {**cached['data'], "cached": True}
        }

    result = await fetch_weather_data(input_data.city, input_data.units)

    if result['success']:
        cache[cache_key] = {
            'data': result['data'],
            'timestamp': time.time()
        }

    return result

cached_weather_tool = tool(
    func=cached_weather_func,
    name="get_weather_cached",
    description="Get weather with caching",
    schema=CachedWeatherInput
)

连接池

重用连接以获得更好的性能:
from typing import List, Any
import asyncio

class DatabaseConnectionPool:
    def __init__(self):
        self.pool: List[Any] = []
        self.max_size = 10

    async def get_connection(self):
        if self.pool:
            return self.pool.pop()
        return await self.create_new_connection()

    def release_connection(self, connection: Any):
        if len(self.pool) < self.max_size:
            self.pool.append(connection)
        else:
            connection.close()

    async def create_new_connection(self):
        # Implementation depends on your database
        pass

pool = DatabaseConnectionPool()

class DbQueryInput(BaseModel):
    query: str
    parameters: List[Any] = Field(default_factory=list)

async def optimized_db_func(input_data: DbQueryInput):
    connection = await pool.get_connection()

    try:
        result = await connection.query(input_data.query, input_data.parameters)
        return {"success": True, "data": result}
    finally:
        pool.release_connection(connection)

optimized_db_tool = tool(
    func=optimized_db_func,
    name="optimized_db_query",
    description="Database query with connection pooling",
    schema=DbQueryInput
)

后续步骤