Coverage for lintro / ai / undo.py: 80%
40 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""AI fix undo/rollback via patch files."""
3from __future__ import annotations
5import difflib
6import os
7import tempfile
8from pathlib import Path
9from typing import TYPE_CHECKING
11if TYPE_CHECKING:
12 from lintro.ai.models import AIFixSuggestion
14UNDO_DIR = ".lintro-cache/ai"
15UNDO_FILE = "last_fixes.patch"
18def save_undo_patch(
19 suggestions: list[AIFixSuggestion],
20 workspace_root: Path,
21) -> Path | None:
22 """Save a combined reverse patch before applying fixes.
24 The patch reverses applied changes (suggested -> original) so that
25 running ``git apply <patch>`` restores the original code.
27 Args:
28 suggestions: List of fix suggestions about to be applied.
29 workspace_root: Project root directory.
31 Returns:
32 Path to the saved patch file, or None if nothing to save.
34 Raises:
35 BaseException: Re-raised after cleaning up the temporary file on write failure.
36 """
37 if not suggestions:
38 return None
39 patch_lines: list[str] = []
40 for s in suggestions:
41 # Ensure trailing newlines for valid unified diff output
42 suggested = s.suggested_code or ""
43 if suggested and not suggested.endswith("\n"):
44 suggested += "\n"
45 original = s.original_code or ""
46 if original and not original.endswith("\n"):
47 original += "\n"
48 # Reverse diff: suggested -> original (for undo)
49 diff = difflib.unified_diff(
50 suggested.splitlines(keepends=True),
51 original.splitlines(keepends=True),
52 fromfile=f"a/{s.file}",
53 tofile=f"b/{s.file}",
54 )
55 patch_lines.extend(diff)
56 if not patch_lines:
57 return None
58 undo_dir = workspace_root / UNDO_DIR
59 undo_dir.mkdir(parents=True, exist_ok=True)
60 patch_path = undo_dir / UNDO_FILE
61 # Atomic write: temp file + os.replace to avoid partial writes
62 fd, tmp = tempfile.mkstemp(dir=undo_dir, suffix=".tmp")
63 try:
64 try:
65 fobj = os.fdopen(fd, "w", encoding="utf-8")
66 except BaseException:
67 os.close(fd)
68 raise
69 with fobj:
70 fobj.write("".join(patch_lines))
71 Path(tmp).replace(patch_path)
72 except BaseException:
73 Path(tmp).unlink(missing_ok=True)
74 raise
75 return patch_path