Coverage for tests / unit / ai / test_fallback.py: 100%

111 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Tests for the runtime model fallback chain.""" 

2 

3from __future__ import annotations 

4 

5from unittest.mock import MagicMock 

6 

7import pytest 

8from assertpy import assert_that 

9 

10from lintro.ai.exceptions import ( 

11 AIAuthenticationError, 

12 AIProviderError, 

13 AIRateLimitError, 

14) 

15from lintro.ai.fallback import complete_with_fallback 

16from lintro.ai.providers.base import AIResponse 

17 

18 

19def _make_provider(model: str = "primary-model") -> MagicMock: 

20 """Create a mock provider with a configurable model_name attribute.""" 

21 provider = MagicMock() 

22 provider.model_name = model 

23 return provider 

24 

25 

26def _ok_response(model: str = "primary-model") -> AIResponse: 

27 """Create a successful mock AI response.""" 

28 return AIResponse( 

29 content="ok", 

30 model=model, 

31 input_tokens=10, 

32 output_tokens=5, 

33 cost_estimate=0.001, 

34 provider="mock", 

35 ) 

36 

37 

38# -- TestCompleteWithFallbackPrimarySuccess: Primary model succeeds on first try. 

39 

40 

41def test_returns_response_without_fallback() -> None: 

42 """Return response when primary succeeds without fallback.""" 

43 provider = _make_provider() 

44 provider.complete.return_value = _ok_response() 

45 

46 result = complete_with_fallback(provider, "hello") 

47 

48 assert_that(result.content).is_equal_to("ok") 

49 provider.complete.assert_called_once() 

50 

51 

52def test_returns_response_with_empty_fallback_list() -> None: 

53 """Return response when fallback list is empty.""" 

54 provider = _make_provider() 

55 provider.complete.return_value = _ok_response() 

56 

57 result = complete_with_fallback(provider, "hello", fallback_models=[]) 

58 

59 assert_that(result.content).is_equal_to("ok") 

60 provider.complete.assert_called_once() 

61 

62 

63def test_does_not_try_fallbacks_on_success() -> None: 

64 """Skip fallback models when primary succeeds.""" 

65 provider = _make_provider() 

66 provider.complete.return_value = _ok_response() 

67 

68 result = complete_with_fallback( 

69 provider, 

70 "hello", 

71 fallback_models=["fb-1", "fb-2"], 

72 ) 

73 

74 assert_that(result.content).is_equal_to("ok") 

75 assert_that(provider.complete.call_count).is_equal_to(1) 

76 # Model should be restored 

77 assert_that(provider.model_name).is_equal_to("primary-model") 

78 

79 

80# -- TestCompleteWithFallbackChain: Primary fails, fallback models are tried in order. 

81 

82 

83def test_falls_back_on_provider_error() -> None: 

84 """Fall back to next model on provider error.""" 

85 provider = _make_provider() 

86 provider.complete.side_effect = [ 

87 AIProviderError("primary down"), 

88 _ok_response("fb-1"), 

89 ] 

90 

91 result = complete_with_fallback( 

92 provider, 

93 "hello", 

94 fallback_models=["fb-1"], 

95 ) 

96 

97 assert_that(result.content).is_equal_to("ok") 

98 assert_that(provider.complete.call_count).is_equal_to(2) 

99 assert_that(provider.model_name).is_equal_to("primary-model") # restored 

100 

101 

102def test_falls_back_on_rate_limit_error() -> None: 

103 """Fall back to next model on rate limit error.""" 

104 provider = _make_provider() 

105 provider.complete.side_effect = [ 

106 AIRateLimitError("rate limited"), 

107 _ok_response("fb-1"), 

108 ] 

109 

110 result = complete_with_fallback( 

111 provider, 

112 "hello", 

113 fallback_models=["fb-1"], 

114 ) 

115 

116 assert_that(result.content).is_equal_to("ok") 

117 assert_that(provider.complete.call_count).is_equal_to(2) 

118 

119 

120def test_tries_multiple_fallbacks_in_order() -> None: 

121 """Try fallback models sequentially until one succeeds.""" 

122 provider = _make_provider() 

123 provider.complete.side_effect = [ 

124 AIProviderError("primary down"), 

125 AIRateLimitError("fb-1 rate limited"), 

126 _ok_response("fb-2"), 

127 ] 

128 

129 result = complete_with_fallback( 

130 provider, 

131 "hello", 

132 fallback_models=["fb-1", "fb-2"], 

133 ) 

134 

135 assert_that(result.content).is_equal_to("ok") 

136 assert_that(provider.complete.call_count).is_equal_to(3) 

137 assert_that(provider.model_name).is_equal_to("primary-model") # restored 

138 

139 

140def test_model_is_swapped_for_each_fallback() -> None: 

141 """Verify the provider's model_name is set to each fallback in turn.""" 

142 provider = _make_provider("primary") 

143 models_seen: list[str] = [] 

144 

145 def capture_model(*args, **kwargs): 

146 """Record the current model and fail until the third call.""" 

147 models_seen.append(provider.model_name) 

148 if len(models_seen) < 3: 

149 raise AIProviderError("fail") 

150 return _ok_response(provider.model_name) 

151 

152 provider.complete.side_effect = capture_model 

153 

154 complete_with_fallback( 

155 provider, 

156 "hello", 

157 fallback_models=["fb-1", "fb-2"], 

158 ) 

159 

160 assert_that(models_seen).is_equal_to(["primary", "fb-1", "fb-2"]) 

161 assert_that(provider.model_name).is_equal_to("primary") # restored 

162 

163 

164# -- TestCompleteWithFallbackAllFail: All models fail -- last error is raised. 

165 

166 

167def test_raises_last_error_when_all_fail() -> None: 

168 """Raise the last error when all models fail.""" 

169 provider = _make_provider() 

170 provider.complete.side_effect = [ 

171 AIProviderError("primary down"), 

172 AIRateLimitError("fb-1 limited"), 

173 AIProviderError("fb-2 down"), 

174 ] 

175 

176 with pytest.raises(AIProviderError, match="fb-2 down"): 

177 complete_with_fallback( 

178 provider, 

179 "hello", 

180 fallback_models=["fb-1", "fb-2"], 

181 ) 

182 

183 assert_that(provider.model_name).is_equal_to("primary-model") # restored 

184 

185 

186def test_raises_primary_error_when_no_fallbacks() -> None: 

187 """Raise the primary error when no fallbacks are configured.""" 

188 provider = _make_provider() 

189 provider.complete.side_effect = AIProviderError("primary down") 

190 

191 with pytest.raises(AIProviderError, match="primary down"): 

192 complete_with_fallback(provider, "hello") 

193 

194 

195# -- TestCompleteWithFallbackAuthError: AIAuthenticationError is never retried. 

196 

197 

198def test_auth_error_propagates_immediately() -> None: 

199 """Propagate authentication error without trying fallbacks.""" 

200 provider = _make_provider() 

201 provider.complete.side_effect = AIAuthenticationError("bad key") 

202 

203 with pytest.raises(AIAuthenticationError, match="bad key"): 

204 complete_with_fallback( 

205 provider, 

206 "hello", 

207 fallback_models=["fb-1", "fb-2"], 

208 ) 

209 

210 # Only one call -- no fallback attempted 

211 assert_that(provider.complete.call_count).is_equal_to(1) 

212 assert_that(provider.model_name).is_equal_to("primary-model") # restored 

213 

214 

215def test_auth_error_on_fallback_propagates() -> None: 

216 """Propagate authentication error raised by a fallback model.""" 

217 provider = _make_provider() 

218 provider.complete.side_effect = [ 

219 AIProviderError("primary down"), 

220 AIAuthenticationError("bad key on fallback"), 

221 ] 

222 

223 with pytest.raises(AIAuthenticationError, match="bad key on fallback"): 

224 complete_with_fallback( 

225 provider, 

226 "hello", 

227 fallback_models=["fb-1"], 

228 ) 

229 

230 assert_that(provider.complete.call_count).is_equal_to(2) 

231 assert_that(provider.model_name).is_equal_to("primary-model") # restored 

232 

233 

234# -- TestCompleteWithFallbackModelRestoration: model_name restored. 

235 

236 

237def test_model_restored_on_auth_error() -> None: 

238 """Restore original model after authentication error.""" 

239 provider = _make_provider("orig") 

240 provider.complete.side_effect = AIAuthenticationError("err") 

241 

242 with pytest.raises(AIAuthenticationError): 

243 complete_with_fallback( 

244 provider, 

245 "hello", 

246 fallback_models=["x"], 

247 ) 

248 

249 assert_that(provider.model_name).is_equal_to("orig") 

250 

251 

252def test_model_restored_on_provider_error() -> None: 

253 """Restore original model after provider error.""" 

254 provider = _make_provider("orig") 

255 provider.complete.side_effect = AIProviderError("err") 

256 

257 with pytest.raises(AIProviderError): 

258 complete_with_fallback(provider, "hello") 

259 

260 assert_that(provider.model_name).is_equal_to("orig") 

261 

262 

263def test_model_restored_on_success() -> None: 

264 """Restore original model after successful fallback.""" 

265 provider = _make_provider("orig") 

266 provider.complete.side_effect = [ 

267 AIProviderError("fail"), 

268 _ok_response("fb"), 

269 ] 

270 

271 complete_with_fallback(provider, "hello", fallback_models=["fb"]) 

272 

273 assert_that(provider.model_name).is_equal_to("orig") 

274 

275 

276# -- TestCompleteWithFallbackKwargsPassthrough: kwargs forwarded. - 

277 

278 

279def test_forwards_all_kwargs() -> None: 

280 """Forward all keyword arguments to provider.complete.""" 

281 provider = _make_provider() 

282 provider.complete.return_value = _ok_response() 

283 

284 complete_with_fallback( 

285 provider, 

286 "hello", 

287 system="sys", 

288 max_tokens=512, 

289 timeout=30.0, 

290 ) 

291 

292 provider.complete.assert_called_once_with( 

293 "hello", 

294 system="sys", 

295 max_tokens=512, 

296 timeout=30.0, 

297 )