Base Provider API
The base provider module defines the abstract base class and core interfaces that all model providers must implement. This ensures a consistent API across different LLM services while allowing provider-specific optimizations.BaseModelProvider
Abstract base class for all model providers.Class Definition
Copy
from abc import ABC, abstractmethod
from ag_kit_py.providers import BaseModelProvider, ModelProviderConfig
class BaseModelProvider(ABC):
"""Abstract base class for model providers."""
def __init__(self, config: ModelProviderConfig):
"""Initialize the model provider.
Args:
config: Provider configuration
"""
self.config = config
self.default_model = config.default_model or self.get_default_model()
Abstract Methods
All provider implementations must implement these methods:create_completion()
Create a non-streaming chat completion.Copy
@abstractmethod
async def create_completion(
self,
params: ChatCompletionParams
) -> ChatCompletion:
"""Create a chat completion.
Args:
params: Completion parameters including model, messages, temperature, etc.
Returns:
ChatCompletion: Complete response with message content and metadata
Raises:
ModelProviderError: If completion fails
"""
pass
Copy
async def create_completion(self, params: ChatCompletionParams) -> ChatCompletion:
try:
model = self.get_langchain_model(
model=params.model,
temperature=params.temperature
)
if params.tools:
formatted_tools = self.format_tools(params.tools)
model = model.bind_tools(formatted_tools)
response = await model.ainvoke(params.messages)
return ChatCompletion(
id=f"chatcmpl-{int(time.time())}",
object="chat.completion",
created=int(time.time()),
model=params.model,
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": response.content,
"tool_calls": getattr(response, "tool_calls", None)
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
)
except Exception as e:
raise self._handle_error(e)
create_stream()
Create a streaming chat completion.Copy
@abstractmethod
async def create_stream(
self,
params: ChatCompletionParams
) -> AsyncIterator[ChatCompletionChunk]:
"""Create a streaming chat completion.
Args:
params: Completion parameters
Yields:
ChatCompletionChunk: Incremental response chunks
Raises:
ModelProviderError: If streaming fails
"""
pass
Copy
async def create_stream(
self,
params: ChatCompletionParams
) -> AsyncIterator[ChatCompletionChunk]:
try:
model = self.get_langchain_model(
model=params.model,
temperature=params.temperature
)
if params.tools:
formatted_tools = self.format_tools(params.tools)
model = model.bind_tools(formatted_tools)
chunk_id = f"chatcmpl-{int(time.time())}"
async for chunk in model.astream(params.messages):
yield ChatCompletionChunk(
id=chunk_id,
object="chat.completion.chunk",
created=int(time.time()),
model=params.model,
choices=[{
"index": 0,
"delta": {
"role": "assistant",
"content": chunk.content if hasattr(chunk, "content") else str(chunk)
},
"finish_reason": None
}]
)
# Final chunk with finish_reason
yield ChatCompletionChunk(
id=chunk_id,
object="chat.completion.chunk",
created=int(time.time()),
model=params.model,
choices=[{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
)
except Exception as e:
raise self._handle_error(e)
get_provider_name()
Get the provider name identifier.Copy
@abstractmethod
def get_provider_name(self) -> str:
"""Get the provider name.
Returns:
str: Provider name (e.g., "openai", "zhipu")
"""
pass
get_default_model()
Get the default model for this provider.Copy
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider.
Returns:
str: Default model name (e.g., "gpt-4o-mini")
"""
pass
format_tools()
Format tools for provider-specific API.Copy
@abstractmethod
def format_tools(self, tools: List[Any]) -> List[ToolDefinition]:
"""Format tools for this provider.
Args:
tools: List of tools in various formats
Returns:
List[ToolDefinition]: Formatted tool definitions
"""
pass
Copy
def format_tools(self, tools: List[Any]) -> List[Dict[str, Any]]:
formatted_tools = []
for tool in tools:
if isinstance(tool, dict):
if "type" in tool and "function" in tool:
# Already in correct format
formatted_tools.append(tool)
elif "name" in tool and "parameters" in tool:
# Convert from simplified format
formatted_tools.append({
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": tool["parameters"]
}
})
elif isinstance(tool, ToolDefinition):
formatted_tools.append(tool.model_dump())
return formatted_tools
parse_tool_calls()
Parse tool calls from completion response.Copy
@abstractmethod
def parse_tool_calls(self, response: ChatCompletion) -> List[ToolCall]:
"""Parse tool calls from completion response.
Args:
response: Chat completion response
Returns:
List[ToolCall]: Parsed tool calls
"""
pass
get_langchain_model()
Get a LangChain-compatible model instance.Copy
@abstractmethod
def get_langchain_model(self, **kwargs) -> Any:
"""Get a LangChain-compatible model instance.
This method returns a LangChain ChatModel instance that can be used
with LangGraph and other LangChain components.
Args:
**kwargs: Additional model configuration to override defaults
Returns:
Any: LangChain ChatModel instance
"""
pass
Concrete Methods
These methods have default implementations but can be overridden:supports_tools()
Check if provider supports tool/function calling.Copy
def supports_tools(self) -> bool:
"""Check if provider supports tool/function calling.
Returns:
bool: True if tools are supported (default: True)
"""
return True
supports_streaming()
Check if provider supports streaming responses.Copy
def supports_streaming(self) -> bool:
"""Check if provider supports streaming responses.
Returns:
bool: True if streaming is supported (default: True)
"""
return True
validate_config()
Validate provider configuration.Copy
def validate_config(self, config: ModelProviderConfig) -> bool:
"""Validate provider configuration.
Args:
config: Configuration to validate
Returns:
bool: True if configuration is valid
"""
return bool(config.api_key and config.api_key.strip())
Protected Methods
These helper methods are available for subclasses:_create_error()
Create a model provider error.Copy
def _create_error(
self,
message: str,
error_type: str = "unknown",
details: Optional[Any] = None
) -> ModelProviderError:
"""Create a model provider error.
Args:
message: Error message
error_type: Type of error
details: Additional error details
Returns:
ModelProviderError: Formatted error
"""
return ModelProviderError(
message=message,
error_type=error_type,
details=details
)
_is_non_retryable_error()
Check if an error should not be retried.Copy
def _is_non_retryable_error(self, error: Exception) -> bool:
"""Check if an error should not be retried.
Args:
error: Error to check
Returns:
bool: True if error should not be retried
"""
if isinstance(error, ModelProviderError):
non_retryable_types = ["authentication", "invalid_request"]
return error.error_type in non_retryable_types
return False
Data Models
ModelProviderConfig
Configuration dataclass for model providers.Copy
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelProviderConfig:
"""Configuration for model providers."""
api_key: str # Required: API key for authentication
base_url: Optional[str] = None # Optional: Custom API endpoint
default_model: Optional[str] = None # Optional: Default model name
temperature: float = 0.7 # Optional: Sampling temperature
timeout: int = 30000 # Optional: Request timeout (ms)
max_retries: int = 3 # Optional: Maximum retry attempts
retry_delay: int = 1000 # Optional: Delay between retries (ms)
proxy: Optional[str] = None # Optional: Proxy URL
organization: Optional[str] = None # Optional: Organization ID
project: Optional[str] = None # Optional: Project ID
Copy
from ag_kit_py.providers import ModelProviderConfig
config = ModelProviderConfig(
api_key="sk-...",
base_url="https://api.openai.com/v1",
default_model="gpt-4o-mini",
temperature=0.7,
timeout=30000,
max_retries=3
)
ChatCompletionParams
Parameters for chat completion requests.Copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
@dataclass
class ChatCompletionParams:
"""Parameters for chat completion requests."""
model: str # Required: Model name
messages: List[Dict[str, Any]] # Required: Conversation messages
temperature: Optional[float] = 0.7 # Optional: Sampling temperature
max_tokens: Optional[int] = None # Optional: Maximum tokens to generate
top_p: Optional[float] = None # Optional: Nucleus sampling parameter
frequency_penalty: Optional[float] = None # Optional: Frequency penalty
presence_penalty: Optional[float] = None # Optional: Presence penalty
tools: Optional[List[ToolDefinition]] = None # Optional: Available tools
tool_choice: Optional[Union[str, Dict]] = "auto" # Optional: Tool choice strategy
stream: bool = False # Optional: Enable streaming
stop: Optional[List[str]] = None # Optional: Stop sequences
Copy
from ag_kit_py.providers import ChatCompletionParams
params = ChatCompletionParams(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello!"}
],
temperature=0.7,
max_tokens=1000,
tools=[
{
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {...}}
}
]
)
ChatCompletion
Chat completion response.Copy
from dataclasses import dataclass, field
from typing import Any, Dict, List
@dataclass
class ChatCompletion:
"""Chat completion response."""
id: str # Unique completion ID
object: str = "chat.completion" # Object type
created: int = 0 # Creation timestamp
model: str = "" # Model used
choices: List[Dict[str, Any]] = field(default_factory=list) # Response choices
usage: Dict[str, int] = field(default_factory=dict) # Token usage
Copy
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4o-mini",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?",
"tool_calls": None
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
ChatCompletionChunk
Streaming chat completion chunk.Copy
from dataclasses import dataclass, field
from typing import Any, Dict, List
@dataclass
class ChatCompletionChunk:
"""Chat completion chunk for streaming responses."""
id: str # Unique completion ID
object: str = "chat.completion.chunk" # Object type
created: int = 0 # Creation timestamp
model: str = "" # Model used
choices: List[Dict[str, Any]] = field(default_factory=list) # Response chunks
Copy
{
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1677652288,
"model": "gpt-4o-mini",
"choices": [{
"index": 0,
"delta": {
"role": "assistant",
"content": "Hello"
},
"finish_reason": None
}]
}
ToolDefinition
Tool definition for model providers.Copy
from pydantic import BaseModel, Field
from typing import Any, Dict
class ToolDefinition(BaseModel):
"""Tool definition for model providers."""
type: str = Field(default="function", description="Tool type")
function: Dict[str, Any] = Field(
description="Function definition with name, description, and parameters"
)
Copy
from ag_kit_py.providers import ToolDefinition
tool = ToolDefinition(
type="function",
function={
"name": "get_weather",
"description": "Get current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature unit"
}
},
"required": ["location"]
}
}
)
ToolCall
Tool call representation.Copy
from pydantic import BaseModel, Field
from typing import Dict
class ToolCall(BaseModel):
"""Tool call representation."""
id: str = Field(description="Unique identifier for the tool call")
function: Dict[str, str] = Field(description="Function name and arguments")
Copy
from ag_kit_py.providers import ToolCall
call = ToolCall(
id="call_abc123",
function={
"name": "get_weather",
"arguments": '{"location": "Paris", "unit": "celsius"}'
}
)
ModelProviderError
Exception class for provider errors.Copy
class ModelProviderError(Exception):
"""Base exception for model provider errors."""
def __init__(
self,
message: str,
error_type: str = "unknown",
status: Optional[int] = None,
code: Optional[str] = None,
details: Optional[Any] = None
):
"""Initialize model provider error.
Args:
message: Error message
error_type: Type of error (authentication, rate_limit, etc.)
status: HTTP status code if applicable
code: Error code from provider
details: Additional error details
"""
super().__init__(message)
self.message = message
self.error_type = error_type
self.status = status
self.code = code
self.details = details
authentication: Invalid API key or credentialsrate_limit: Rate limit exceededquota_exceeded: API quota exceededinvalid_request: Invalid request parametersserver_error: Provider server errortimeout: Request timeoutunknown: Unknown error
Copy
from ag_kit_py.providers import ModelProviderError
try:
response = await provider.create_completion(params)
except ModelProviderError as e:
if e.error_type == "authentication":
print("Invalid API key")
elif e.error_type == "rate_limit":
print("Rate limit exceeded, retry later")
else:
print(f"Error: {e.message}")
ProviderType
Enum of supported provider types.Copy
from enum import Enum
class ProviderType(str, Enum):
"""Supported model provider types."""
OPENAI = "openai"
ZHIPU = "zhipu"
QWEN = "qwen"
DEEPSEEK = "deepseek"
ANTHROPIC = "anthropic"
CUSTOM = "custom"
Implementing a Custom Provider
To create a custom provider, extendBaseModelProvider and implement all abstract methods:
Copy
from ag_kit_py.providers import (
BaseModelProvider,
ModelProviderConfig,
ChatCompletion,
ChatCompletionChunk,
ChatCompletionParams,
ToolDefinition,
ToolCall
)
from typing import Any, AsyncIterator, List
import time
class CustomProvider(BaseModelProvider):
"""Custom model provider implementation."""
def __init__(self, config: ModelProviderConfig):
super().__init__(config)
# Initialize your custom client
self.client = create_custom_client(config.api_key, config.base_url)
def get_provider_name(self) -> str:
return "custom"
def get_default_model(self) -> str:
return "custom-model-v1"
async def create_completion(
self,
params: ChatCompletionParams
) -> ChatCompletion:
try:
# Call your custom API
response = await self.client.chat.create(
model=params.model,
messages=params.messages,
temperature=params.temperature
)
# Convert to standard format
return ChatCompletion(
id=response.id,
object="chat.completion",
created=int(time.time()),
model=params.model,
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": response.content
},
"finish_reason": "stop"
}],
usage=response.usage
)
except Exception as e:
raise self._create_error(str(e), "unknown", e)
async def create_stream(
self,
params: ChatCompletionParams
) -> AsyncIterator[ChatCompletionChunk]:
try:
chunk_id = f"chatcmpl-{int(time.time())}"
async for chunk in self.client.chat.stream(
model=params.model,
messages=params.messages
):
yield ChatCompletionChunk(
id=chunk_id,
object="chat.completion.chunk",
created=int(time.time()),
model=params.model,
choices=[{
"index": 0,
"delta": {"content": chunk.content},
"finish_reason": None
}]
)
except Exception as e:
raise self._create_error(str(e), "unknown", e)
def format_tools(self, tools: List[Any]) -> List[ToolDefinition]:
# Implement tool formatting for your provider
return [ToolDefinition(**tool) for tool in tools]
def parse_tool_calls(self, response: ChatCompletion) -> List[ToolCall]:
# Implement tool call parsing for your provider
tool_calls = []
if response.choices:
message = response.choices[0].get("message", {})
raw_calls = message.get("tool_calls", [])
for call in raw_calls:
tool_calls.append(ToolCall(
id=call["id"],
function=call["function"]
))
return tool_calls
def get_langchain_model(self, **kwargs) -> Any:
# Return a LangChain-compatible model
from langchain_core.language_models import BaseChatModel
# Implement LangChain wrapper for your provider
return CustomLangChainModel(self.client, **kwargs)
Related Documentation
- Providers Overview - Provider system overview
- OpenAI Provider - OpenAI implementation example
- Factory Functions - Provider creation utilities