Coverage for lintro / ai / config.py: 95%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""AI configuration model for Lintro. 

2 

3Defines the AIConfig Pydantic model used in the ``ai:`` section of 

4.lintro-config.yaml. All AI features are opt-in and disabled by default. 

5 

6Fields are logically grouped into three areas: 

7 

8* **Provider** — model selection, API endpoints, authentication, retry 

9* **Budget** — cost caps, issue limits, parallelism, caching 

10* **Output** — display, verbosity, PR integration, apply behaviour 

11 

12The flat attribute API (``config.provider``, ``config.max_tokens``, …) 

13is the primary interface; the grouping is for documentation only. 

14""" 

15 

16from __future__ import annotations 

17 

18from pydantic import BaseModel, ConfigDict, Field, model_validator 

19 

20from lintro.ai.config_views import AIBudgetConfig, AIOutputConfig, AIProviderConfig 

21from lintro.ai.enums import ConfidenceLevel, SanitizeMode 

22from lintro.ai.registry import AIProvider 

23 

24__all__ = [ 

25 "AIBudgetConfig", 

26 "AIConfig", 

27 "AIOutputConfig", 

28 "AIProviderConfig", 

29] 

30 

31 

32class AIConfig(BaseModel): 

33 """Configuration for AI-powered features. 

34 

35 All fields are accessible directly on the model instance 

36 (e.g. ``config.provider``). For structured access, use the 

37 ``provider_config``, ``budget_config``, and ``output_config`` 

38 properties which return frozen dataclass snapshots. 

39 """ 

40 

41 model_config = ConfigDict(frozen=False, extra="forbid") 

42 

43 enabled: bool = False 

44 provider: AIProvider = AIProvider.ANTHROPIC 

45 model: str | None = None 

46 api_key_env: str | None = None 

47 api_base_url: str | None = Field( 

48 default=None, 

49 description=( 

50 "Custom API base URL. Enables Ollama, vLLM, Azure OpenAI, " 

51 "or any OpenAI-compatible endpoint." 

52 ), 

53 ) 

54 api_region: str | None = Field( 

55 default=None, 

56 description=( 

57 "Provider region hint for data residency. " 

58 "Used with api_base_url for region-specific endpoints." 

59 ), 

60 ) 

61 fallback_models: list[str] = Field(default_factory=list) 

62 default_fix: bool = False 

63 auto_apply: bool = False 

64 auto_apply_safe_fixes: bool = True 

65 max_tokens: int = Field(default=4096, ge=1, le=128_000) 

66 max_fix_attempts: int = Field( 

67 default=20, 

68 ge=1, 

69 description="Maximum number of issues to attempt fixing per run. " 

70 "Counts API calls made, not suggestions returned.", 

71 ) 

72 max_parallel_calls: int = Field(default=5, ge=1, le=20) 

73 max_retries: int = Field(default=2, ge=0, le=10) 

74 api_timeout: float = Field(default=60.0, ge=1.0) 

75 validate_after_group: bool = False 

76 show_cost_estimate: bool = True 

77 context_lines: int = Field(default=15, ge=1, le=100) 

78 fix_search_radius: int = Field(default=5, ge=1, le=50) 

79 retry_base_delay: float = Field(default=1.0, ge=0.1) 

80 retry_max_delay: float = Field(default=30.0, ge=1.0) 

81 retry_backoff_factor: float = Field(default=2.0, ge=1.0) 

82 enable_cache: bool = Field(default=False) 

83 cache_ttl: int = Field(default=3600, ge=60) 

84 cache_max_entries: int = Field(default=1000, ge=1) 

85 max_refinement_attempts: int = Field(default=1, ge=0, le=3) 

86 fail_on_ai_error: bool = Field(default=False) 

87 fail_on_unfixed: bool = Field( 

88 default=False, 

89 description=( 

90 "When True, unfixable or failed AI fixes contribute to a " 

91 "non-zero exit code." 

92 ), 

93 ) 

94 verbose: bool = Field(default=False) 

95 include_paths: list[str] = Field( 

96 default_factory=list, 

97 description="Glob patterns for paths to include in AI processing.", 

98 ) 

99 exclude_paths: list[str] = Field( 

100 default_factory=list, 

101 description="Glob patterns for paths to exclude from AI processing.", 

102 ) 

103 include_rules: list[str] = Field( 

104 default_factory=list, 

105 description="Glob patterns for rules to include in AI processing.", 

106 ) 

107 exclude_rules: list[str] = Field( 

108 default_factory=list, 

109 description="Glob patterns for rules to exclude from AI processing.", 

110 ) 

111 min_confidence: ConfidenceLevel = Field( 

112 default=ConfidenceLevel.LOW, 

113 description=( 

114 "Minimum confidence level for AI fix suggestions. " 

115 "Suggestions below this threshold are discarded. " 

116 "One of 'low', 'medium', 'high'." 

117 ), 

118 ) 

119 github_pr_comments: bool = Field( 

120 default=False, 

121 description=( 

122 "Post AI summaries and fix suggestions as inline PR review " 

123 "comments when running in GitHub Actions." 

124 ), 

125 ) 

126 dry_run: bool = Field( 

127 default=False, 

128 description=( 

129 "Display AI fix suggestions without applying them. " 

130 "Useful for previewing what changes the AI would make." 

131 ), 

132 ) 

133 max_cost_usd: float | None = Field( 

134 default=None, 

135 ge=0, 

136 description=( 

137 "Maximum total cost in USD per AI session." " None disables the limit." 

138 ), 

139 ) 

140 max_prompt_tokens: int = Field( 

141 default=12000, 

142 ge=1000, 

143 description="Token budget for fix prompts before context trimming.", 

144 ) 

145 stream: bool = Field( 

146 default=False, 

147 description="Stream AI responses token-by-token in interactive mode.", 

148 ) 

149 sanitize_mode: SanitizeMode = Field( 

150 default=SanitizeMode.WARN, 

151 description=( 

152 "How to handle detected prompt injection patterns in source " 

153 "files: 'warn' logs and continues, 'block' skips the file, " 

154 "'off' disables detection." 

155 ), 

156 ) 

157 

158 @model_validator(mode="after") 

159 def _check_retry_delays(self) -> AIConfig: 

160 if self.retry_max_delay < self.retry_base_delay: 

161 msg = ( 

162 f"retry_max_delay ({self.retry_max_delay}) must be >= " 

163 f"retry_base_delay ({self.retry_base_delay})" 

164 ) 

165 raise ValueError(msg) 

166 return self 

167 

168 # -- Grouped views ----------------------------------------------------- 

169 

170 @property 

171 def provider_config(self) -> AIProviderConfig: 

172 """Return a frozen snapshot of provider-related settings.""" 

173 return AIProviderConfig( 

174 provider=self.provider, 

175 model=self.model, 

176 api_key_env=self.api_key_env, 

177 api_base_url=self.api_base_url, 

178 api_region=self.api_region, 

179 fallback_models=tuple(self.fallback_models), 

180 max_tokens=self.max_tokens, 

181 max_retries=self.max_retries, 

182 api_timeout=self.api_timeout, 

183 retry_base_delay=self.retry_base_delay, 

184 retry_max_delay=self.retry_max_delay, 

185 retry_backoff_factor=self.retry_backoff_factor, 

186 ) 

187 

188 @property 

189 def budget_config(self) -> AIBudgetConfig: 

190 """Return a frozen snapshot of budget and limit settings.""" 

191 return AIBudgetConfig( 

192 max_fix_attempts=self.max_fix_attempts, 

193 max_parallel_calls=self.max_parallel_calls, 

194 max_cost_usd=self.max_cost_usd, 

195 max_prompt_tokens=self.max_prompt_tokens, 

196 max_refinement_attempts=self.max_refinement_attempts, 

197 enable_cache=self.enable_cache, 

198 cache_ttl=self.cache_ttl, 

199 cache_max_entries=self.cache_max_entries, 

200 context_lines=self.context_lines, 

201 fix_search_radius=self.fix_search_radius, 

202 ) 

203 

204 @property 

205 def output_config(self) -> AIOutputConfig: 

206 """Return a frozen snapshot of output and display settings.""" 

207 return AIOutputConfig( 

208 show_cost_estimate=self.show_cost_estimate, 

209 verbose=self.verbose, 

210 stream=self.stream, 

211 dry_run=self.dry_run, 

212 github_pr_comments=self.github_pr_comments, 

213 validate_after_group=self.validate_after_group, 

214 auto_apply=self.auto_apply, 

215 auto_apply_safe_fixes=self.auto_apply_safe_fixes, 

216 default_fix=self.default_fix, 

217 fail_on_ai_error=self.fail_on_ai_error, 

218 fail_on_unfixed=self.fail_on_unfixed, 

219 min_confidence=self.min_confidence, 

220 sanitize_mode=self.sanitize_mode, 

221 include_paths=tuple(self.include_paths), 

222 exclude_paths=tuple(self.exclude_paths), 

223 include_rules=tuple(self.include_rules), 

224 exclude_rules=tuple(self.exclude_rules), 

225 )