Coverage for tests / unit / utils / async_tool_executor / test_parallel_batches.py: 100%
94 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for get_parallel_batches function."""
3from __future__ import annotations
5from unittest.mock import MagicMock
7from assertpy import assert_that
9from lintro.utils.async_tool_executor import get_parallel_batches
12def test_get_parallel_batches_empty_tools_list() -> None:
13 """Test batching with empty tools list."""
14 mock_manager = MagicMock()
16 batches = get_parallel_batches([], mock_manager)
18 assert_that(batches).is_empty()
21def test_get_parallel_batches_single_tool() -> None:
22 """Test single tool results in single batch."""
23 mock_manager = MagicMock()
24 mock_tool = MagicMock()
25 mock_tool.definition.conflicts_with = []
26 mock_manager.get_tool.return_value = mock_tool
28 batches = get_parallel_batches(["ruff"], mock_manager)
30 assert_that(batches).is_length(1)
31 assert_that(batches[0]).is_equal_to(["ruff"])
34def test_get_parallel_batches_no_conflicts() -> None:
35 """Test tools without conflicts go into single batch."""
36 mock_manager = MagicMock()
38 def get_tool(name: str) -> MagicMock:
39 mock = MagicMock()
40 mock.definition.conflicts_with = []
41 return mock
43 mock_manager.get_tool.side_effect = get_tool
45 batches = get_parallel_batches(["ruff", "mypy", "bandit"], mock_manager)
47 assert_that(batches).is_length(1)
48 assert_that(batches[0]).contains("ruff", "mypy", "bandit")
51def test_get_parallel_batches_conflicting_tools() -> None:
52 """Test conflicting tools are put in separate batches."""
53 mock_manager = MagicMock()
55 def get_tool(name: str) -> MagicMock:
56 mock = MagicMock()
57 if name == "black":
58 mock.definition.conflicts_with = ["ruff"]
59 elif name == "ruff":
60 mock.definition.conflicts_with = ["black"]
61 else:
62 mock.definition.conflicts_with = []
63 return mock
65 mock_manager.get_tool.side_effect = get_tool
67 batches = get_parallel_batches(["black", "ruff", "mypy"], mock_manager)
69 assert_that(len(batches)).is_greater_than_or_equal_to(2)
71 black_batch = None
72 ruff_batch = None
73 for i, batch in enumerate(batches):
74 if "black" in batch:
75 black_batch = i
76 if "ruff" in batch:
77 ruff_batch = i
79 assert_that(black_batch).is_not_equal_to(ruff_batch)
82def test_get_parallel_batches_multiple_conflicts() -> None:
83 """Test multiple conflicting tool pairs create appropriate batches."""
84 mock_manager = MagicMock()
86 def get_tool(name: str) -> MagicMock:
87 mock = MagicMock()
88 conflicts: dict[str, list[str]] = {
89 "tool_a": ["tool_b"],
90 "tool_b": ["tool_a"],
91 "tool_c": ["tool_d"],
92 "tool_d": ["tool_c"],
93 "tool_e": [],
94 }
95 mock.definition.conflicts_with = conflicts.get(name, [])
96 return mock
98 mock_manager.get_tool.side_effect = get_tool
100 batches = get_parallel_batches(
101 ["tool_a", "tool_b", "tool_c", "tool_d", "tool_e"],
102 mock_manager,
103 )
105 for batch in batches:
106 if "tool_a" in batch:
107 assert_that(batch).does_not_contain("tool_b")
108 if "tool_c" in batch:
109 assert_that(batch).does_not_contain("tool_d")
112def test_get_parallel_batches_ordering_preserved() -> None:
113 """Test that original tool order is preserved within batches."""
114 mock_manager = MagicMock()
116 def get_tool(name: str) -> MagicMock:
117 mock = MagicMock()
118 mock.definition.conflicts_with = []
119 return mock
121 mock_manager.get_tool.side_effect = get_tool
123 tools = ["first", "second", "third"]
124 batches = get_parallel_batches(tools, mock_manager)
126 assert_that(batches[0]).is_equal_to(tools)
129def test_get_parallel_batches_handles_missing_tool() -> None:
130 """Test graceful handling when tool is not found."""
131 mock_manager = MagicMock()
133 def get_tool(name: str) -> MagicMock:
134 if name == "missing":
135 raise KeyError("Tool not found")
136 mock = MagicMock()
137 mock.definition.conflicts_with = []
138 return mock
140 mock_manager.get_tool.side_effect = get_tool
142 batches = get_parallel_batches(["valid", "missing"], mock_manager)
144 assert_that(batches).is_length(1)
145 assert_that(batches[0]).contains("valid", "missing")
148def test_get_parallel_batches_handles_tool_without_conflicts_attribute() -> None:
149 """Test handling tools without conflicts_with attribute."""
150 mock_manager = MagicMock()
152 def get_tool(name: str) -> MagicMock:
153 mock = MagicMock()
154 if name == "no_conflicts_attr":
155 del mock.definition.conflicts_with
156 else:
157 mock.definition.conflicts_with = []
158 return mock
160 mock_manager.get_tool.side_effect = get_tool
162 batches = get_parallel_batches(["normal", "no_conflicts_attr"], mock_manager)
164 assert_that(batches).is_length(1)