Coverage for lintro / ai / fallback.py: 96%

48 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Runtime model fallback chain for AI providers. 

2 

3When a primary model fails with a retryable error, the fallback chain 

4tries each configured fallback model in order before giving up. 

5Authentication errors are never retried. 

6""" 

7 

8from __future__ import annotations 

9 

10import threading 

11from collections.abc import Callable 

12from typing import TypeVar 

13 

14from loguru import logger 

15 

16from lintro.ai.exceptions import ( 

17 AIAuthenticationError, 

18 AIProviderError, 

19 AIRateLimitError, 

20) 

21from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider 

22 

23# Serializes model_name mutations across concurrent fallback calls 

24# sharing the same provider instance. 

25_model_lock = threading.Lock() 

26 

27_T = TypeVar("_T") 

28 

29 

30def _with_fallback( 

31 provider: BaseAIProvider, 

32 attempt_fn: Callable[[str, str | None, int, float], _T], 

33 prompt: str, 

34 *, 

35 fallback_models: list[str] | None = None, 

36 system: str | None = None, 

37 max_tokens: int = 1024, 

38 timeout: float = 60.0, 

39 label_prefix: str = "Fallback chain", 

40) -> _T: 

41 """Run *attempt_fn* with automatic model fallback. 

42 

43 Tries the provider's current (primary) model first. On 

44 ``AIProviderError`` or ``AIRateLimitError``, swaps to each fallback 

45 model in order and retries. ``AIAuthenticationError`` is never 

46 retried — it propagates immediately. 

47 

48 **Mutation contract:** the provider's ``model_name`` is temporarily 

49 mutated to each fallback model during retries, but is always restored 

50 to its original value — even on success, on error, or if an 

51 ``AIAuthenticationError`` short-circuits the chain. 

52 

53 Args: 

54 provider: AI provider instance whose ``model_name`` may be 

55 temporarily mutated during retries. 

56 attempt_fn: Callable with signature 

57 ``(prompt, system, max_tokens, timeout) -> T``. Typically 

58 ``provider.complete`` or ``provider.stream_complete``. 

59 prompt: The user prompt. 

60 fallback_models: Ordered list of fallback model identifiers. 

61 When empty or ``None``, behaves identically to a single 

62 call to *attempt_fn*. 

63 system: Optional system prompt. 

64 max_tokens: Maximum tokens to generate. 

65 timeout: Request timeout in seconds. 

66 label_prefix: Prefix for debug log messages. 

67 

68 Returns: 

69 The first successful result from *attempt_fn*. 

70 

71 Raises: 

72 AIAuthenticationError: Immediately on authentication failure. 

73 AIProviderError: If the primary model and all fallbacks fail. 

74 AIRateLimitError: If the primary model and all fallbacks fail 

75 with rate-limit errors. 

76 """ 

77 models_to_try: list[str | None] = [None] # None = keep current model 

78 if fallback_models: 

79 models_to_try.extend(fallback_models) 

80 

81 last_error: Exception | None = None 

82 

83 # Lock serializes model_name access across concurrent threads 

84 # sharing the same provider instance. 

85 with _model_lock: 

86 original_model = provider.model_name 

87 

88 try: 

89 for idx, model in enumerate(models_to_try): 

90 try: 

91 # Hold lock from model assignment through logging and 

92 # the provider call to prevent TOCTOU races where 

93 # another thread swaps model_name between set and use. 

94 with _model_lock: 

95 if model is not None: 

96 provider.model_name = model 

97 label = provider.model_name 

98 logger.debug( 

99 "{}: trying model '{}' (attempt {}/{})", 

100 label_prefix, 

101 label, 

102 idx + 1, 

103 len(models_to_try), 

104 ) 

105 return attempt_fn(prompt, system, max_tokens, timeout) 

106 except AIAuthenticationError: 

107 # Never retry auth errors — restore and propagate. 

108 raise 

109 except (AIProviderError, AIRateLimitError) as exc: 

110 last_error = exc 

111 if idx < len(models_to_try) - 1: 

112 next_model = models_to_try[idx + 1] 

113 logger.debug( 

114 "{}: model '{}' failed ({}), falling back to '{}'", 

115 label_prefix, 

116 label, 

117 exc, 

118 next_model, 

119 ) 

120 else: 

121 logger.debug( 

122 "{}: model '{}' failed ({}), no more fallbacks", 

123 label_prefix, 

124 label, 

125 exc, 

126 ) 

127 finally: 

128 with _model_lock: 

129 provider.model_name = original_model 

130 

131 # All models exhausted — wrap the last error so pydoclint can 

132 # statically verify the Raises section. 

133 if isinstance(last_error, AIRateLimitError): 

134 raise AIRateLimitError(str(last_error)) from last_error 

135 if isinstance(last_error, AIProviderError): 

136 raise AIProviderError(str(last_error)) from last_error 

137 raise AIProviderError(f"{label_prefix} exhausted") 

138 

139 

140def complete_with_fallback( 

141 provider: BaseAIProvider, 

142 prompt: str, 

143 *, 

144 fallback_models: list[str] | None = None, 

145 system: str | None = None, 

146 max_tokens: int = 1024, 

147 timeout: float = 60.0, 

148) -> AIResponse: 

149 """Call ``provider.complete()`` with automatic model fallback. 

150 

151 Tries the provider's current (primary) model first. On 

152 ``AIProviderError`` or ``AIRateLimitError``, swaps to each fallback 

153 model in order and retries. ``AIAuthenticationError`` is never 

154 retried — it propagates immediately. 

155 

156 After all attempts (successful or not), the provider's ``model_name`` 

157 is restored to the original value. 

158 

159 Args: 

160 provider: AI provider instance. 

161 prompt: The user prompt. 

162 fallback_models: Ordered list of fallback model identifiers. 

163 When empty or ``None``, behaves identically to a plain 

164 ``provider.complete()`` call. 

165 system: Optional system prompt. 

166 max_tokens: Maximum tokens to generate. 

167 timeout: Request timeout in seconds. 

168 

169 Returns: 

170 The first successful ``AIResponse``. 

171 """ 

172 

173 def _attempt( 

174 prompt: str, 

175 system: str | None, 

176 max_tokens: int, 

177 timeout: float, 

178 ) -> AIResponse: 

179 return provider.complete( 

180 prompt, 

181 system=system, 

182 max_tokens=max_tokens, 

183 timeout=timeout, 

184 ) 

185 

186 return _with_fallback( 

187 provider, 

188 _attempt, 

189 prompt, 

190 fallback_models=fallback_models, 

191 system=system, 

192 max_tokens=max_tokens, 

193 timeout=timeout, 

194 label_prefix="Fallback chain", 

195 ) 

196 

197 

198def stream_complete_with_fallback( 

199 provider: BaseAIProvider, 

200 prompt: str, 

201 *, 

202 fallback_models: list[str] | None = None, 

203 system: str | None = None, 

204 max_tokens: int = 1024, 

205 timeout: float = 60.0, 

206) -> AIStreamResult: 

207 """Call ``provider.stream_complete()`` with automatic model fallback. 

208 

209 Same fallback logic as ``complete_with_fallback`` but returns a 

210 streaming result. Fallback applies at stream *creation* time only — 

211 once a provider begins yielding tokens, mid-stream failures are 

212 not retried because partial content has already been consumed. 

213 

214 Args: 

215 provider: AI provider instance. 

216 prompt: The user prompt. 

217 fallback_models: Ordered list of fallback model identifiers. 

218 system: Optional system prompt. 

219 max_tokens: Maximum tokens to generate. 

220 timeout: Request timeout in seconds. 

221 

222 Returns: 

223 The first successful ``AIStreamResult``. 

224 """ 

225 

226 def _attempt( 

227 prompt: str, 

228 system: str | None, 

229 max_tokens: int, 

230 timeout: float, 

231 ) -> AIStreamResult: 

232 return provider.stream_complete( 

233 prompt, 

234 system=system, 

235 max_tokens=max_tokens, 

236 timeout=timeout, 

237 ) 

238 

239 return _with_fallback( 

240 provider, 

241 _attempt, 

242 prompt, 

243 fallback_models=fallback_models, 

244 system=system, 

245 max_tokens=max_tokens, 

246 timeout=timeout, 

247 label_prefix="Stream fallback", 

248 )