Coverage for lintro / ai / providers / openai.py: 52%
85 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""OpenAI AI provider implementation.
3Uses the OpenAI Python SDK to communicate with GPT models.
4Requires the ``openai`` package (installed via ``lintro[ai]``).
5"""
7from __future__ import annotations
9from collections.abc import Iterator
10from contextlib import contextmanager
11from typing import Any
13from loguru import logger
15from lintro.ai.cost import estimate_cost
16from lintro.ai.exceptions import (
17 AIAuthenticationError,
18 AIProviderError,
19 AIRateLimitError,
20)
21from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider
22from lintro.ai.providers.constants import (
23 DEFAULT_MAX_TOKENS,
24 DEFAULT_PER_CALL_MAX_TOKENS,
25 DEFAULT_TIMEOUT,
26)
27from lintro.ai.registry import PROVIDERS, AIProvider
29_has_openai = False
30try:
31 import openai
33 _has_openai = True
34except ImportError:
35 pass
37DEFAULT_MODEL = PROVIDERS.openai.default_model
38DEFAULT_API_KEY_ENV = PROVIDERS.openai.default_api_key_env
41class OpenAIProvider(BaseAIProvider):
42 """OpenAI GPT provider."""
44 @staticmethod
45 @contextmanager
46 def _map_errors() -> Iterator[None]:
47 """Map OpenAI SDK exceptions to AI exceptions.
49 Safe to call only when the ``openai`` SDK is installed —
50 the base class ``__init__`` raises ``AINotAvailableError``
51 before any method can be called if the SDK is missing.
52 """
53 try:
54 yield
55 except openai.AuthenticationError as e:
56 raise AIAuthenticationError(
57 f"OpenAI authentication failed: {e}",
58 ) from e
59 except openai.RateLimitError as e:
60 raise AIRateLimitError(
61 f"OpenAI rate limit exceeded: {e}",
62 ) from e
63 except openai.OpenAIError as e:
64 logger.debug(f"OpenAI API error: {e}")
65 raise AIProviderError(
66 f"OpenAI API error: {e}",
67 ) from e
69 def __init__(
70 self,
71 *,
72 model: str | None = None,
73 api_key_env: str | None = None,
74 max_tokens: int = DEFAULT_MAX_TOKENS,
75 base_url: str | None = None,
76 ) -> None:
77 """Initialize the OpenAI provider.
79 Args:
80 model: Model identifier. Defaults to gpt-4o.
81 api_key_env: Environment variable for API key.
82 Defaults to OPENAI_API_KEY.
83 max_tokens: Default max tokens for completions.
84 base_url: Custom API base URL for OpenAI-compatible
85 endpoints (Ollama, vLLM, Azure OpenAI, etc.).
86 """
87 super().__init__(
88 provider_name=AIProvider.OPENAI,
89 has_sdk=_has_openai,
90 sdk_package="openai",
91 default_model=DEFAULT_MODEL,
92 default_api_key_env=DEFAULT_API_KEY_ENV,
93 model=model,
94 api_key_env=api_key_env,
95 max_tokens=max_tokens,
96 base_url=base_url,
97 )
99 def _create_client(self, *, api_key: str) -> Any:
100 """Create the OpenAI SDK client.
102 Args:
103 api_key: The resolved API key.
105 Returns:
106 openai.OpenAI: The API client.
107 """
108 kwargs: dict[str, Any] = {"api_key": api_key}
109 if self._base_url:
110 kwargs["base_url"] = self._base_url
111 return openai.OpenAI(**kwargs)
113 def complete(
114 self,
115 prompt: str,
116 *,
117 system: str | None = None,
118 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
119 timeout: float = DEFAULT_TIMEOUT,
120 ) -> AIResponse:
121 """Generate a completion using GPT.
123 Args:
124 prompt: The user prompt.
125 system: Optional system prompt.
126 max_tokens: Maximum tokens to generate.
127 timeout: Request timeout in seconds.
129 Returns:
130 AIResponse: The model's response with usage metadata.
131 """
132 client = self._get_client()
133 # Per-call cap: the lower of the caller's request and the
134 # provider-level cap set at init time.
135 effective_max = min(max_tokens, self._max_tokens)
137 with self._map_errors():
138 messages: list[dict[str, str]] = []
139 if system:
140 messages.append({"role": "system", "content": system})
141 messages.append({"role": "user", "content": prompt})
143 response = client.chat.completions.create(
144 model=self._model,
145 messages=messages,
146 max_tokens=effective_max,
147 timeout=timeout,
148 )
150 content = response.choices[0].message.content or ""
152 input_tokens = 0
153 output_tokens = 0
154 if response.usage:
155 input_tokens = response.usage.prompt_tokens
156 output_tokens = response.usage.completion_tokens
158 cost = estimate_cost(self._model, input_tokens, output_tokens)
160 return AIResponse(
161 content=content,
162 model=self._model,
163 input_tokens=input_tokens,
164 output_tokens=output_tokens,
165 cost_estimate=cost,
166 provider=AIProvider.OPENAI,
167 )
169 def stream_complete(
170 self,
171 prompt: str,
172 *,
173 system: str | None = None,
174 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
175 timeout: float = DEFAULT_TIMEOUT,
176 ) -> AIStreamResult:
177 """Stream a completion from the OpenAI API token-by-token.
179 Args:
180 prompt: The user prompt.
181 system: Optional system prompt.
182 max_tokens: Maximum tokens to generate.
183 timeout: Request timeout in seconds.
185 Returns:
186 An AIStreamResult wrapping the token stream.
187 """
188 client = self._get_client()
189 effective_max = min(max_tokens, self._max_tokens)
191 messages: list[dict[str, str]] = []
192 if system:
193 messages.append({"role": "system", "content": system})
194 messages.append({"role": "user", "content": prompt})
196 logger.debug(
197 f"OpenAI stream request: model={self._model}, "
198 f"max_tokens={effective_max}",
199 )
201 final_response: list[AIResponse] = []
202 accumulated_text: list[str] = []
204 def _generate() -> Iterator[str]:
205 with self._map_errors():
206 stream = client.chat.completions.create(
207 model=self._model,
208 messages=messages,
209 max_tokens=effective_max,
210 timeout=timeout,
211 stream=True,
212 stream_options={"include_usage": True},
213 )
215 input_tokens = 0
216 output_tokens = 0
218 for chunk in stream:
219 if chunk.choices and chunk.choices[0].delta.content:
220 text = chunk.choices[0].delta.content
221 accumulated_text.append(text)
222 yield text
223 if chunk.usage:
224 input_tokens = chunk.usage.prompt_tokens
225 output_tokens = chunk.usage.completion_tokens
227 cost = estimate_cost(self._model, input_tokens, output_tokens)
228 final_response.append(
229 AIResponse(
230 content="".join(accumulated_text),
231 model=self._model,
232 input_tokens=input_tokens,
233 output_tokens=output_tokens,
234 cost_estimate=cost,
235 provider=AIProvider.OPENAI,
236 ),
237 )
239 def _on_done() -> AIResponse:
240 if not final_response:
241 raise AIProviderError(
242 "OpenAI stream was not fully consumed",
243 )
244 return final_response[0]
246 return AIStreamResult(
247 _chunks=_generate(),
248 _on_done=_on_done,
249 )