Coverage for tests / unit / ai / test_stream_fallback.py: 95%
74 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for stream_complete_with_fallback()."""
3from __future__ import annotations
5import pytest
6from assertpy import assert_that
8from lintro.ai.exceptions import AIProviderError
9from lintro.ai.fallback import stream_complete_with_fallback
10from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider
11from lintro.ai.providers.constants import DEFAULT_PER_CALL_MAX_TOKENS, DEFAULT_TIMEOUT
14def _make_response(content: str = "ok", provider: str = "stub") -> AIResponse:
15 return AIResponse(
16 content=content,
17 model="m",
18 input_tokens=1,
19 output_tokens=1,
20 cost_estimate=0.0,
21 provider=provider,
22 )
25class _SuccessProvider(BaseAIProvider):
26 """Provider that always succeeds."""
28 def __init__(self, name: str = "success") -> None:
29 self._name = name
30 self._provider_name = name
31 self._has_sdk = True
32 self._model = "test-model"
33 self._api_key_env = "TEST_KEY"
34 self._max_tokens = DEFAULT_PER_CALL_MAX_TOKENS
35 self._base_url = None
36 self._client = "fake"
38 def _create_client(self, *, api_key: str) -> object:
39 return "fake"
41 def complete(
42 self,
43 prompt: str,
44 *,
45 system: str | None = None,
46 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
47 timeout: float = DEFAULT_TIMEOUT,
48 ) -> AIResponse:
49 return _make_response(content=f"from-{self._name}", provider=self._name)
51 def stream_complete(
52 self,
53 prompt: str,
54 *,
55 system: str | None = None,
56 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
57 timeout: float = DEFAULT_TIMEOUT,
58 ) -> AIStreamResult:
59 resp = _make_response(content="", provider=self._name)
60 return AIStreamResult(
61 _chunks=iter([f"chunk-{self._name}"]),
62 _on_done=lambda: resp,
63 )
66class _FailingProvider(BaseAIProvider):
67 """Provider that always raises."""
69 def __init__(self) -> None:
70 self._provider_name = "failing"
71 self._has_sdk = True
72 self._model = "fail-model"
73 self._api_key_env = "FAIL_KEY"
74 self._max_tokens = DEFAULT_PER_CALL_MAX_TOKENS
75 self._base_url = None
76 self._client = "fake"
78 def _create_client(self, *, api_key: str) -> object:
79 return "fake"
81 def complete(
82 self,
83 prompt: str,
84 *,
85 system: str | None = None,
86 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
87 timeout: float = DEFAULT_TIMEOUT,
88 ) -> AIResponse:
89 raise AIProviderError("provider down")
91 def stream_complete(
92 self,
93 prompt: str,
94 *,
95 system: str | None = None,
96 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
97 timeout: float = DEFAULT_TIMEOUT,
98 ) -> AIStreamResult:
99 raise AIProviderError("stream provider down")
102def test_stream_fallback_returns_first_success() -> None:
103 """Return the stream from the first working provider."""
104 provider = _SuccessProvider("primary")
105 result = stream_complete_with_fallback(provider, "prompt")
107 chunks = list(result)
108 assert_that(chunks).is_equal_to(["chunk-primary"])
111def test_stream_fallback_tries_fallback_models() -> None:
112 """Falls back to alternate model when primary fails."""
113 calls: list[str] = []
115 class _ModelTrackingProvider(_SuccessProvider):
116 def stream_complete(
117 self,
118 prompt: str,
119 *,
120 system: str | None = None,
121 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
122 timeout: float = DEFAULT_TIMEOUT,
123 ) -> AIStreamResult:
124 calls.append(self._model)
125 if self._model == "test-model":
126 raise AIProviderError("primary failed")
127 return super().stream_complete(
128 prompt,
129 system=system,
130 max_tokens=max_tokens,
131 timeout=timeout,
132 )
134 provider = _ModelTrackingProvider("tracker")
135 result = stream_complete_with_fallback(
136 provider,
137 "prompt",
138 fallback_models=["fallback-model"],
139 )
141 chunks = list(result)
142 assert_that(chunks).is_equal_to(["chunk-tracker"])
143 assert_that(calls).is_equal_to(["test-model", "fallback-model"])
146def test_stream_fallback_raises_when_all_fail() -> None:
147 """Raise AIProviderError when all providers fail."""
148 provider = _FailingProvider()
150 with pytest.raises(AIProviderError, match="stream provider down"):
151 stream_complete_with_fallback(provider, "prompt")
154def test_stream_fallback_restores_model_name() -> None:
155 """Provider model name is restored after fallback completes."""
157 class _FailThenSuccessProvider(_SuccessProvider):
158 def stream_complete(
159 self,
160 prompt: str,
161 *,
162 system: str | None = None,
163 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
164 timeout: float = DEFAULT_TIMEOUT,
165 ) -> AIStreamResult:
166 if self._model == "test-model":
167 raise AIProviderError("primary failed")
168 return super().stream_complete(
169 prompt,
170 system=system,
171 max_tokens=max_tokens,
172 timeout=timeout,
173 )
175 provider = _FailThenSuccessProvider("p1")
176 original_model = provider.model_name
178 result = stream_complete_with_fallback(
179 provider,
180 "prompt",
181 fallback_models=["other-model"],
182 )
183 list(result) # consume stream so fallback logic completes
185 assert_that(provider.model_name).is_equal_to(original_model)