Coverage for tests / unit / ai / test_fix_generation_batch.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Tests for batch fix generation. 

2 

3Covers multi-issue batching per file, batch prompt construction, 

4and batch fallback to single-issue mode. 

5""" 

6 

7from __future__ import annotations 

8 

9import json 

10 

11from assertpy import assert_that 

12 

13from lintro.ai.fix import ( 

14 generate_fixes, 

15) 

16from lintro.ai.providers.base import AIResponse 

17from tests.unit.ai.conftest import MockAIProvider, MockIssue 

18 

19# --------------------------------------------------------------------------- 

20# P3-1: Multi-issue batching per file 

21# --------------------------------------------------------------------------- 

22 

23 

24def test_batch_prompt_for_multi_issue_file(tmp_path): 

25 """Multiple issues in one file should trigger a batch prompt.""" 

26 source = tmp_path / "multi.py" 

27 source.write_text("x = 1\ny = 2\nz = 3\n") 

28 

29 issues = [ 

30 MockIssue( 

31 file=str(source), 

32 line=1, 

33 code="B101", 

34 message="Issue one", 

35 ), 

36 MockIssue( 

37 file=str(source), 

38 line=3, 

39 code="E501", 

40 message="Issue two", 

41 ), 

42 ] 

43 

44 batch_response = AIResponse( 

45 content=json.dumps( 

46 [ 

47 { 

48 "line": 1, 

49 "code": "B101", 

50 "original_code": "x = 1", 

51 "suggested_code": "x = 2", 

52 "explanation": "Fix one", 

53 "confidence": "high", 

54 "risk_level": "behavioral-risk", 

55 }, 

56 { 

57 "line": 3, 

58 "code": "E501", 

59 "original_code": "z = 3", 

60 "suggested_code": "z = 4", 

61 "explanation": "Fix two", 

62 "confidence": "medium", 

63 "risk_level": "safe-style", 

64 }, 

65 ], 

66 ), 

67 model="mock", 

68 input_tokens=50, 

69 output_tokens=50, 

70 cost_estimate=0.002, 

71 provider="mock", 

72 ) 

73 provider = MockAIProvider(responses=[batch_response]) 

74 

75 result = generate_fixes( 

76 issues, 

77 provider, 

78 tool_name="ruff", 

79 workspace_root=tmp_path, 

80 ) 

81 

82 # Only 1 provider call (the batch), not 2 single calls 

83 assert_that(provider.calls).is_length(1) 

84 prompt = provider.calls[0]["prompt"] 

85 assert_that(prompt).contains("Issue one") 

86 assert_that(prompt).contains("Issue two") 

87 assert_that(prompt).contains("JSON array") 

88 

89 assert_that(result).is_length(2) 

90 assert_that(result[0].line).is_equal_to(1) 

91 assert_that(result[1].line).is_equal_to(3) 

92 assert_that(result[0].tool_name).is_equal_to("ruff") 

93 

94 

95def test_batch_fallback_to_single_on_parse_failure(tmp_path): 

96 """Failed batch parse falls back to single-issue mode.""" 

97 source = tmp_path / "multi.py" 

98 source.write_text("x = 1\ny = 2\n") 

99 

100 issues = [ 

101 MockIssue( 

102 file=str(source), 

103 line=1, 

104 code="B101", 

105 message="Issue one", 

106 ), 

107 MockIssue( 

108 file=str(source), 

109 line=2, 

110 code="E501", 

111 message="Issue two", 

112 ), 

113 ] 

114 

115 # First response (batch) is invalid, subsequent ones are valid single fixes 

116 responses = [ 

117 AIResponse( 

118 content="not-a-json-array", 

119 model="mock", 

120 input_tokens=10, 

121 output_tokens=10, 

122 cost_estimate=0.001, 

123 provider="mock", 

124 ), 

125 AIResponse( 

126 content=json.dumps( 

127 { 

128 "original_code": "x = 1", 

129 "suggested_code": "x = 2", 

130 "explanation": "Fix one", 

131 "confidence": "high", 

132 }, 

133 ), 

134 model="mock", 

135 input_tokens=10, 

136 output_tokens=10, 

137 cost_estimate=0.001, 

138 provider="mock", 

139 ), 

140 AIResponse( 

141 content=json.dumps( 

142 { 

143 "original_code": "y = 2", 

144 "suggested_code": "y = 3", 

145 "explanation": "Fix two", 

146 "confidence": "high", 

147 }, 

148 ), 

149 model="mock", 

150 input_tokens=10, 

151 output_tokens=10, 

152 cost_estimate=0.001, 

153 provider="mock", 

154 ), 

155 ] 

156 provider = MockAIProvider(responses=responses) 

157 

158 result = generate_fixes( 

159 issues, 

160 provider, 

161 tool_name="ruff", 

162 workspace_root=tmp_path, 

163 max_workers=1, 

164 ) 

165 

166 # 1 batch call + 2 single fallback calls = 3 

167 assert_that(provider.calls).is_length(3) 

168 assert_that(result).is_length(2)