Coverage for lintro / ai / fallback.py: 96%
48 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Runtime model fallback chain for AI providers.
3When a primary model fails with a retryable error, the fallback chain
4tries each configured fallback model in order before giving up.
5Authentication errors are never retried.
6"""
8from __future__ import annotations
10import threading
11from collections.abc import Callable
12from typing import TypeVar
14from loguru import logger
16from lintro.ai.exceptions import (
17 AIAuthenticationError,
18 AIProviderError,
19 AIRateLimitError,
20)
21from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider
23# Serializes model_name mutations across concurrent fallback calls
24# sharing the same provider instance.
25_model_lock = threading.Lock()
27_T = TypeVar("_T")
30def _with_fallback(
31 provider: BaseAIProvider,
32 attempt_fn: Callable[[str, str | None, int, float], _T],
33 prompt: str,
34 *,
35 fallback_models: list[str] | None = None,
36 system: str | None = None,
37 max_tokens: int = 1024,
38 timeout: float = 60.0,
39 label_prefix: str = "Fallback chain",
40) -> _T:
41 """Run *attempt_fn* with automatic model fallback.
43 Tries the provider's current (primary) model first. On
44 ``AIProviderError`` or ``AIRateLimitError``, swaps to each fallback
45 model in order and retries. ``AIAuthenticationError`` is never
46 retried — it propagates immediately.
48 **Mutation contract:** the provider's ``model_name`` is temporarily
49 mutated to each fallback model during retries, but is always restored
50 to its original value — even on success, on error, or if an
51 ``AIAuthenticationError`` short-circuits the chain.
53 Args:
54 provider: AI provider instance whose ``model_name`` may be
55 temporarily mutated during retries.
56 attempt_fn: Callable with signature
57 ``(prompt, system, max_tokens, timeout) -> T``. Typically
58 ``provider.complete`` or ``provider.stream_complete``.
59 prompt: The user prompt.
60 fallback_models: Ordered list of fallback model identifiers.
61 When empty or ``None``, behaves identically to a single
62 call to *attempt_fn*.
63 system: Optional system prompt.
64 max_tokens: Maximum tokens to generate.
65 timeout: Request timeout in seconds.
66 label_prefix: Prefix for debug log messages.
68 Returns:
69 The first successful result from *attempt_fn*.
71 Raises:
72 AIAuthenticationError: Immediately on authentication failure.
73 AIProviderError: If the primary model and all fallbacks fail.
74 AIRateLimitError: If the primary model and all fallbacks fail
75 with rate-limit errors.
76 """
77 models_to_try: list[str | None] = [None] # None = keep current model
78 if fallback_models:
79 models_to_try.extend(fallback_models)
81 last_error: Exception | None = None
83 # Lock serializes model_name access across concurrent threads
84 # sharing the same provider instance.
85 with _model_lock:
86 original_model = provider.model_name
88 try:
89 for idx, model in enumerate(models_to_try):
90 try:
91 # Hold lock from model assignment through logging and
92 # the provider call to prevent TOCTOU races where
93 # another thread swaps model_name between set and use.
94 with _model_lock:
95 if model is not None:
96 provider.model_name = model
97 label = provider.model_name
98 logger.debug(
99 "{}: trying model '{}' (attempt {}/{})",
100 label_prefix,
101 label,
102 idx + 1,
103 len(models_to_try),
104 )
105 return attempt_fn(prompt, system, max_tokens, timeout)
106 except AIAuthenticationError:
107 # Never retry auth errors — restore and propagate.
108 raise
109 except (AIProviderError, AIRateLimitError) as exc:
110 last_error = exc
111 if idx < len(models_to_try) - 1:
112 next_model = models_to_try[idx + 1]
113 logger.debug(
114 "{}: model '{}' failed ({}), falling back to '{}'",
115 label_prefix,
116 label,
117 exc,
118 next_model,
119 )
120 else:
121 logger.debug(
122 "{}: model '{}' failed ({}), no more fallbacks",
123 label_prefix,
124 label,
125 exc,
126 )
127 finally:
128 with _model_lock:
129 provider.model_name = original_model
131 # All models exhausted — wrap the last error so pydoclint can
132 # statically verify the Raises section.
133 if isinstance(last_error, AIRateLimitError):
134 raise AIRateLimitError(str(last_error)) from last_error
135 if isinstance(last_error, AIProviderError):
136 raise AIProviderError(str(last_error)) from last_error
137 raise AIProviderError(f"{label_prefix} exhausted")
140def complete_with_fallback(
141 provider: BaseAIProvider,
142 prompt: str,
143 *,
144 fallback_models: list[str] | None = None,
145 system: str | None = None,
146 max_tokens: int = 1024,
147 timeout: float = 60.0,
148) -> AIResponse:
149 """Call ``provider.complete()`` with automatic model fallback.
151 Tries the provider's current (primary) model first. On
152 ``AIProviderError`` or ``AIRateLimitError``, swaps to each fallback
153 model in order and retries. ``AIAuthenticationError`` is never
154 retried — it propagates immediately.
156 After all attempts (successful or not), the provider's ``model_name``
157 is restored to the original value.
159 Args:
160 provider: AI provider instance.
161 prompt: The user prompt.
162 fallback_models: Ordered list of fallback model identifiers.
163 When empty or ``None``, behaves identically to a plain
164 ``provider.complete()`` call.
165 system: Optional system prompt.
166 max_tokens: Maximum tokens to generate.
167 timeout: Request timeout in seconds.
169 Returns:
170 The first successful ``AIResponse``.
171 """
173 def _attempt(
174 prompt: str,
175 system: str | None,
176 max_tokens: int,
177 timeout: float,
178 ) -> AIResponse:
179 return provider.complete(
180 prompt,
181 system=system,
182 max_tokens=max_tokens,
183 timeout=timeout,
184 )
186 return _with_fallback(
187 provider,
188 _attempt,
189 prompt,
190 fallback_models=fallback_models,
191 system=system,
192 max_tokens=max_tokens,
193 timeout=timeout,
194 label_prefix="Fallback chain",
195 )
198def stream_complete_with_fallback(
199 provider: BaseAIProvider,
200 prompt: str,
201 *,
202 fallback_models: list[str] | None = None,
203 system: str | None = None,
204 max_tokens: int = 1024,
205 timeout: float = 60.0,
206) -> AIStreamResult:
207 """Call ``provider.stream_complete()`` with automatic model fallback.
209 Same fallback logic as ``complete_with_fallback`` but returns a
210 streaming result. Fallback applies at stream *creation* time only —
211 once a provider begins yielding tokens, mid-stream failures are
212 not retried because partial content has already been consumed.
214 Args:
215 provider: AI provider instance.
216 prompt: The user prompt.
217 fallback_models: Ordered list of fallback model identifiers.
218 system: Optional system prompt.
219 max_tokens: Maximum tokens to generate.
220 timeout: Request timeout in seconds.
222 Returns:
223 The first successful ``AIStreamResult``.
224 """
226 def _attempt(
227 prompt: str,
228 system: str | None,
229 max_tokens: int,
230 timeout: float,
231 ) -> AIStreamResult:
232 return provider.stream_complete(
233 prompt,
234 system=system,
235 max_tokens=max_tokens,
236 timeout=timeout,
237 )
239 return _with_fallback(
240 provider,
241 _attempt,
242 prompt,
243 fallback_models=fallback_models,
244 system=system,
245 max_tokens=max_tokens,
246 timeout=timeout,
247 label_prefix="Stream fallback",
248 )