Coverage for lintro / ai / providers / base.py: 94%
48 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Abstract base class for AI providers.
3Defines the contract that all AI provider implementations must follow.
4Shared initialisation, API-key resolution, and availability logic live
5here so that concrete providers only implement SDK-specific pieces.
6"""
8from __future__ import annotations
10import os
11from abc import ABC, abstractmethod
12from typing import Any
14from lintro.ai.exceptions import AIAuthenticationError, AINotAvailableError
15from lintro.ai.providers.constants import (
16 DEFAULT_MAX_TOKENS,
17 DEFAULT_PER_CALL_MAX_TOKENS,
18 DEFAULT_TIMEOUT,
19)
20from lintro.ai.providers.response import AIResponse # noqa: F401
21from lintro.ai.providers.stream_result import AIStreamResult # noqa: F401
23__all__ = ["AIResponse", "AIStreamResult", "BaseAIProvider"]
26class BaseAIProvider(ABC):
27 """Abstract base class for AI providers.
29 Handles common initialisation (model, API-key env var, max tokens,
30 base URL), lazy client creation with API-key validation, and the
31 ``is_available`` / property boilerplate.
33 Subclasses must implement:
34 * ``_create_client()`` -- return an SDK-specific client instance.
35 * ``complete()`` -- perform the SDK-specific API call and map errors.
36 """
38 def __init__(
39 self,
40 *,
41 provider_name: str,
42 has_sdk: bool,
43 sdk_package: str,
44 default_model: str,
45 default_api_key_env: str,
46 model: str | None = None,
47 api_key_env: str | None = None,
48 max_tokens: int = DEFAULT_MAX_TOKENS,
49 base_url: str | None = None,
50 ) -> None:
51 """Initialise the provider with shared parameters.
53 Args:
54 provider_name: Human-readable provider name (e.g. "anthropic").
55 has_sdk: Whether the provider SDK was successfully imported.
56 sdk_package: Package name shown in the install hint.
57 default_model: Fallback model when *model* is ``None``.
58 default_api_key_env: Fallback env-var name when *api_key_env*
59 is ``None``.
60 model: Model identifier override.
61 api_key_env: Environment variable for the API key override.
62 The key is required at runtime; its absence raises
63 ``AIAuthenticationError`` on first API call.
64 max_tokens: Provider-level cap on generated tokens.
65 base_url: Custom API base URL.
67 Raises:
68 AINotAvailableError: If the SDK is not installed.
69 """
70 if not has_sdk:
71 raise AINotAvailableError(
72 f"{provider_name.title()} provider requires the "
73 f"'{sdk_package}' package. "
74 "Install with: uv pip install 'lintro[ai]'",
75 )
77 self._provider_name = provider_name
78 self._has_sdk = has_sdk
79 self._model = model or default_model
80 self._api_key_env = api_key_env or default_api_key_env
81 self._max_tokens = max_tokens
82 self._base_url = base_url
83 self._client: Any = None
85 # -- Client management -------------------------------------------------
87 def _get_client(self) -> Any:
88 """Get or lazily create the SDK client.
90 Returns:
91 The SDK client instance.
93 Raises:
94 AIAuthenticationError: If no API key is found.
95 """
96 if self._client is not None:
97 return self._client
99 api_key = os.environ.get(self._api_key_env) or ""
100 if not api_key and not self._base_url:
101 raise AIAuthenticationError(
102 f"No API key found. Set the {self._api_key_env} "
103 f"environment variable.",
104 )
106 self._client = self._create_client(api_key=api_key)
107 return self._client
109 @abstractmethod
110 def _create_client(self, *, api_key: str) -> Any:
111 """Create the SDK-specific client.
113 Args:
114 api_key: The resolved API key.
116 Returns:
117 An SDK client instance.
118 """
119 ...
121 # -- Abstract: SDK-specific completion ---------------------------------
123 @abstractmethod
124 def complete(
125 self,
126 prompt: str,
127 *,
128 system: str | None = None,
129 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
130 timeout: float = DEFAULT_TIMEOUT,
131 ) -> AIResponse:
132 """Generate a completion from the AI model.
134 Args:
135 prompt: The user prompt to send to the model.
136 system: Optional system prompt to set context.
137 max_tokens: Maximum number of tokens to generate.
138 timeout: Request timeout in seconds.
140 Returns:
141 AIResponse: The model's response with usage metadata.
143 Raises:
144 AIProviderError: If the API call fails.
145 AIAuthenticationError: If authentication fails.
146 AIRateLimitError: If rate limited.
147 """
148 ...
150 # -- Streaming (default delegates to complete) --------------------------
152 def stream_complete(
153 self,
154 prompt: str,
155 *,
156 system: str | None = None,
157 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
158 timeout: float = DEFAULT_TIMEOUT,
159 ) -> AIStreamResult:
160 """Stream a completion. Default: delegates to complete().
162 Providers with native streaming support should override this.
164 Args:
165 prompt: The user prompt text.
166 system: Optional system prompt.
167 max_tokens: Maximum tokens to generate.
168 timeout: Request timeout in seconds.
170 Returns:
171 An AIStreamResult wrapping the token stream.
172 """
173 response = self.complete(
174 prompt,
175 system=system,
176 max_tokens=max_tokens,
177 timeout=timeout,
178 )
179 return AIStreamResult(
180 _chunks=iter([response.content]),
181 _on_done=lambda: response,
182 )
184 # -- Concrete shared helpers -------------------------------------------
186 def is_available(self) -> bool:
187 """Check if this provider is ready to use.
189 A provider is available when its SDK is installed and at least
190 one of these is true: an API key env var is set, a custom
191 base URL is configured, or a client has already been created.
193 Returns:
194 bool: True if the provider can serve requests.
195 """
196 if not self._has_sdk:
197 return False
198 return bool(
199 os.environ.get(self._api_key_env)
200 or self._base_url
201 or self._client is not None,
202 )
204 @property
205 def name(self) -> str:
206 """Return the provider's name.
208 Returns:
209 str: Provider identifier (e.g., "anthropic", "openai").
210 """
211 return self._provider_name
213 @property
214 def model_name(self) -> str:
215 """Return the configured model name.
217 Returns:
218 str: Model identifier being used.
219 """
220 return self._model
222 @model_name.setter
223 def model_name(self, value: str) -> None:
224 """Set the model name.
226 Args:
227 value: New model identifier.
228 """
229 self._model = value