Coverage for lintro / ai / interactive.py: 71%

192 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Interactive fix review loop for AI-generated suggestions. 

2 

3Used by the ``fmt`` flow when ``auto_apply`` is False. Groups suggestions 

4by error code and presents them for batch accept/reject decisions. 

5""" 

6 

7from __future__ import annotations 

8 

9import sys 

10from collections import defaultdict 

11from collections.abc import Sequence 

12from enum import StrEnum 

13from pathlib import Path 

14 

15import click 

16from rich.console import Console, Group, RenderableType 

17from rich.markup import escape 

18from rich.panel import Panel 

19from rich.syntax import Syntax 

20 

21from lintro.ai.apply import apply_fixes 

22from lintro.ai.display.shared import cost_str, print_code_panel, print_section_header 

23from lintro.ai.display.validation import render_validation 

24from lintro.ai.enums import RiskLevel 

25from lintro.ai.models import AIFixSuggestion 

26from lintro.ai.paths import relative_path 

27from lintro.ai.risk import ( 

28 SAFE_STYLE_RISK, 

29 calculate_patch_stats, 

30 classify_fix_risk, 

31 is_safe_style_fix, 

32) 

33from lintro.ai.validation import validate_applied_fixes 

34 

35__all__ = ["apply_fixes", "review_fixes_interactive"] 

36 

37 

38class ReviewKey(StrEnum): 

39 """Key bindings for interactive fix review.""" 

40 

41 ACCEPT = "y" 

42 ACCEPT_ALL = "a" 

43 REJECT = "r" 

44 SHOW_DIFF = "d" 

45 SKIP = "s" 

46 TOGGLE_VALIDATE = "v" 

47 QUIT = "q" 

48 

49 

50def _group_by_code( 

51 suggestions: Sequence[AIFixSuggestion], 

52) -> dict[str, list[AIFixSuggestion]]: 

53 """Group fix suggestions by error code. 

54 

55 Args: 

56 suggestions: Fix suggestions to group. 

57 

58 Returns: 

59 Dict mapping error code to list of suggestions. 

60 """ 

61 groups: dict[str, list[AIFixSuggestion]] = defaultdict(list) 

62 for s in suggestions: 

63 key = s.code or "unknown" 

64 groups[key].append(s) 

65 return dict(groups) 

66 

67 

68def _print_group_header( 

69 console: Console, 

70 code: str, 

71 fixes: list[AIFixSuggestion], 

72 group_index: int, 

73 total_groups: int, 

74) -> None: 

75 """Print a panel for one error-code group. 

76 

77 Delegates to the shared ``print_code_panel`` from display.py 

78 to ensure consistent Panel styling across chk and fmt. 

79 

80 Args: 

81 console: Rich Console instance. 

82 code: The error code (e.g. "D107"). 

83 fixes: Suggestions in this group. 

84 group_index: 1-based index of this group. 

85 total_groups: Total number of groups. 

86 """ 

87 parts: list[RenderableType] = [] 

88 stats = calculate_patch_stats(fixes) 

89 risk_labels = {classify_fix_risk(fix) for fix in fixes} 

90 group_risk = ( 

91 SAFE_STYLE_RISK 

92 if len(risk_labels) == 1 and SAFE_STYLE_RISK in risk_labels 

93 else RiskLevel.BEHAVIORAL_RISK 

94 ) 

95 risk_color = "green" if group_risk == SAFE_STYLE_RISK else "yellow" 

96 

97 parts.append( 

98 ( 

99 f"[{risk_color}]risk: {group_risk}[/{risk_color}]" 

100 " · " 

101 f"[dim]patch: {stats.files} files, +{stats.lines_added}/" 

102 f"-{stats.lines_removed}, {stats.hunks} hunks[/dim]" 

103 ), 

104 ) 

105 

106 first = fixes[0] if fixes else None 

107 explanation = (first.explanation or "") if first else "" 

108 if explanation: 

109 parts.append(f"[cyan]{escape(explanation)}[/cyan]") 

110 

111 for fix in fixes: 

112 rel = relative_path(fix.file) 

113 loc = f"{rel}:{fix.line}" if fix.line else rel 

114 parts.append( 

115 Panel( 

116 f"[green]{escape(loc)}[/green]", 

117 border_style="dim", 

118 padding=(0, 1), 

119 ), 

120 ) 

121 

122 content: RenderableType = ( 

123 Group(*parts) if len(parts) > 1 else (parts[0] if parts else "") 

124 ) 

125 console.print() 

126 group_tool = first.tool_name if first else "" 

127 unique_files = len({fix.file for fix in fixes}) 

128 print_code_panel( 

129 console, 

130 code=code, 

131 index=group_index, 

132 total=total_groups, 

133 count=unique_files, 

134 count_label="file", 

135 content=content, 

136 tool_name=group_tool, 

137 ) 

138 

139 

140def _show_group_diffs( 

141 console: Console, 

142 fixes: list[AIFixSuggestion], 

143) -> None: 

144 """Show individual diffs for a group of fixes. 

145 

146 Args: 

147 console: Rich Console instance. 

148 fixes: Suggestions to show diffs for. 

149 """ 

150 for fix in fixes: 

151 if not fix.diff or not fix.diff.strip(): 

152 continue 

153 

154 rel = relative_path(fix.file) 

155 loc = f"{rel}:{fix.line}" if fix.line else rel 

156 console.print(f"\n [dim]{loc}[/dim]") 

157 

158 syntax = Syntax( 

159 fix.diff, 

160 "diff", 

161 theme="ansi_dark", 

162 padding=0, 

163 ) 

164 console.print(syntax) 

165 

166 

167def _apply_group( 

168 console: Console, 

169 fixes: list[AIFixSuggestion], 

170 *, 

171 workspace_root: Path, 

172 search_radius: int = 5, 

173) -> tuple[int, list[AIFixSuggestion]]: 

174 """Apply all fixes in a group, reporting results. 

175 

176 Args: 

177 console: Rich Console instance. 

178 fixes: Suggestions to apply. 

179 workspace_root: Root directory limiting writable paths. 

180 search_radius: Max lines above/below the target line to search. 

181 

182 Returns: 

183 Tuple of (applied_count, list of successfully applied suggestions). 

184 """ 

185 # Sort fixes by (file, line descending) so same-file edits apply 

186 # bottom-to-top, preventing earlier edits from shifting later targets. 

187 sorted_fixes = sorted(fixes, key=lambda f: (f.file, -(f.line or 0))) 

188 applied_fixes = apply_fixes( 

189 sorted_fixes, 

190 workspace_root=workspace_root, 

191 search_radius=search_radius, 

192 ) 

193 applied = len(applied_fixes) 

194 failed = len(fixes) - applied 

195 msg = f" [green]✓ Applied {applied}/{len(fixes)}[/green]" 

196 if failed: 

197 msg += f" [yellow]({failed} failed)[/yellow]" 

198 console.print(msg) 

199 return applied, applied_fixes 

200 

201 

202def _validate_group( 

203 console: Console, 

204 applied_suggestions: Sequence[AIFixSuggestion], 

205) -> None: 

206 """Run validation immediately for a single accepted group.""" 

207 validation = validate_applied_fixes(applied_suggestions) 

208 if not validation: 

209 return 

210 if ( 

211 validation.verified == 0 

212 and validation.unverified == 0 

213 and not validation.new_issues 

214 and not validation.details 

215 ): 

216 return 

217 output = render_validation(validation) 

218 if output: 

219 console.print(output) 

220 

221 

222def _render_prompt(*, validate_mode: bool, safe_default: bool) -> str: 

223 """Build interactive prompt text with current mode/default.""" 

224 default_text = " (Enter=accept group; safe-style default)" if safe_default else "" 

225 mode = "on" if validate_mode else "off" 

226 return ( 

227 " [y]accept group [a]accept group + remaining " 

228 "[r]reject [d]diffs [s]skip [v]verify fixes:" 

229 f" {mode} (toggle only, no apply) [q]quit{default_text}: " 

230 ) 

231 

232 

233def review_fixes_interactive( 

234 suggestions: Sequence[AIFixSuggestion], 

235 *, 

236 validate_after_group: bool = False, 

237 workspace_root: Path, 

238 search_radius: int = 5, 

239) -> tuple[int, int, list[AIFixSuggestion]]: 

240 """Present fix suggestions grouped by error code for review. 

241 

242 Groups suggestions by error code and prompts once per group: 

243 ``[y]accept group / [a]accept group + remaining / [r]eject / 

244 [d]iffs / [s]kip / [v]toggle per-group validation / [q]uit`` 

245 

246 Args: 

247 suggestions: Fix suggestions to review. 

248 validate_after_group: Whether to validate immediately after 

249 each accepted group. 

250 workspace_root: Root directory limiting writable paths. 

251 search_radius: Max lines above/below the target line to search. 

252 

253 Returns: 

254 Tuple of (accepted_count, rejected_count, applied_suggestions). 

255 """ 

256 if not suggestions: 

257 return 0, 0, [] 

258 

259 # Non-interactive environments skip the review 

260 if not sys.stdin.isatty(): 

261 return 0, 0, [] 

262 

263 console = Console() 

264 accepted = 0 

265 rejected = 0 

266 accept_all = False 

267 validate_mode = validate_after_group 

268 all_applied: list[AIFixSuggestion] = [] 

269 

270 groups = _group_by_code(suggestions) 

271 total_groups = len(groups) 

272 total_fixes = len(suggestions) 

273 plural = "es" if total_fixes != 1 else "" 

274 

275 # Section header 

276 total_input = sum(s.input_tokens for s in suggestions) 

277 total_output = sum(s.output_tokens for s in suggestions) 

278 total_cost = sum(s.cost_estimate for s in suggestions) 

279 codes = f"{total_groups} code{'s' if total_groups != 1 else ''}" 

280 cost_info = cost_str(total_input, total_output, total_cost) 

281 print_section_header( 

282 console, 

283 "🤖", 

284 "AI FIX SUGGESTIONS", 

285 f"{total_fixes} fix{plural} across {codes}", 

286 cost_info=cost_info, 

287 ) 

288 

289 auto_accepted = 0 

290 auto_failed = 0 

291 auto_groups = 0 

292 

293 for gi, (code, fixes) in enumerate(groups.items(), 1): 

294 if accept_all: 

295 count, group_applied = _apply_group( 

296 console, 

297 fixes, 

298 workspace_root=workspace_root, 

299 search_radius=search_radius, 

300 ) 

301 accepted += count 

302 auto_accepted += count 

303 auto_failed += len(fixes) - count 

304 auto_groups += 1 

305 all_applied.extend(group_applied) 

306 if validate_mode: 

307 if group_applied: 

308 _validate_group(console, group_applied) 

309 else: 

310 console.print( 

311 " [dim]Validation skipped " 

312 "(no fixes applied in this group).[/dim]", 

313 ) 

314 continue 

315 

316 # Group header (flat text, no panels) 

317 _print_group_header(console, code, fixes, gi, total_groups) 

318 

319 safe_default = all(is_safe_style_fix(fix) for fix in fixes) 

320 console.print() 

321 

322 while True: 

323 prompt_text = click.style( 

324 _render_prompt( 

325 validate_mode=validate_mode, 

326 safe_default=safe_default, 

327 ), 

328 fg="cyan", 

329 ) 

330 click.echo(prompt_text, nl=False) 

331 try: 

332 choice = click.getchar() 

333 click.echo(choice) # echo the keypress 

334 except (EOFError, KeyboardInterrupt): 

335 click.echo() 

336 return accepted, rejected, all_applied 

337 

338 if choice in ("\r", "\n"): 

339 choice = ReviewKey.ACCEPT if safe_default else ReviewKey.SKIP 

340 else: 

341 choice = choice.lower() 

342 

343 if choice == ReviewKey.SHOW_DIFF: 

344 _show_group_diffs(console, fixes) 

345 console.print() 

346 continue 

347 if choice == ReviewKey.TOGGLE_VALIDATE: 

348 validate_mode = not validate_mode 

349 state = "enabled" if validate_mode else "disabled" 

350 console.print( 

351 f" [dim]Per-group validation {state} " "(no fixes applied).[/dim]", 

352 ) 

353 console.print() 

354 continue 

355 

356 _valid_actions = {v.value for v in ReviewKey} 

357 if choice not in _valid_actions: 

358 console.print(" [dim]Unrecognized key. Use y/a/r/d/s/v/q.[/dim]") 

359 console.print() 

360 continue 

361 

362 break 

363 

364 if choice in (ReviewKey.ACCEPT_ALL, ReviewKey.ACCEPT): 

365 count, group_applied = _apply_group( 

366 console, 

367 fixes, 

368 workspace_root=workspace_root, 

369 search_radius=search_radius, 

370 ) 

371 accepted += count 

372 all_applied.extend(group_applied) 

373 if validate_mode: 

374 if group_applied: 

375 _validate_group(console, group_applied) 

376 else: 

377 console.print( 

378 " [dim]Validation skipped " 

379 "(no fixes applied in this group).[/dim]", 

380 ) 

381 if choice == ReviewKey.ACCEPT_ALL: 

382 accept_all = True 

383 console.print(" [dim]Will accept all remaining groups.[/dim]") 

384 elif choice == ReviewKey.REJECT: 

385 rejected += len(fixes) 

386 console.print( 

387 f" [yellow]✗ Rejected {len(fixes)} " 

388 f"fix{'es' if len(fixes) != 1 else ''}[/yellow]", 

389 ) 

390 elif choice == ReviewKey.SKIP: 

391 console.print(" [dim]⏭ Skipped[/dim]") 

392 elif choice == ReviewKey.QUIT: 

393 console.print(" [dim]Quit review.[/dim]") 

394 break 

395 

396 # Consolidated line for auto-accepted groups 

397 if auto_groups > 0: 

398 total_auto = auto_accepted + auto_failed 

399 msg = ( 

400 f" [green]✓ Applied {auto_accepted}/{total_auto} " 

401 f"across {auto_groups} group{'s' if auto_groups != 1 else ''}[/green]" 

402 ) 

403 if auto_failed: 

404 msg += f" [yellow]({auto_failed} failed)[/yellow]" 

405 console.print(msg) 

406 

407 # Summary 

408 console.print() 

409 parts: list[str] = [] 

410 if accepted: 

411 parts.append(f"[green]{accepted} accepted[/green]") 

412 if rejected: 

413 parts.append(f"[red]{rejected} rejected[/red]") 

414 skipped = total_fixes - accepted - rejected 

415 if skipped: 

416 parts.append(f"{skipped} skipped") 

417 if parts: 

418 console.print( 

419 f" [bold]Review complete:[/bold] {' · '.join(parts)}", 

420 ) 

421 console.print() 

422 

423 return accepted, rejected, all_applied