Coverage for lintro / ai / providers / openai.py: 52%

85 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""OpenAI AI provider implementation. 

2 

3Uses the OpenAI Python SDK to communicate with GPT models. 

4Requires the ``openai`` package (installed via ``lintro[ai]``). 

5""" 

6 

7from __future__ import annotations 

8 

9from collections.abc import Iterator 

10from contextlib import contextmanager 

11from typing import Any 

12 

13from loguru import logger 

14 

15from lintro.ai.cost import estimate_cost 

16from lintro.ai.exceptions import ( 

17 AIAuthenticationError, 

18 AIProviderError, 

19 AIRateLimitError, 

20) 

21from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider 

22from lintro.ai.providers.constants import ( 

23 DEFAULT_MAX_TOKENS, 

24 DEFAULT_PER_CALL_MAX_TOKENS, 

25 DEFAULT_TIMEOUT, 

26) 

27from lintro.ai.registry import PROVIDERS, AIProvider 

28 

29_has_openai = False 

30try: 

31 import openai 

32 

33 _has_openai = True 

34except ImportError: 

35 pass 

36 

37DEFAULT_MODEL = PROVIDERS.openai.default_model 

38DEFAULT_API_KEY_ENV = PROVIDERS.openai.default_api_key_env 

39 

40 

41class OpenAIProvider(BaseAIProvider): 

42 """OpenAI GPT provider.""" 

43 

44 @staticmethod 

45 @contextmanager 

46 def _map_errors() -> Iterator[None]: 

47 """Map OpenAI SDK exceptions to AI exceptions. 

48 

49 Safe to call only when the ``openai`` SDK is installed — 

50 the base class ``__init__`` raises ``AINotAvailableError`` 

51 before any method can be called if the SDK is missing. 

52 """ 

53 try: 

54 yield 

55 except openai.AuthenticationError as e: 

56 raise AIAuthenticationError( 

57 f"OpenAI authentication failed: {e}", 

58 ) from e 

59 except openai.RateLimitError as e: 

60 raise AIRateLimitError( 

61 f"OpenAI rate limit exceeded: {e}", 

62 ) from e 

63 except openai.OpenAIError as e: 

64 logger.debug(f"OpenAI API error: {e}") 

65 raise AIProviderError( 

66 f"OpenAI API error: {e}", 

67 ) from e 

68 

69 def __init__( 

70 self, 

71 *, 

72 model: str | None = None, 

73 api_key_env: str | None = None, 

74 max_tokens: int = DEFAULT_MAX_TOKENS, 

75 base_url: str | None = None, 

76 ) -> None: 

77 """Initialize the OpenAI provider. 

78 

79 Args: 

80 model: Model identifier. Defaults to gpt-4o. 

81 api_key_env: Environment variable for API key. 

82 Defaults to OPENAI_API_KEY. 

83 max_tokens: Default max tokens for completions. 

84 base_url: Custom API base URL for OpenAI-compatible 

85 endpoints (Ollama, vLLM, Azure OpenAI, etc.). 

86 """ 

87 super().__init__( 

88 provider_name=AIProvider.OPENAI, 

89 has_sdk=_has_openai, 

90 sdk_package="openai", 

91 default_model=DEFAULT_MODEL, 

92 default_api_key_env=DEFAULT_API_KEY_ENV, 

93 model=model, 

94 api_key_env=api_key_env, 

95 max_tokens=max_tokens, 

96 base_url=base_url, 

97 ) 

98 

99 def _create_client(self, *, api_key: str) -> Any: 

100 """Create the OpenAI SDK client. 

101 

102 Args: 

103 api_key: The resolved API key. 

104 

105 Returns: 

106 openai.OpenAI: The API client. 

107 """ 

108 kwargs: dict[str, Any] = {"api_key": api_key} 

109 if self._base_url: 

110 kwargs["base_url"] = self._base_url 

111 return openai.OpenAI(**kwargs) 

112 

113 def complete( 

114 self, 

115 prompt: str, 

116 *, 

117 system: str | None = None, 

118 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

119 timeout: float = DEFAULT_TIMEOUT, 

120 ) -> AIResponse: 

121 """Generate a completion using GPT. 

122 

123 Args: 

124 prompt: The user prompt. 

125 system: Optional system prompt. 

126 max_tokens: Maximum tokens to generate. 

127 timeout: Request timeout in seconds. 

128 

129 Returns: 

130 AIResponse: The model's response with usage metadata. 

131 """ 

132 client = self._get_client() 

133 # Per-call cap: the lower of the caller's request and the 

134 # provider-level cap set at init time. 

135 effective_max = min(max_tokens, self._max_tokens) 

136 

137 with self._map_errors(): 

138 messages: list[dict[str, str]] = [] 

139 if system: 

140 messages.append({"role": "system", "content": system}) 

141 messages.append({"role": "user", "content": prompt}) 

142 

143 response = client.chat.completions.create( 

144 model=self._model, 

145 messages=messages, 

146 max_tokens=effective_max, 

147 timeout=timeout, 

148 ) 

149 

150 content = response.choices[0].message.content or "" 

151 

152 input_tokens = 0 

153 output_tokens = 0 

154 if response.usage: 

155 input_tokens = response.usage.prompt_tokens 

156 output_tokens = response.usage.completion_tokens 

157 

158 cost = estimate_cost(self._model, input_tokens, output_tokens) 

159 

160 return AIResponse( 

161 content=content, 

162 model=self._model, 

163 input_tokens=input_tokens, 

164 output_tokens=output_tokens, 

165 cost_estimate=cost, 

166 provider=AIProvider.OPENAI, 

167 ) 

168 

169 def stream_complete( 

170 self, 

171 prompt: str, 

172 *, 

173 system: str | None = None, 

174 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

175 timeout: float = DEFAULT_TIMEOUT, 

176 ) -> AIStreamResult: 

177 """Stream a completion from the OpenAI API token-by-token. 

178 

179 Args: 

180 prompt: The user prompt. 

181 system: Optional system prompt. 

182 max_tokens: Maximum tokens to generate. 

183 timeout: Request timeout in seconds. 

184 

185 Returns: 

186 An AIStreamResult wrapping the token stream. 

187 """ 

188 client = self._get_client() 

189 effective_max = min(max_tokens, self._max_tokens) 

190 

191 messages: list[dict[str, str]] = [] 

192 if system: 

193 messages.append({"role": "system", "content": system}) 

194 messages.append({"role": "user", "content": prompt}) 

195 

196 logger.debug( 

197 f"OpenAI stream request: model={self._model}, " 

198 f"max_tokens={effective_max}", 

199 ) 

200 

201 final_response: list[AIResponse] = [] 

202 accumulated_text: list[str] = [] 

203 

204 def _generate() -> Iterator[str]: 

205 with self._map_errors(): 

206 stream = client.chat.completions.create( 

207 model=self._model, 

208 messages=messages, 

209 max_tokens=effective_max, 

210 timeout=timeout, 

211 stream=True, 

212 stream_options={"include_usage": True}, 

213 ) 

214 

215 input_tokens = 0 

216 output_tokens = 0 

217 

218 for chunk in stream: 

219 if chunk.choices and chunk.choices[0].delta.content: 

220 text = chunk.choices[0].delta.content 

221 accumulated_text.append(text) 

222 yield text 

223 if chunk.usage: 

224 input_tokens = chunk.usage.prompt_tokens 

225 output_tokens = chunk.usage.completion_tokens 

226 

227 cost = estimate_cost(self._model, input_tokens, output_tokens) 

228 final_response.append( 

229 AIResponse( 

230 content="".join(accumulated_text), 

231 model=self._model, 

232 input_tokens=input_tokens, 

233 output_tokens=output_tokens, 

234 cost_estimate=cost, 

235 provider=AIProvider.OPENAI, 

236 ), 

237 ) 

238 

239 def _on_done() -> AIResponse: 

240 if not final_response: 

241 raise AIProviderError( 

242 "OpenAI stream was not fully consumed", 

243 ) 

244 return final_response[0] 

245 

246 return AIStreamResult( 

247 _chunks=_generate(), 

248 _on_done=_on_done, 

249 )