Skip to main content

Base Provider API

The base provider module defines the abstract base class and core interfaces that all model providers must implement. This ensures a consistent API across different LLM services while allowing provider-specific optimizations.

BaseModelProvider

Abstract base class for all model providers.

Class Definition

from abc import ABC, abstractmethod
from ag_kit_py.providers import BaseModelProvider, ModelProviderConfig

class BaseModelProvider(ABC):
    """Abstract base class for model providers."""
    
    def __init__(self, config: ModelProviderConfig):
        """Initialize the model provider.
        
        Args:
            config: Provider configuration
        """
        self.config = config
        self.default_model = config.default_model or self.get_default_model()

Abstract Methods

All provider implementations must implement these methods:

create_completion()

Create a non-streaming chat completion.
@abstractmethod
async def create_completion(
    self, 
    params: ChatCompletionParams
) -> ChatCompletion:
    """Create a chat completion.
    
    Args:
        params: Completion parameters including model, messages, temperature, etc.
        
    Returns:
        ChatCompletion: Complete response with message content and metadata
        
    Raises:
        ModelProviderError: If completion fails
    """
    pass
Example Implementation:
async def create_completion(self, params: ChatCompletionParams) -> ChatCompletion:
    try:
        model = self.get_langchain_model(
            model=params.model,
            temperature=params.temperature
        )
        
        if params.tools:
            formatted_tools = self.format_tools(params.tools)
            model = model.bind_tools(formatted_tools)
        
        response = await model.ainvoke(params.messages)
        
        return ChatCompletion(
            id=f"chatcmpl-{int(time.time())}",
            object="chat.completion",
            created=int(time.time()),
            model=params.model,
            choices=[{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": response.content,
                    "tool_calls": getattr(response, "tool_calls", None)
                },
                "finish_reason": "stop"
            }],
            usage={
                "prompt_tokens": 0,
                "completion_tokens": 0,
                "total_tokens": 0
            }
        )
    except Exception as e:
        raise self._handle_error(e)

create_stream()

Create a streaming chat completion.
@abstractmethod
async def create_stream(
    self, 
    params: ChatCompletionParams
) -> AsyncIterator[ChatCompletionChunk]:
    """Create a streaming chat completion.
    
    Args:
        params: Completion parameters
        
    Yields:
        ChatCompletionChunk: Incremental response chunks
        
    Raises:
        ModelProviderError: If streaming fails
    """
    pass
Example Implementation:
async def create_stream(
    self, 
    params: ChatCompletionParams
) -> AsyncIterator[ChatCompletionChunk]:
    try:
        model = self.get_langchain_model(
            model=params.model,
            temperature=params.temperature
        )
        
        if params.tools:
            formatted_tools = self.format_tools(params.tools)
            model = model.bind_tools(formatted_tools)
        
        chunk_id = f"chatcmpl-{int(time.time())}"
        
        async for chunk in model.astream(params.messages):
            yield ChatCompletionChunk(
                id=chunk_id,
                object="chat.completion.chunk",
                created=int(time.time()),
                model=params.model,
                choices=[{
                    "index": 0,
                    "delta": {
                        "role": "assistant",
                        "content": chunk.content if hasattr(chunk, "content") else str(chunk)
                    },
                    "finish_reason": None
                }]
            )
        
        # Final chunk with finish_reason
        yield ChatCompletionChunk(
            id=chunk_id,
            object="chat.completion.chunk",
            created=int(time.time()),
            model=params.model,
            choices=[{
                "index": 0,
                "delta": {},
                "finish_reason": "stop"
            }]
        )
    except Exception as e:
        raise self._handle_error(e)

get_provider_name()

Get the provider name identifier.
@abstractmethod
def get_provider_name(self) -> str:
    """Get the provider name.
    
    Returns:
        str: Provider name (e.g., "openai", "zhipu")
    """
    pass

get_default_model()

Get the default model for this provider.
@abstractmethod
def get_default_model(self) -> str:
    """Get the default model for this provider.
    
    Returns:
        str: Default model name (e.g., "gpt-4o-mini")
    """
    pass

format_tools()

Format tools for provider-specific API.
@abstractmethod
def format_tools(self, tools: List[Any]) -> List[ToolDefinition]:
    """Format tools for this provider.
    
    Args:
        tools: List of tools in various formats
        
    Returns:
        List[ToolDefinition]: Formatted tool definitions
    """
    pass
Example Implementation:
def format_tools(self, tools: List[Any]) -> List[Dict[str, Any]]:
    formatted_tools = []
    
    for tool in tools:
        if isinstance(tool, dict):
            if "type" in tool and "function" in tool:
                # Already in correct format
                formatted_tools.append(tool)
            elif "name" in tool and "parameters" in tool:
                # Convert from simplified format
                formatted_tools.append({
                    "type": "function",
                    "function": {
                        "name": tool["name"],
                        "description": tool.get("description", ""),
                        "parameters": tool["parameters"]
                    }
                })
        elif isinstance(tool, ToolDefinition):
            formatted_tools.append(tool.model_dump())
    
    return formatted_tools

parse_tool_calls()

Parse tool calls from completion response.
@abstractmethod
def parse_tool_calls(self, response: ChatCompletion) -> List[ToolCall]:
    """Parse tool calls from completion response.
    
    Args:
        response: Chat completion response
        
    Returns:
        List[ToolCall]: Parsed tool calls
    """
    pass

get_langchain_model()

Get a LangChain-compatible model instance.
@abstractmethod
def get_langchain_model(self, **kwargs) -> Any:
    """Get a LangChain-compatible model instance.
    
    This method returns a LangChain ChatModel instance that can be used
    with LangGraph and other LangChain components.
    
    Args:
        **kwargs: Additional model configuration to override defaults
        
    Returns:
        Any: LangChain ChatModel instance
    """
    pass

Concrete Methods

These methods have default implementations but can be overridden:

supports_tools()

Check if provider supports tool/function calling.
def supports_tools(self) -> bool:
    """Check if provider supports tool/function calling.
    
    Returns:
        bool: True if tools are supported (default: True)
    """
    return True

supports_streaming()

Check if provider supports streaming responses.
def supports_streaming(self) -> bool:
    """Check if provider supports streaming responses.
    
    Returns:
        bool: True if streaming is supported (default: True)
    """
    return True

validate_config()

Validate provider configuration.
def validate_config(self, config: ModelProviderConfig) -> bool:
    """Validate provider configuration.
    
    Args:
        config: Configuration to validate
        
    Returns:
        bool: True if configuration is valid
    """
    return bool(config.api_key and config.api_key.strip())

Protected Methods

These helper methods are available for subclasses:

_create_error()

Create a model provider error.
def _create_error(
    self,
    message: str,
    error_type: str = "unknown",
    details: Optional[Any] = None
) -> ModelProviderError:
    """Create a model provider error.
    
    Args:
        message: Error message
        error_type: Type of error
        details: Additional error details
        
    Returns:
        ModelProviderError: Formatted error
    """
    return ModelProviderError(
        message=message,
        error_type=error_type,
        details=details
    )

_is_non_retryable_error()

Check if an error should not be retried.
def _is_non_retryable_error(self, error: Exception) -> bool:
    """Check if an error should not be retried.
    
    Args:
        error: Error to check
        
    Returns:
        bool: True if error should not be retried
    """
    if isinstance(error, ModelProviderError):
        non_retryable_types = ["authentication", "invalid_request"]
        return error.error_type in non_retryable_types
    return False

Data Models

ModelProviderConfig

Configuration dataclass for model providers.
from dataclasses import dataclass
from typing import Optional

@dataclass
class ModelProviderConfig:
    """Configuration for model providers."""
    
    api_key: str                          # Required: API key for authentication
    base_url: Optional[str] = None        # Optional: Custom API endpoint
    default_model: Optional[str] = None   # Optional: Default model name
    temperature: float = 0.7              # Optional: Sampling temperature
    timeout: int = 30000                  # Optional: Request timeout (ms)
    max_retries: int = 3                  # Optional: Maximum retry attempts
    retry_delay: int = 1000               # Optional: Delay between retries (ms)
    proxy: Optional[str] = None           # Optional: Proxy URL
    organization: Optional[str] = None    # Optional: Organization ID
    project: Optional[str] = None         # Optional: Project ID
Example:
from ag_kit_py.providers import ModelProviderConfig

config = ModelProviderConfig(
    api_key="sk-...",
    base_url="https://api.openai.com/v1",
    default_model="gpt-4o-mini",
    temperature=0.7,
    timeout=30000,
    max_retries=3
)

ChatCompletionParams

Parameters for chat completion requests.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

@dataclass
class ChatCompletionParams:
    """Parameters for chat completion requests."""
    
    model: str                                          # Required: Model name
    messages: List[Dict[str, Any]]                      # Required: Conversation messages
    temperature: Optional[float] = 0.7                  # Optional: Sampling temperature
    max_tokens: Optional[int] = None                    # Optional: Maximum tokens to generate
    top_p: Optional[float] = None                       # Optional: Nucleus sampling parameter
    frequency_penalty: Optional[float] = None           # Optional: Frequency penalty
    presence_penalty: Optional[float] = None            # Optional: Presence penalty
    tools: Optional[List[ToolDefinition]] = None        # Optional: Available tools
    tool_choice: Optional[Union[str, Dict]] = "auto"    # Optional: Tool choice strategy
    stream: bool = False                                # Optional: Enable streaming
    stop: Optional[List[str]] = None                    # Optional: Stop sequences
Example:
from ag_kit_py.providers import ChatCompletionParams

params = ChatCompletionParams(
    model="gpt-4o-mini",
    messages=[
        {"role": "system", "content": "You are helpful."},
        {"role": "user", "content": "Hello!"}
    ],
    temperature=0.7,
    max_tokens=1000,
    tools=[
        {
            "name": "get_weather",
            "description": "Get weather",
            "parameters": {"type": "object", "properties": {...}}
        }
    ]
)

ChatCompletion

Chat completion response.
from dataclasses import dataclass, field
from typing import Any, Dict, List

@dataclass
class ChatCompletion:
    """Chat completion response."""
    
    id: str                                    # Unique completion ID
    object: str = "chat.completion"            # Object type
    created: int = 0                           # Creation timestamp
    model: str = ""                            # Model used
    choices: List[Dict[str, Any]] = field(default_factory=list)  # Response choices
    usage: Dict[str, int] = field(default_factory=dict)          # Token usage
Response Structure:
{
    "id": "chatcmpl-123",
    "object": "chat.completion",
    "created": 1677652288,
    "model": "gpt-4o-mini",
    "choices": [{
        "index": 0,
        "message": {
            "role": "assistant",
            "content": "Hello! How can I help you?",
            "tool_calls": None
        },
        "finish_reason": "stop"
    }],
    "usage": {
        "prompt_tokens": 10,
        "completion_tokens": 20,
        "total_tokens": 30
    }
}

ChatCompletionChunk

Streaming chat completion chunk.
from dataclasses import dataclass, field
from typing import Any, Dict, List

@dataclass
class ChatCompletionChunk:
    """Chat completion chunk for streaming responses."""
    
    id: str                                    # Unique completion ID
    object: str = "chat.completion.chunk"      # Object type
    created: int = 0                           # Creation timestamp
    model: str = ""                            # Model used
    choices: List[Dict[str, Any]] = field(default_factory=list)  # Response chunks
Chunk Structure:
{
    "id": "chatcmpl-123",
    "object": "chat.completion.chunk",
    "created": 1677652288,
    "model": "gpt-4o-mini",
    "choices": [{
        "index": 0,
        "delta": {
            "role": "assistant",
            "content": "Hello"
        },
        "finish_reason": None
    }]
}

ToolDefinition

Tool definition for model providers.
from pydantic import BaseModel, Field
from typing import Any, Dict

class ToolDefinition(BaseModel):
    """Tool definition for model providers."""
    
    type: str = Field(default="function", description="Tool type")
    function: Dict[str, Any] = Field(
        description="Function definition with name, description, and parameters"
    )
Example:
from ag_kit_py.providers import ToolDefinition

tool = ToolDefinition(
    type="function",
    function={
        "name": "get_weather",
        "description": "Get current weather for a location",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "City name"
                },
                "unit": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "Temperature unit"
                }
            },
            "required": ["location"]
        }
    }
)

ToolCall

Tool call representation.
from pydantic import BaseModel, Field
from typing import Dict

class ToolCall(BaseModel):
    """Tool call representation."""
    
    id: str = Field(description="Unique identifier for the tool call")
    function: Dict[str, str] = Field(description="Function name and arguments")
Example:
from ag_kit_py.providers import ToolCall

call = ToolCall(
    id="call_abc123",
    function={
        "name": "get_weather",
        "arguments": '{"location": "Paris", "unit": "celsius"}'
    }
)

ModelProviderError

Exception class for provider errors.
class ModelProviderError(Exception):
    """Base exception for model provider errors."""
    
    def __init__(
        self,
        message: str,
        error_type: str = "unknown",
        status: Optional[int] = None,
        code: Optional[str] = None,
        details: Optional[Any] = None
    ):
        """Initialize model provider error.
        
        Args:
            message: Error message
            error_type: Type of error (authentication, rate_limit, etc.)
            status: HTTP status code if applicable
            code: Error code from provider
            details: Additional error details
        """
        super().__init__(message)
        self.message = message
        self.error_type = error_type
        self.status = status
        self.code = code
        self.details = details
Error Types:
  • authentication: Invalid API key or credentials
  • rate_limit: Rate limit exceeded
  • quota_exceeded: API quota exceeded
  • invalid_request: Invalid request parameters
  • server_error: Provider server error
  • timeout: Request timeout
  • unknown: Unknown error
Example:
from ag_kit_py.providers import ModelProviderError

try:
    response = await provider.create_completion(params)
except ModelProviderError as e:
    if e.error_type == "authentication":
        print("Invalid API key")
    elif e.error_type == "rate_limit":
        print("Rate limit exceeded, retry later")
    else:
        print(f"Error: {e.message}")

ProviderType

Enum of supported provider types.
from enum import Enum

class ProviderType(str, Enum):
    """Supported model provider types."""
    
    OPENAI = "openai"
    ZHIPU = "zhipu"
    QWEN = "qwen"
    DEEPSEEK = "deepseek"
    ANTHROPIC = "anthropic"
    CUSTOM = "custom"

Implementing a Custom Provider

To create a custom provider, extend BaseModelProvider and implement all abstract methods:
from ag_kit_py.providers import (
    BaseModelProvider,
    ModelProviderConfig,
    ChatCompletion,
    ChatCompletionChunk,
    ChatCompletionParams,
    ToolDefinition,
    ToolCall
)
from typing import Any, AsyncIterator, List
import time

class CustomProvider(BaseModelProvider):
    """Custom model provider implementation."""
    
    def __init__(self, config: ModelProviderConfig):
        super().__init__(config)
        # Initialize your custom client
        self.client = create_custom_client(config.api_key, config.base_url)
    
    def get_provider_name(self) -> str:
        return "custom"
    
    def get_default_model(self) -> str:
        return "custom-model-v1"
    
    async def create_completion(
        self, 
        params: ChatCompletionParams
    ) -> ChatCompletion:
        try:
            # Call your custom API
            response = await self.client.chat.create(
                model=params.model,
                messages=params.messages,
                temperature=params.temperature
            )
            
            # Convert to standard format
            return ChatCompletion(
                id=response.id,
                object="chat.completion",
                created=int(time.time()),
                model=params.model,
                choices=[{
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": response.content
                    },
                    "finish_reason": "stop"
                }],
                usage=response.usage
            )
        except Exception as e:
            raise self._create_error(str(e), "unknown", e)
    
    async def create_stream(
        self, 
        params: ChatCompletionParams
    ) -> AsyncIterator[ChatCompletionChunk]:
        try:
            chunk_id = f"chatcmpl-{int(time.time())}"
            
            async for chunk in self.client.chat.stream(
                model=params.model,
                messages=params.messages
            ):
                yield ChatCompletionChunk(
                    id=chunk_id,
                    object="chat.completion.chunk",
                    created=int(time.time()),
                    model=params.model,
                    choices=[{
                        "index": 0,
                        "delta": {"content": chunk.content},
                        "finish_reason": None
                    }]
                )
        except Exception as e:
            raise self._create_error(str(e), "unknown", e)
    
    def format_tools(self, tools: List[Any]) -> List[ToolDefinition]:
        # Implement tool formatting for your provider
        return [ToolDefinition(**tool) for tool in tools]
    
    def parse_tool_calls(self, response: ChatCompletion) -> List[ToolCall]:
        # Implement tool call parsing for your provider
        tool_calls = []
        if response.choices:
            message = response.choices[0].get("message", {})
            raw_calls = message.get("tool_calls", [])
            for call in raw_calls:
                tool_calls.append(ToolCall(
                    id=call["id"],
                    function=call["function"]
                ))
        return tool_calls
    
    def get_langchain_model(self, **kwargs) -> Any:
        # Return a LangChain-compatible model
        from langchain_core.language_models import BaseChatModel
        # Implement LangChain wrapper for your provider
        return CustomLangChainModel(self.client, **kwargs)