Coverage for tests / unit / ai / test_stream_fallback.py: 95%

74 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Tests for stream_complete_with_fallback().""" 

2 

3from __future__ import annotations 

4 

5import pytest 

6from assertpy import assert_that 

7 

8from lintro.ai.exceptions import AIProviderError 

9from lintro.ai.fallback import stream_complete_with_fallback 

10from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider 

11from lintro.ai.providers.constants import DEFAULT_PER_CALL_MAX_TOKENS, DEFAULT_TIMEOUT 

12 

13 

14def _make_response(content: str = "ok", provider: str = "stub") -> AIResponse: 

15 return AIResponse( 

16 content=content, 

17 model="m", 

18 input_tokens=1, 

19 output_tokens=1, 

20 cost_estimate=0.0, 

21 provider=provider, 

22 ) 

23 

24 

25class _SuccessProvider(BaseAIProvider): 

26 """Provider that always succeeds.""" 

27 

28 def __init__(self, name: str = "success") -> None: 

29 self._name = name 

30 self._provider_name = name 

31 self._has_sdk = True 

32 self._model = "test-model" 

33 self._api_key_env = "TEST_KEY" 

34 self._max_tokens = DEFAULT_PER_CALL_MAX_TOKENS 

35 self._base_url = None 

36 self._client = "fake" 

37 

38 def _create_client(self, *, api_key: str) -> object: 

39 return "fake" 

40 

41 def complete( 

42 self, 

43 prompt: str, 

44 *, 

45 system: str | None = None, 

46 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

47 timeout: float = DEFAULT_TIMEOUT, 

48 ) -> AIResponse: 

49 return _make_response(content=f"from-{self._name}", provider=self._name) 

50 

51 def stream_complete( 

52 self, 

53 prompt: str, 

54 *, 

55 system: str | None = None, 

56 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

57 timeout: float = DEFAULT_TIMEOUT, 

58 ) -> AIStreamResult: 

59 resp = _make_response(content="", provider=self._name) 

60 return AIStreamResult( 

61 _chunks=iter([f"chunk-{self._name}"]), 

62 _on_done=lambda: resp, 

63 ) 

64 

65 

66class _FailingProvider(BaseAIProvider): 

67 """Provider that always raises.""" 

68 

69 def __init__(self) -> None: 

70 self._provider_name = "failing" 

71 self._has_sdk = True 

72 self._model = "fail-model" 

73 self._api_key_env = "FAIL_KEY" 

74 self._max_tokens = DEFAULT_PER_CALL_MAX_TOKENS 

75 self._base_url = None 

76 self._client = "fake" 

77 

78 def _create_client(self, *, api_key: str) -> object: 

79 return "fake" 

80 

81 def complete( 

82 self, 

83 prompt: str, 

84 *, 

85 system: str | None = None, 

86 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

87 timeout: float = DEFAULT_TIMEOUT, 

88 ) -> AIResponse: 

89 raise AIProviderError("provider down") 

90 

91 def stream_complete( 

92 self, 

93 prompt: str, 

94 *, 

95 system: str | None = None, 

96 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

97 timeout: float = DEFAULT_TIMEOUT, 

98 ) -> AIStreamResult: 

99 raise AIProviderError("stream provider down") 

100 

101 

102def test_stream_fallback_returns_first_success() -> None: 

103 """Return the stream from the first working provider.""" 

104 provider = _SuccessProvider("primary") 

105 result = stream_complete_with_fallback(provider, "prompt") 

106 

107 chunks = list(result) 

108 assert_that(chunks).is_equal_to(["chunk-primary"]) 

109 

110 

111def test_stream_fallback_tries_fallback_models() -> None: 

112 """Falls back to alternate model when primary fails.""" 

113 calls: list[str] = [] 

114 

115 class _ModelTrackingProvider(_SuccessProvider): 

116 def stream_complete( 

117 self, 

118 prompt: str, 

119 *, 

120 system: str | None = None, 

121 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

122 timeout: float = DEFAULT_TIMEOUT, 

123 ) -> AIStreamResult: 

124 calls.append(self._model) 

125 if self._model == "test-model": 

126 raise AIProviderError("primary failed") 

127 return super().stream_complete( 

128 prompt, 

129 system=system, 

130 max_tokens=max_tokens, 

131 timeout=timeout, 

132 ) 

133 

134 provider = _ModelTrackingProvider("tracker") 

135 result = stream_complete_with_fallback( 

136 provider, 

137 "prompt", 

138 fallback_models=["fallback-model"], 

139 ) 

140 

141 chunks = list(result) 

142 assert_that(chunks).is_equal_to(["chunk-tracker"]) 

143 assert_that(calls).is_equal_to(["test-model", "fallback-model"]) 

144 

145 

146def test_stream_fallback_raises_when_all_fail() -> None: 

147 """Raise AIProviderError when all providers fail.""" 

148 provider = _FailingProvider() 

149 

150 with pytest.raises(AIProviderError, match="stream provider down"): 

151 stream_complete_with_fallback(provider, "prompt") 

152 

153 

154def test_stream_fallback_restores_model_name() -> None: 

155 """Provider model name is restored after fallback completes.""" 

156 

157 class _FailThenSuccessProvider(_SuccessProvider): 

158 def stream_complete( 

159 self, 

160 prompt: str, 

161 *, 

162 system: str | None = None, 

163 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

164 timeout: float = DEFAULT_TIMEOUT, 

165 ) -> AIStreamResult: 

166 if self._model == "test-model": 

167 raise AIProviderError("primary failed") 

168 return super().stream_complete( 

169 prompt, 

170 system=system, 

171 max_tokens=max_tokens, 

172 timeout=timeout, 

173 ) 

174 

175 provider = _FailThenSuccessProvider("p1") 

176 original_model = provider.model_name 

177 

178 result = stream_complete_with_fallback( 

179 provider, 

180 "prompt", 

181 fallback_models=["other-model"], 

182 ) 

183 list(result) # consume stream so fallback logic completes 

184 

185 assert_that(provider.model_name).is_equal_to(original_model)