Coverage for tests / unit / ai / test_fix_generation_edge.py: 98%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Tests for fix generation edge cases, error handling, and retries. 

2 

3Covers provider errors, concurrent generation, retry behaviour, 

4and authentication error handling. 

5""" 

6 

7from __future__ import annotations 

8 

9import json 

10 

11from assertpy import assert_that 

12 

13from lintro.ai.fix import ( 

14 generate_fixes, 

15) 

16from lintro.ai.providers.base import AIResponse 

17from tests.unit.ai.conftest import MockAIProvider, MockIssue 

18 

19 

20def _make_ai_response( 

21 original: str = "x = 1", 

22 suggested: str = "x = 2", 

23 explanation: str = "Fix", 

24 confidence: str = "high", 

25 risk_level: str | None = None, 

26) -> AIResponse: 

27 """Helper to build a successful AIResponse with a valid JSON payload.""" 

28 payload: dict[str, str | None] = { 

29 "original_code": original, 

30 "suggested_code": suggested, 

31 "explanation": explanation, 

32 "confidence": confidence, 

33 } 

34 if risk_level is not None: 

35 payload["risk_level"] = risk_level 

36 return AIResponse( 

37 content=json.dumps(payload), 

38 model="mock", 

39 input_tokens=10, 

40 output_tokens=10, 

41 cost_estimate=0.001, 

42 provider="mock", 

43 ) 

44 

45 

46# --------------------------------------------------------------------------- 

47# Provider error handling 

48# --------------------------------------------------------------------------- 

49 

50 

51def test_generate_fixes_handles_provider_error(tmp_path): 

52 """Verify that a provider exception results in an empty fix list.""" 

53 source = tmp_path / "test.py" 

54 source.write_text("x = 1\n") 

55 

56 issue = MockIssue( 

57 file=str(source), 

58 line=1, 

59 code="B101", 

60 message="test", 

61 ) 

62 

63 class ErrorProvider(MockAIProvider): 

64 def complete(self, prompt, **kwargs): 

65 raise RuntimeError("API down") 

66 

67 provider = ErrorProvider() 

68 result = generate_fixes( 

69 [issue], 

70 provider, 

71 tool_name="ruff", 

72 ) 

73 

74 assert_that(result).is_empty() 

75 

76 

77# --------------------------------------------------------------------------- 

78# Concurrent generate_fixes (ThreadPoolExecutor path) 

79# --------------------------------------------------------------------------- 

80 

81 

82def test_concurrent_generation_with_multiple_workers(tmp_path): 

83 """generate_fixes with max_workers=3 exercises the ThreadPoolExecutor path.""" 

84 # Use separate files so batching does not group them, 

85 # exercising the ThreadPoolExecutor path for single-issue calls. 

86 sources = [] 

87 for i in range(1, 4): 

88 f = tmp_path / f"test{i}.py" 

89 f.write_text(f"line{i}\n") 

90 sources.append(f) 

91 

92 issues = [ 

93 MockIssue( 

94 file=str(sources[i - 1]), 

95 line=1, 

96 code="B101", 

97 message=f"Issue {i}", 

98 ) 

99 for i in range(1, 4) 

100 ] 

101 

102 responses = [ 

103 AIResponse( 

104 content=json.dumps( 

105 { 

106 "original_code": f"line{i}", 

107 "suggested_code": f"fixed_line{i}", 

108 "explanation": f"Fix issue {i}", 

109 "confidence": "high", 

110 }, 

111 ), 

112 model="mock", 

113 input_tokens=10, 

114 output_tokens=10, 

115 cost_estimate=0.001, 

116 provider="mock", 

117 ) 

118 for i in range(1, 4) 

119 ] 

120 provider = MockAIProvider(responses=responses) 

121 

122 result = generate_fixes( 

123 issues, 

124 provider, 

125 tool_name="ruff", 

126 max_workers=3, 

127 workspace_root=tmp_path, 

128 ) 

129 

130 assert_that(provider.calls).is_length(3) 

131 assert_that(result).is_length(3) 

132 

133 

134def test_concurrent_mixed_success_and_failure(tmp_path): 

135 """Concurrent mode: one success, one failure -> 1 suggestion returned.""" 

136 # Use separate files so batching does not group them 

137 source1 = tmp_path / "test1.py" 

138 source1.write_text("line1\n") 

139 source2 = tmp_path / "test2.py" 

140 source2.write_text("line2\n") 

141 

142 issues = [ 

143 MockIssue( 

144 file=str(source1), 

145 line=1, 

146 code="B101", 

147 message="Issue 1", 

148 ), 

149 MockIssue( 

150 file=str(source2), 

151 line=1, 

152 code="B101", 

153 message="Issue 2", 

154 ), 

155 ] 

156 

157 responses = [ 

158 AIResponse( 

159 content=json.dumps( 

160 { 

161 "original_code": "line1", 

162 "suggested_code": "fixed_line1", 

163 "explanation": "Fix issue 1", 

164 "confidence": "high", 

165 }, 

166 ), 

167 model="mock", 

168 input_tokens=10, 

169 output_tokens=10, 

170 cost_estimate=0.001, 

171 provider="mock", 

172 ), 

173 AIResponse( 

174 content="not-json", 

175 model="mock", 

176 input_tokens=10, 

177 output_tokens=10, 

178 cost_estimate=0.001, 

179 provider="mock", 

180 ), 

181 ] 

182 provider = MockAIProvider(responses=responses) 

183 

184 result = generate_fixes( 

185 issues, 

186 provider, 

187 tool_name="ruff", 

188 max_workers=3, 

189 workspace_root=tmp_path, 

190 ) 

191 

192 assert_that(provider.calls).is_length(2) 

193 assert_that(result).is_length(1) 

194 

195 

196# --------------------------------------------------------------------------- 

197# _call_provider retry behaviour (via with_retry) 

198# --------------------------------------------------------------------------- 

199 

200 

201def test_retries_on_provider_error(tmp_path): 

202 """Transient AIProviderError triggers retries, then succeeds.""" 

203 from lintro.ai.exceptions import AIProviderError 

204 

205 source = tmp_path / "test.py" 

206 source.write_text("x = 1\n") 

207 

208 issue = MockIssue( 

209 file=str(source), 

210 line=1, 

211 code="B101", 

212 message="test", 

213 ) 

214 

215 call_count = {"n": 0} 

216 success_response = _make_ai_response() 

217 

218 class RetryProvider(MockAIProvider): 

219 def complete(self, prompt, **kwargs): 

220 call_count["n"] += 1 

221 if call_count["n"] < 3: 

222 raise AIProviderError("transient") 

223 return success_response 

224 

225 provider = RetryProvider() 

226 result = generate_fixes( 

227 [issue], 

228 provider, 

229 tool_name="ruff", 

230 workspace_root=tmp_path, 

231 max_retries=3, 

232 ) 

233 

234 assert_that(call_count["n"]).is_equal_to(3) 

235 assert_that(result).is_length(1) 

236 

237 

238def test_no_retry_on_auth_error(tmp_path): 

239 """AIAuthenticationError is never retried.""" 

240 from lintro.ai.exceptions import AIAuthenticationError 

241 

242 source = tmp_path / "test.py" 

243 source.write_text("x = 1\n") 

244 

245 issue = MockIssue( 

246 file=str(source), 

247 line=1, 

248 code="B101", 

249 message="test", 

250 ) 

251 

252 call_count = {"n": 0} 

253 

254 class AuthErrorProvider(MockAIProvider): 

255 def complete(self, prompt, **kwargs): 

256 call_count["n"] += 1 

257 raise AIAuthenticationError("bad key") 

258 

259 provider = AuthErrorProvider() 

260 result = generate_fixes( 

261 [issue], 

262 provider, 

263 tool_name="ruff", 

264 workspace_root=tmp_path, 

265 max_retries=3, 

266 ) 

267 

268 # Auth errors propagate immediately -- only 1 call, no retries 

269 assert_that(call_count["n"]).is_equal_to(1) 

270 assert_that(result).is_empty() 

271 

272 

273def test_max_retries_zero_means_no_retry(tmp_path): 

274 """max_retries=0 means no retry on transient error — only one call.""" 

275 from lintro.ai.exceptions import AIProviderError 

276 

277 source = tmp_path / "test.py" 

278 source.write_text("x = 1\n") 

279 

280 issue = MockIssue( 

281 file=str(source), 

282 line=1, 

283 code="B101", 

284 message="test", 

285 ) 

286 

287 call_count = {"n": 0} 

288 

289 class FailOnceProvider(MockAIProvider): 

290 def complete(self, prompt, **kwargs): 

291 call_count["n"] += 1 

292 raise AIProviderError("transient failure") 

293 

294 provider = FailOnceProvider() 

295 result = generate_fixes( 

296 [issue], 

297 provider, 

298 tool_name="ruff", 

299 workspace_root=tmp_path, 

300 max_retries=0, 

301 ) 

302 

303 assert_that(call_count["n"]).is_equal_to(1) 

304 assert_that(result).is_empty()