Coverage for lintro / ai / providers / base.py: 94%

48 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Abstract base class for AI providers. 

2 

3Defines the contract that all AI provider implementations must follow. 

4Shared initialisation, API-key resolution, and availability logic live 

5here so that concrete providers only implement SDK-specific pieces. 

6""" 

7 

8from __future__ import annotations 

9 

10import os 

11from abc import ABC, abstractmethod 

12from typing import Any 

13 

14from lintro.ai.exceptions import AIAuthenticationError, AINotAvailableError 

15from lintro.ai.providers.constants import ( 

16 DEFAULT_MAX_TOKENS, 

17 DEFAULT_PER_CALL_MAX_TOKENS, 

18 DEFAULT_TIMEOUT, 

19) 

20from lintro.ai.providers.response import AIResponse # noqa: F401 

21from lintro.ai.providers.stream_result import AIStreamResult # noqa: F401 

22 

23__all__ = ["AIResponse", "AIStreamResult", "BaseAIProvider"] 

24 

25 

26class BaseAIProvider(ABC): 

27 """Abstract base class for AI providers. 

28 

29 Handles common initialisation (model, API-key env var, max tokens, 

30 base URL), lazy client creation with API-key validation, and the 

31 ``is_available`` / property boilerplate. 

32 

33 Subclasses must implement: 

34 * ``_create_client()`` -- return an SDK-specific client instance. 

35 * ``complete()`` -- perform the SDK-specific API call and map errors. 

36 """ 

37 

38 def __init__( 

39 self, 

40 *, 

41 provider_name: str, 

42 has_sdk: bool, 

43 sdk_package: str, 

44 default_model: str, 

45 default_api_key_env: str, 

46 model: str | None = None, 

47 api_key_env: str | None = None, 

48 max_tokens: int = DEFAULT_MAX_TOKENS, 

49 base_url: str | None = None, 

50 ) -> None: 

51 """Initialise the provider with shared parameters. 

52 

53 Args: 

54 provider_name: Human-readable provider name (e.g. "anthropic"). 

55 has_sdk: Whether the provider SDK was successfully imported. 

56 sdk_package: Package name shown in the install hint. 

57 default_model: Fallback model when *model* is ``None``. 

58 default_api_key_env: Fallback env-var name when *api_key_env* 

59 is ``None``. 

60 model: Model identifier override. 

61 api_key_env: Environment variable for the API key override. 

62 The key is required at runtime; its absence raises 

63 ``AIAuthenticationError`` on first API call. 

64 max_tokens: Provider-level cap on generated tokens. 

65 base_url: Custom API base URL. 

66 

67 Raises: 

68 AINotAvailableError: If the SDK is not installed. 

69 """ 

70 if not has_sdk: 

71 raise AINotAvailableError( 

72 f"{provider_name.title()} provider requires the " 

73 f"'{sdk_package}' package. " 

74 "Install with: uv pip install 'lintro[ai]'", 

75 ) 

76 

77 self._provider_name = provider_name 

78 self._has_sdk = has_sdk 

79 self._model = model or default_model 

80 self._api_key_env = api_key_env or default_api_key_env 

81 self._max_tokens = max_tokens 

82 self._base_url = base_url 

83 self._client: Any = None 

84 

85 # -- Client management ------------------------------------------------- 

86 

87 def _get_client(self) -> Any: 

88 """Get or lazily create the SDK client. 

89 

90 Returns: 

91 The SDK client instance. 

92 

93 Raises: 

94 AIAuthenticationError: If no API key is found. 

95 """ 

96 if self._client is not None: 

97 return self._client 

98 

99 api_key = os.environ.get(self._api_key_env) or "" 

100 if not api_key and not self._base_url: 

101 raise AIAuthenticationError( 

102 f"No API key found. Set the {self._api_key_env} " 

103 f"environment variable.", 

104 ) 

105 

106 self._client = self._create_client(api_key=api_key) 

107 return self._client 

108 

109 @abstractmethod 

110 def _create_client(self, *, api_key: str) -> Any: 

111 """Create the SDK-specific client. 

112 

113 Args: 

114 api_key: The resolved API key. 

115 

116 Returns: 

117 An SDK client instance. 

118 """ 

119 ... 

120 

121 # -- Abstract: SDK-specific completion --------------------------------- 

122 

123 @abstractmethod 

124 def complete( 

125 self, 

126 prompt: str, 

127 *, 

128 system: str | None = None, 

129 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

130 timeout: float = DEFAULT_TIMEOUT, 

131 ) -> AIResponse: 

132 """Generate a completion from the AI model. 

133 

134 Args: 

135 prompt: The user prompt to send to the model. 

136 system: Optional system prompt to set context. 

137 max_tokens: Maximum number of tokens to generate. 

138 timeout: Request timeout in seconds. 

139 

140 Returns: 

141 AIResponse: The model's response with usage metadata. 

142 

143 Raises: 

144 AIProviderError: If the API call fails. 

145 AIAuthenticationError: If authentication fails. 

146 AIRateLimitError: If rate limited. 

147 """ 

148 ... 

149 

150 # -- Streaming (default delegates to complete) -------------------------- 

151 

152 def stream_complete( 

153 self, 

154 prompt: str, 

155 *, 

156 system: str | None = None, 

157 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

158 timeout: float = DEFAULT_TIMEOUT, 

159 ) -> AIStreamResult: 

160 """Stream a completion. Default: delegates to complete(). 

161 

162 Providers with native streaming support should override this. 

163 

164 Args: 

165 prompt: The user prompt text. 

166 system: Optional system prompt. 

167 max_tokens: Maximum tokens to generate. 

168 timeout: Request timeout in seconds. 

169 

170 Returns: 

171 An AIStreamResult wrapping the token stream. 

172 """ 

173 response = self.complete( 

174 prompt, 

175 system=system, 

176 max_tokens=max_tokens, 

177 timeout=timeout, 

178 ) 

179 return AIStreamResult( 

180 _chunks=iter([response.content]), 

181 _on_done=lambda: response, 

182 ) 

183 

184 # -- Concrete shared helpers ------------------------------------------- 

185 

186 def is_available(self) -> bool: 

187 """Check if this provider is ready to use. 

188 

189 A provider is available when its SDK is installed and at least 

190 one of these is true: an API key env var is set, a custom 

191 base URL is configured, or a client has already been created. 

192 

193 Returns: 

194 bool: True if the provider can serve requests. 

195 """ 

196 if not self._has_sdk: 

197 return False 

198 return bool( 

199 os.environ.get(self._api_key_env) 

200 or self._base_url 

201 or self._client is not None, 

202 ) 

203 

204 @property 

205 def name(self) -> str: 

206 """Return the provider's name. 

207 

208 Returns: 

209 str: Provider identifier (e.g., "anthropic", "openai"). 

210 """ 

211 return self._provider_name 

212 

213 @property 

214 def model_name(self) -> str: 

215 """Return the configured model name. 

216 

217 Returns: 

218 str: Model identifier being used. 

219 """ 

220 return self._model 

221 

222 @model_name.setter 

223 def model_name(self, value: str) -> None: 

224 """Set the model name. 

225 

226 Args: 

227 value: New model identifier. 

228 """ 

229 self._model = value