Coverage for tests / unit / ai / test_fallback.py: 100%
111 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for the runtime model fallback chain."""
3from __future__ import annotations
5from unittest.mock import MagicMock
7import pytest
8from assertpy import assert_that
10from lintro.ai.exceptions import (
11 AIAuthenticationError,
12 AIProviderError,
13 AIRateLimitError,
14)
15from lintro.ai.fallback import complete_with_fallback
16from lintro.ai.providers.base import AIResponse
19def _make_provider(model: str = "primary-model") -> MagicMock:
20 """Create a mock provider with a configurable model_name attribute."""
21 provider = MagicMock()
22 provider.model_name = model
23 return provider
26def _ok_response(model: str = "primary-model") -> AIResponse:
27 """Create a successful mock AI response."""
28 return AIResponse(
29 content="ok",
30 model=model,
31 input_tokens=10,
32 output_tokens=5,
33 cost_estimate=0.001,
34 provider="mock",
35 )
38# -- TestCompleteWithFallbackPrimarySuccess: Primary model succeeds on first try.
41def test_returns_response_without_fallback() -> None:
42 """Return response when primary succeeds without fallback."""
43 provider = _make_provider()
44 provider.complete.return_value = _ok_response()
46 result = complete_with_fallback(provider, "hello")
48 assert_that(result.content).is_equal_to("ok")
49 provider.complete.assert_called_once()
52def test_returns_response_with_empty_fallback_list() -> None:
53 """Return response when fallback list is empty."""
54 provider = _make_provider()
55 provider.complete.return_value = _ok_response()
57 result = complete_with_fallback(provider, "hello", fallback_models=[])
59 assert_that(result.content).is_equal_to("ok")
60 provider.complete.assert_called_once()
63def test_does_not_try_fallbacks_on_success() -> None:
64 """Skip fallback models when primary succeeds."""
65 provider = _make_provider()
66 provider.complete.return_value = _ok_response()
68 result = complete_with_fallback(
69 provider,
70 "hello",
71 fallback_models=["fb-1", "fb-2"],
72 )
74 assert_that(result.content).is_equal_to("ok")
75 assert_that(provider.complete.call_count).is_equal_to(1)
76 # Model should be restored
77 assert_that(provider.model_name).is_equal_to("primary-model")
80# -- TestCompleteWithFallbackChain: Primary fails, fallback models are tried in order.
83def test_falls_back_on_provider_error() -> None:
84 """Fall back to next model on provider error."""
85 provider = _make_provider()
86 provider.complete.side_effect = [
87 AIProviderError("primary down"),
88 _ok_response("fb-1"),
89 ]
91 result = complete_with_fallback(
92 provider,
93 "hello",
94 fallback_models=["fb-1"],
95 )
97 assert_that(result.content).is_equal_to("ok")
98 assert_that(provider.complete.call_count).is_equal_to(2)
99 assert_that(provider.model_name).is_equal_to("primary-model") # restored
102def test_falls_back_on_rate_limit_error() -> None:
103 """Fall back to next model on rate limit error."""
104 provider = _make_provider()
105 provider.complete.side_effect = [
106 AIRateLimitError("rate limited"),
107 _ok_response("fb-1"),
108 ]
110 result = complete_with_fallback(
111 provider,
112 "hello",
113 fallback_models=["fb-1"],
114 )
116 assert_that(result.content).is_equal_to("ok")
117 assert_that(provider.complete.call_count).is_equal_to(2)
120def test_tries_multiple_fallbacks_in_order() -> None:
121 """Try fallback models sequentially until one succeeds."""
122 provider = _make_provider()
123 provider.complete.side_effect = [
124 AIProviderError("primary down"),
125 AIRateLimitError("fb-1 rate limited"),
126 _ok_response("fb-2"),
127 ]
129 result = complete_with_fallback(
130 provider,
131 "hello",
132 fallback_models=["fb-1", "fb-2"],
133 )
135 assert_that(result.content).is_equal_to("ok")
136 assert_that(provider.complete.call_count).is_equal_to(3)
137 assert_that(provider.model_name).is_equal_to("primary-model") # restored
140def test_model_is_swapped_for_each_fallback() -> None:
141 """Verify the provider's model_name is set to each fallback in turn."""
142 provider = _make_provider("primary")
143 models_seen: list[str] = []
145 def capture_model(*args, **kwargs):
146 """Record the current model and fail until the third call."""
147 models_seen.append(provider.model_name)
148 if len(models_seen) < 3:
149 raise AIProviderError("fail")
150 return _ok_response(provider.model_name)
152 provider.complete.side_effect = capture_model
154 complete_with_fallback(
155 provider,
156 "hello",
157 fallback_models=["fb-1", "fb-2"],
158 )
160 assert_that(models_seen).is_equal_to(["primary", "fb-1", "fb-2"])
161 assert_that(provider.model_name).is_equal_to("primary") # restored
164# -- TestCompleteWithFallbackAllFail: All models fail -- last error is raised.
167def test_raises_last_error_when_all_fail() -> None:
168 """Raise the last error when all models fail."""
169 provider = _make_provider()
170 provider.complete.side_effect = [
171 AIProviderError("primary down"),
172 AIRateLimitError("fb-1 limited"),
173 AIProviderError("fb-2 down"),
174 ]
176 with pytest.raises(AIProviderError, match="fb-2 down"):
177 complete_with_fallback(
178 provider,
179 "hello",
180 fallback_models=["fb-1", "fb-2"],
181 )
183 assert_that(provider.model_name).is_equal_to("primary-model") # restored
186def test_raises_primary_error_when_no_fallbacks() -> None:
187 """Raise the primary error when no fallbacks are configured."""
188 provider = _make_provider()
189 provider.complete.side_effect = AIProviderError("primary down")
191 with pytest.raises(AIProviderError, match="primary down"):
192 complete_with_fallback(provider, "hello")
195# -- TestCompleteWithFallbackAuthError: AIAuthenticationError is never retried.
198def test_auth_error_propagates_immediately() -> None:
199 """Propagate authentication error without trying fallbacks."""
200 provider = _make_provider()
201 provider.complete.side_effect = AIAuthenticationError("bad key")
203 with pytest.raises(AIAuthenticationError, match="bad key"):
204 complete_with_fallback(
205 provider,
206 "hello",
207 fallback_models=["fb-1", "fb-2"],
208 )
210 # Only one call -- no fallback attempted
211 assert_that(provider.complete.call_count).is_equal_to(1)
212 assert_that(provider.model_name).is_equal_to("primary-model") # restored
215def test_auth_error_on_fallback_propagates() -> None:
216 """Propagate authentication error raised by a fallback model."""
217 provider = _make_provider()
218 provider.complete.side_effect = [
219 AIProviderError("primary down"),
220 AIAuthenticationError("bad key on fallback"),
221 ]
223 with pytest.raises(AIAuthenticationError, match="bad key on fallback"):
224 complete_with_fallback(
225 provider,
226 "hello",
227 fallback_models=["fb-1"],
228 )
230 assert_that(provider.complete.call_count).is_equal_to(2)
231 assert_that(provider.model_name).is_equal_to("primary-model") # restored
234# -- TestCompleteWithFallbackModelRestoration: model_name restored.
237def test_model_restored_on_auth_error() -> None:
238 """Restore original model after authentication error."""
239 provider = _make_provider("orig")
240 provider.complete.side_effect = AIAuthenticationError("err")
242 with pytest.raises(AIAuthenticationError):
243 complete_with_fallback(
244 provider,
245 "hello",
246 fallback_models=["x"],
247 )
249 assert_that(provider.model_name).is_equal_to("orig")
252def test_model_restored_on_provider_error() -> None:
253 """Restore original model after provider error."""
254 provider = _make_provider("orig")
255 provider.complete.side_effect = AIProviderError("err")
257 with pytest.raises(AIProviderError):
258 complete_with_fallback(provider, "hello")
260 assert_that(provider.model_name).is_equal_to("orig")
263def test_model_restored_on_success() -> None:
264 """Restore original model after successful fallback."""
265 provider = _make_provider("orig")
266 provider.complete.side_effect = [
267 AIProviderError("fail"),
268 _ok_response("fb"),
269 ]
271 complete_with_fallback(provider, "hello", fallback_models=["fb"])
273 assert_that(provider.model_name).is_equal_to("orig")
276# -- TestCompleteWithFallbackKwargsPassthrough: kwargs forwarded. -
279def test_forwards_all_kwargs() -> None:
280 """Forward all keyword arguments to provider.complete."""
281 provider = _make_provider()
282 provider.complete.return_value = _ok_response()
284 complete_with_fallback(
285 provider,
286 "hello",
287 system="sys",
288 max_tokens=512,
289 timeout=30.0,
290 )
292 provider.complete.assert_called_once_with(
293 "hello",
294 system="sys",
295 max_tokens=512,
296 timeout=30.0,
297 )