Coverage for lintro / utils / async_tool_executor.py: 96%

90 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Async tool execution for parallel linting. 

2 

3This module provides functionality to run multiple linting tools in parallel 

4using asyncio and ThreadPoolExecutor for subprocess isolation. 

5""" 

6 

7from __future__ import annotations 

8 

9import asyncio 

10import os 

11from collections.abc import Callable 

12from concurrent.futures import ThreadPoolExecutor 

13from dataclasses import dataclass, field 

14from typing import TYPE_CHECKING, Any 

15 

16from loguru import logger 

17 

18if TYPE_CHECKING: 

19 from lintro.enums.action import Action 

20 from lintro.models.core.tool_result import ToolResult 

21 from lintro.plugins.base import BaseToolPlugin 

22 

23 

24def _get_default_max_workers() -> int: 

25 """Get default max workers based on CPU count. 

26 

27 Returns: 

28 Number of CPUs available, clamped between 1 and 32. 

29 """ 

30 cpu_count = os.cpu_count() or 4 

31 return max(1, min(cpu_count, 32)) 

32 

33 

34@dataclass 

35class AsyncToolExecutor: 

36 """Execute tools in parallel using a thread pool. 

37 

38 Tools are executed in a ThreadPoolExecutor to avoid blocking the event loop, 

39 since each tool runs as a subprocess which is inherently blocking. 

40 

41 Attributes: 

42 max_workers: Maximum number of concurrent tool executions (default: CPU count). 

43 """ 

44 

45 max_workers: int = field(default_factory=_get_default_max_workers) 

46 _executor: ThreadPoolExecutor | None = field(default=None, init=False) 

47 

48 def __post_init__(self) -> None: 

49 """Initialize the thread pool executor.""" 

50 self._executor = ThreadPoolExecutor(max_workers=self.max_workers) 

51 

52 def __enter__(self) -> AsyncToolExecutor: 

53 """Enter context manager. 

54 

55 Returns: 

56 AsyncToolExecutor: This executor instance. 

57 """ 

58 return self 

59 

60 def __exit__( 

61 self, 

62 exc_type: type[BaseException] | None, 

63 exc_val: BaseException | None, 

64 exc_tb: Any, 

65 ) -> None: 

66 """Exit context manager and cleanup resources. 

67 

68 Args: 

69 exc_type: Exception type if an exception was raised. 

70 exc_val: Exception instance if an exception was raised. 

71 exc_tb: Traceback if an exception was raised. 

72 """ 

73 self.shutdown() 

74 

75 async def run_tool_async( 

76 self, 

77 tool: BaseToolPlugin, 

78 paths: list[str], 

79 action: Action, 

80 options: dict[str, Any] | None = None, 

81 max_fix_retries: int = 3, 

82 ) -> ToolResult: 

83 """Run a single tool asynchronously. 

84 

85 Args: 

86 tool: The tool plugin to execute. 

87 paths: List of file paths to process. 

88 action: The action to perform (check or fix). 

89 options: Additional options to pass to the tool. 

90 max_fix_retries: Maximum fix→verify convergence cycles. 

91 

92 Returns: 

93 ToolResult: The result of tool execution. 

94 

95 Raises: 

96 RuntimeError: If the executor has been shut down. 

97 """ 

98 from lintro.enums.action import Action 

99 

100 if self._executor is None: 

101 raise RuntimeError("Executor has been shut down") 

102 

103 loop = asyncio.get_running_loop() 

104 opts = options or {} 

105 

106 if action == Action.FIX: 

107 from lintro.utils.tool_executor import _run_fix_with_retry 

108 

109 logger.debug(f"Starting async execution of {tool.definition.name}") 

110 result = await loop.run_in_executor( 

111 self._executor, 

112 _run_fix_with_retry, 

113 tool, 

114 paths, 

115 opts, 

116 max_fix_retries, 

117 ) 

118 else: 

119 logger.debug(f"Starting async execution of {tool.definition.name}") 

120 result = await loop.run_in_executor( 

121 self._executor, 

122 tool.check, 

123 paths, 

124 opts, 

125 ) 

126 logger.debug(f"Completed async execution of {tool.definition.name}") 

127 

128 return result 

129 

130 async def run_tools_parallel( 

131 self, 

132 tools: list[tuple[str, BaseToolPlugin]], 

133 paths: list[str], 

134 action: Action, 

135 options_per_tool: dict[str, dict[str, Any]] | None = None, 

136 on_result: Callable[[str, ToolResult], None] | None = None, 

137 max_fix_retries: int = 3, 

138 ) -> list[tuple[str, ToolResult]]: 

139 """Run multiple tools in parallel. 

140 

141 Args: 

142 tools: List of (tool_name, tool_instance) tuples. 

143 paths: List of file paths to process. 

144 action: The action to perform. 

145 options_per_tool: Optional dict mapping tool names to their options. 

146 on_result: Optional callback called when each tool completes. 

147 max_fix_retries: Maximum fix→verify convergence cycles. 

148 

149 Returns: 

150 List of (tool_name, ToolResult) tuples in completion order. 

151 """ 

152 options = options_per_tool or {} 

153 

154 async def run_with_name( 

155 name: str, 

156 tool: BaseToolPlugin, 

157 ) -> tuple[str, ToolResult]: 

158 """Run tool and return result with name. 

159 

160 Args: 

161 name: Name of the tool. 

162 tool: Tool plugin instance to run. 

163 

164 Returns: 

165 Tuple of (tool_name, ToolResult). 

166 """ 

167 tool_opts = options.get(name, {}) 

168 result = await self.run_tool_async( 

169 tool, 

170 paths, 

171 action, 

172 tool_opts, 

173 max_fix_retries=max_fix_retries, 

174 ) 

175 if on_result: 

176 on_result(name, result) 

177 return (name, result) 

178 

179 tasks = [run_with_name(name, tool) for name, tool in tools] 

180 results = await asyncio.gather(*tasks, return_exceptions=True) 

181 

182 # Handle any exceptions that occurred 

183 processed_results: list[tuple[str, ToolResult]] = [] 

184 for i, result in enumerate(results): 

185 tool_name = tools[i][0] 

186 if isinstance(result, Exception): 

187 logger.error(f"Tool {tool_name} failed with exception: {result}") 

188 # Create a failed result 

189 from lintro.models.core.tool_result import ToolResult 

190 

191 failed_result = ToolResult( 

192 name=tool_name, 

193 success=False, 

194 output=f"Parallel execution failed: {result}", 

195 issues_count=0, 

196 ) 

197 processed_results.append((tool_name, failed_result)) 

198 else: 

199 # Result is tuple[str, ToolResult] (type narrowed by isinstance check) 

200 processed_results.append(result) # type: ignore[arg-type] 

201 

202 return processed_results 

203 

204 def shutdown(self) -> None: 

205 """Shutdown the thread pool executor.""" 

206 if self._executor: 

207 self._executor.shutdown(wait=True) 

208 self._executor = None 

209 

210 

211def get_parallel_batches( 

212 tools: list[str], 

213 tool_manager: Any, 

214) -> list[list[str]]: 

215 """Group tools into batches that can run in parallel. 

216 

217 Tools with conflicts (e.g., Black and Ruff formatter) must run in separate 

218 batches to avoid race conditions on the same files. 

219 

220 Args: 

221 tools: List of tool names to batch. 

222 tool_manager: Tool manager instance to query tool definitions. 

223 

224 Returns: 

225 List of batches, where each batch is a list of tool names that can 

226 run in parallel. 

227 """ 

228 if not tools: 

229 return [] 

230 

231 # Build conflict graph 

232 conflict_graph: dict[str, set[str]] = {name: set() for name in tools} 

233 

234 for tool_name in tools: 

235 try: 

236 tool_instance = tool_manager.get_tool(tool_name) 

237 for conflict in tool_instance.definition.conflicts_with: 

238 conflict_lower = conflict.lower() 

239 if conflict_lower in tools: 

240 conflict_graph[tool_name].add(conflict_lower) 

241 conflict_graph[conflict_lower].add(tool_name) 

242 except (KeyError, AttributeError): 

243 # Tool not found or has no conflicts 

244 pass 

245 

246 # Greedy batching: add tools to current batch if they don't conflict 

247 # with any tool already in the batch 

248 batches: list[list[str]] = [] 

249 remaining = set(tools) 

250 

251 while remaining: 

252 batch: list[str] = [] 

253 batch_conflicts: set[str] = set() 

254 

255 for tool_name in tools: # Iterate in original order for determinism 

256 if tool_name not in remaining: 

257 continue 

258 

259 # Check if this tool conflicts with anything in current batch 

260 if tool_name not in batch_conflicts: 

261 batch.append(tool_name) 

262 remaining.remove(tool_name) 

263 # Add this tool's conflicts to the set 

264 batch_conflicts.update(conflict_graph[tool_name]) 

265 batch_conflicts.add(tool_name) 

266 

267 if batch: 

268 batches.append(batch) 

269 else: 

270 # Safety: if we couldn't add anything, break to avoid infinite loop 

271 break 

272 

273 return batches