Custom Provider¶
Add your own provider in three steps.
1. Subclass BaseProvider¶
# my_provider.py
import os
from typing import Any, AsyncIterator, ClassVar, Iterator
from llmgate.base import BaseProvider
from llmgate.exceptions import AuthError, ProviderAPIError
from llmgate.types import (
Choice, CompletionRequest, CompletionResponse,
Message, StreamChunk, TokenUsage,
)
class MyProvider(BaseProvider):
name: ClassVar[str] = "myprovider"
supported_model_prefixes: ClassVar[tuple[str, ...]] = ("myprovider/",)
def __init__(self, api_key: str | None = None, **kwargs: Any) -> None:
try:
import my_sdk # noqa: PLC0415
except ImportError as e:
raise ImportError("my_sdk required: pip install my-sdk") from e
key = api_key or os.environ.get("MY_API_KEY")
if not key:
raise AuthError("MY_API_KEY not set", provider=self.name)
self._client = my_sdk.Client(api_key=key)
self._async_client = my_sdk.AsyncClient(api_key=key)
def complete(self, request: CompletionRequest) -> CompletionResponse:
model = self._strip_prefix(request.model)
try:
raw = self._client.generate(
model=model,
messages=[m.to_dict() for m in request.messages],
max_tokens=request.max_tokens,
temperature=request.temperature,
)
except Exception as exc:
raise ProviderAPIError(str(exc), provider=self.name) from exc
return CompletionResponse(
id=raw.id,
model=request.model,
provider=self.name,
choices=[Choice(
index=0,
message=Message(role="assistant", content=raw.text),
finish_reason=raw.finish_reason,
)],
usage=TokenUsage(
prompt_tokens=raw.usage.input,
completion_tokens=raw.usage.output,
total_tokens=raw.usage.total,
),
raw=raw,
)
async def acomplete(self, request: CompletionRequest) -> CompletionResponse:
# ... async version
...
def stream(self, request: CompletionRequest) -> Iterator[StreamChunk]:
model = self._strip_prefix(request.model)
for chunk in self._client.stream(model=model, messages=...):
yield StreamChunk(delta=chunk.text)
async def astream(self, request: CompletionRequest) -> AsyncIterator[StreamChunk]:
model = self._strip_prefix(request.model)
async for chunk in self._async_client.stream(model=model, messages=...):
yield StreamChunk(delta=chunk.text)
2. Register it¶
import llmgate.completion as _registry
from my_provider import MyProvider
# Prepend so your provider takes priority over built-ins for its prefix
_registry._OPTIONAL_PROVIDERS.insert(0, ("myprovider/", "my_provider", "MyProvider"))