Coverage for tests / unit / ai / test_fix_generation_edge.py: 98%
89 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for fix generation edge cases, error handling, and retries.
3Covers provider errors, concurrent generation, retry behaviour,
4and authentication error handling.
5"""
7from __future__ import annotations
9import json
11from assertpy import assert_that
13from lintro.ai.fix import (
14 generate_fixes,
15)
16from lintro.ai.providers.base import AIResponse
17from tests.unit.ai.conftest import MockAIProvider, MockIssue
20def _make_ai_response(
21 original: str = "x = 1",
22 suggested: str = "x = 2",
23 explanation: str = "Fix",
24 confidence: str = "high",
25 risk_level: str | None = None,
26) -> AIResponse:
27 """Helper to build a successful AIResponse with a valid JSON payload."""
28 payload: dict[str, str | None] = {
29 "original_code": original,
30 "suggested_code": suggested,
31 "explanation": explanation,
32 "confidence": confidence,
33 }
34 if risk_level is not None:
35 payload["risk_level"] = risk_level
36 return AIResponse(
37 content=json.dumps(payload),
38 model="mock",
39 input_tokens=10,
40 output_tokens=10,
41 cost_estimate=0.001,
42 provider="mock",
43 )
46# ---------------------------------------------------------------------------
47# Provider error handling
48# ---------------------------------------------------------------------------
51def test_generate_fixes_handles_provider_error(tmp_path):
52 """Verify that a provider exception results in an empty fix list."""
53 source = tmp_path / "test.py"
54 source.write_text("x = 1\n")
56 issue = MockIssue(
57 file=str(source),
58 line=1,
59 code="B101",
60 message="test",
61 )
63 class ErrorProvider(MockAIProvider):
64 def complete(self, prompt, **kwargs):
65 raise RuntimeError("API down")
67 provider = ErrorProvider()
68 result = generate_fixes(
69 [issue],
70 provider,
71 tool_name="ruff",
72 )
74 assert_that(result).is_empty()
77# ---------------------------------------------------------------------------
78# Concurrent generate_fixes (ThreadPoolExecutor path)
79# ---------------------------------------------------------------------------
82def test_concurrent_generation_with_multiple_workers(tmp_path):
83 """generate_fixes with max_workers=3 exercises the ThreadPoolExecutor path."""
84 # Use separate files so batching does not group them,
85 # exercising the ThreadPoolExecutor path for single-issue calls.
86 sources = []
87 for i in range(1, 4):
88 f = tmp_path / f"test{i}.py"
89 f.write_text(f"line{i}\n")
90 sources.append(f)
92 issues = [
93 MockIssue(
94 file=str(sources[i - 1]),
95 line=1,
96 code="B101",
97 message=f"Issue {i}",
98 )
99 for i in range(1, 4)
100 ]
102 responses = [
103 AIResponse(
104 content=json.dumps(
105 {
106 "original_code": f"line{i}",
107 "suggested_code": f"fixed_line{i}",
108 "explanation": f"Fix issue {i}",
109 "confidence": "high",
110 },
111 ),
112 model="mock",
113 input_tokens=10,
114 output_tokens=10,
115 cost_estimate=0.001,
116 provider="mock",
117 )
118 for i in range(1, 4)
119 ]
120 provider = MockAIProvider(responses=responses)
122 result = generate_fixes(
123 issues,
124 provider,
125 tool_name="ruff",
126 max_workers=3,
127 workspace_root=tmp_path,
128 )
130 assert_that(provider.calls).is_length(3)
131 assert_that(result).is_length(3)
134def test_concurrent_mixed_success_and_failure(tmp_path):
135 """Concurrent mode: one success, one failure -> 1 suggestion returned."""
136 # Use separate files so batching does not group them
137 source1 = tmp_path / "test1.py"
138 source1.write_text("line1\n")
139 source2 = tmp_path / "test2.py"
140 source2.write_text("line2\n")
142 issues = [
143 MockIssue(
144 file=str(source1),
145 line=1,
146 code="B101",
147 message="Issue 1",
148 ),
149 MockIssue(
150 file=str(source2),
151 line=1,
152 code="B101",
153 message="Issue 2",
154 ),
155 ]
157 responses = [
158 AIResponse(
159 content=json.dumps(
160 {
161 "original_code": "line1",
162 "suggested_code": "fixed_line1",
163 "explanation": "Fix issue 1",
164 "confidence": "high",
165 },
166 ),
167 model="mock",
168 input_tokens=10,
169 output_tokens=10,
170 cost_estimate=0.001,
171 provider="mock",
172 ),
173 AIResponse(
174 content="not-json",
175 model="mock",
176 input_tokens=10,
177 output_tokens=10,
178 cost_estimate=0.001,
179 provider="mock",
180 ),
181 ]
182 provider = MockAIProvider(responses=responses)
184 result = generate_fixes(
185 issues,
186 provider,
187 tool_name="ruff",
188 max_workers=3,
189 workspace_root=tmp_path,
190 )
192 assert_that(provider.calls).is_length(2)
193 assert_that(result).is_length(1)
196# ---------------------------------------------------------------------------
197# _call_provider retry behaviour (via with_retry)
198# ---------------------------------------------------------------------------
201def test_retries_on_provider_error(tmp_path):
202 """Transient AIProviderError triggers retries, then succeeds."""
203 from lintro.ai.exceptions import AIProviderError
205 source = tmp_path / "test.py"
206 source.write_text("x = 1\n")
208 issue = MockIssue(
209 file=str(source),
210 line=1,
211 code="B101",
212 message="test",
213 )
215 call_count = {"n": 0}
216 success_response = _make_ai_response()
218 class RetryProvider(MockAIProvider):
219 def complete(self, prompt, **kwargs):
220 call_count["n"] += 1
221 if call_count["n"] < 3:
222 raise AIProviderError("transient")
223 return success_response
225 provider = RetryProvider()
226 result = generate_fixes(
227 [issue],
228 provider,
229 tool_name="ruff",
230 workspace_root=tmp_path,
231 max_retries=3,
232 )
234 assert_that(call_count["n"]).is_equal_to(3)
235 assert_that(result).is_length(1)
238def test_no_retry_on_auth_error(tmp_path):
239 """AIAuthenticationError is never retried."""
240 from lintro.ai.exceptions import AIAuthenticationError
242 source = tmp_path / "test.py"
243 source.write_text("x = 1\n")
245 issue = MockIssue(
246 file=str(source),
247 line=1,
248 code="B101",
249 message="test",
250 )
252 call_count = {"n": 0}
254 class AuthErrorProvider(MockAIProvider):
255 def complete(self, prompt, **kwargs):
256 call_count["n"] += 1
257 raise AIAuthenticationError("bad key")
259 provider = AuthErrorProvider()
260 result = generate_fixes(
261 [issue],
262 provider,
263 tool_name="ruff",
264 workspace_root=tmp_path,
265 max_retries=3,
266 )
268 # Auth errors propagate immediately -- only 1 call, no retries
269 assert_that(call_count["n"]).is_equal_to(1)
270 assert_that(result).is_empty()
273def test_max_retries_zero_means_no_retry(tmp_path):
274 """max_retries=0 means no retry on transient error — only one call."""
275 from lintro.ai.exceptions import AIProviderError
277 source = tmp_path / "test.py"
278 source.write_text("x = 1\n")
280 issue = MockIssue(
281 file=str(source),
282 line=1,
283 code="B101",
284 message="test",
285 )
287 call_count = {"n": 0}
289 class FailOnceProvider(MockAIProvider):
290 def complete(self, prompt, **kwargs):
291 call_count["n"] += 1
292 raise AIProviderError("transient failure")
294 provider = FailOnceProvider()
295 result = generate_fixes(
296 [issue],
297 provider,
298 tool_name="ruff",
299 workspace_root=tmp_path,
300 max_retries=0,
301 )
303 assert_that(call_count["n"]).is_equal_to(1)
304 assert_that(result).is_empty()