Coverage for tests / unit / ai / providers / test_stream.py: 99%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Tests for AIStreamResult and BaseAIProvider.stream_complete().""" 

2 

3from __future__ import annotations 

4 

5import pytest 

6from assertpy import assert_that 

7 

8from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider 

9from lintro.ai.providers.constants import DEFAULT_PER_CALL_MAX_TOKENS, DEFAULT_TIMEOUT 

10 

11 

12class _StubProvider(BaseAIProvider): 

13 """Minimal concrete provider for testing the base default behaviour.""" 

14 

15 def __init__(self, response: AIResponse) -> None: 

16 self._response = response 

17 self._provider_name = "stub" 

18 self._has_sdk = True 

19 self._model = "stub-model" 

20 self._api_key_env = "STUB_KEY" 

21 self._max_tokens = DEFAULT_PER_CALL_MAX_TOKENS 

22 self._base_url = None 

23 self._client = "fake" 

24 

25 def _create_client(self, *, api_key: str) -> object: 

26 return "fake" 

27 

28 def complete( 

29 self, 

30 prompt: str, 

31 *, 

32 system: str | None = None, 

33 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

34 timeout: float = DEFAULT_TIMEOUT, 

35 ) -> AIResponse: 

36 return self._response 

37 

38 

39def _make_response(content: str = "hello world") -> AIResponse: 

40 return AIResponse( 

41 content=content, 

42 model="test-model", 

43 input_tokens=10, 

44 output_tokens=5, 

45 cost_estimate=0.001, 

46 provider="test", 

47 ) 

48 

49 

50def test_stream_result_iter_yields_chunks() -> None: 

51 """Iterating an AIStreamResult yields all provided chunks.""" 

52 chunks = ["foo", "bar", "baz"] 

53 resp = _make_response("foobarbaz") 

54 result = AIStreamResult(_chunks=iter(chunks), _on_done=lambda: resp) 

55 

56 assert_that(list(result)).is_equal_to(["foo", "bar", "baz"]) 

57 

58 

59def test_stream_result_response_returns_metadata() -> None: 

60 """response() returns the AIResponse supplied by _on_done.""" 

61 resp = _make_response() 

62 result = AIStreamResult(_chunks=iter([]), _on_done=lambda: resp) 

63 list(result) 

64 

65 assert_that(result.response()).is_equal_to(resp) 

66 

67 

68def test_stream_result_collect_concatenates_and_returns_response() -> None: 

69 """collect() joins chunks and populates content in the returned AIResponse.""" 

70 resp = _make_response("") 

71 result = AIStreamResult( 

72 _chunks=iter(["alpha", " ", "beta"]), 

73 _on_done=lambda: resp, 

74 ) 

75 

76 collected = result.collect() 

77 

78 assert_that(collected.content).is_equal_to("alpha beta") 

79 assert_that(collected.model).is_equal_to("test-model") 

80 assert_that(collected.input_tokens).is_equal_to(10) 

81 assert_that(collected.output_tokens).is_equal_to(5) 

82 assert_that(collected.provider).is_equal_to("test") 

83 

84 

85def test_stream_result_collect_empty_stream() -> None: 

86 """collect() with no chunks returns empty content.""" 

87 resp = _make_response("") 

88 result = AIStreamResult(_chunks=iter([]), _on_done=lambda: resp) 

89 

90 assert_that(result.collect().content).is_equal_to("") 

91 

92 

93@pytest.mark.parametrize( 

94 ("chunks", "expected"), 

95 [ 

96 (["a"], "a"), 

97 (["a", "b", "c"], "abc"), 

98 ([""], ""), 

99 (["hello ", "world"], "hello world"), 

100 ], 

101 ids=["single", "multi", "empty-chunk", "with-space"], 

102) 

103def test_stream_result_collect_various_chunk_patterns( 

104 chunks: list[str], 

105 expected: str, 

106) -> None: 

107 """collect() works correctly with various chunk patterns.""" 

108 resp = _make_response("") 

109 result = AIStreamResult(_chunks=iter(chunks), _on_done=lambda: resp) 

110 

111 assert_that(result.collect().content).is_equal_to(expected) 

112 

113 

114def test_base_provider_stream_complete_delegates_to_complete() -> None: 

115 """Default stream_complete wraps complete() in a single-chunk stream.""" 

116 resp = _make_response("delegated content") 

117 provider = _StubProvider(response=resp) 

118 

119 stream = provider.stream_complete("test prompt") 

120 collected = stream.collect() 

121 

122 assert_that(collected.content).is_equal_to("delegated content") 

123 assert_that(collected.model).is_equal_to("test-model") 

124 assert_that(collected.provider).is_equal_to("test") 

125 

126 

127def test_base_provider_stream_complete_passes_kwargs() -> None: 

128 """Default stream_complete forwards system/max_tokens/timeout to complete.""" 

129 calls: list[dict[str, object]] = [] 

130 

131 class _CapturingProvider(_StubProvider): 

132 def complete( 

133 self, 

134 prompt: str, 

135 *, 

136 system: str | None = None, 

137 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS, 

138 timeout: float = DEFAULT_TIMEOUT, 

139 ) -> AIResponse: 

140 calls.append( 

141 { 

142 "prompt": prompt, 

143 "system": system, 

144 "max_tokens": max_tokens, 

145 "timeout": timeout, 

146 }, 

147 ) 

148 return _make_response() 

149 

150 provider = _CapturingProvider(response=_make_response()) 

151 stream = provider.stream_complete( 

152 "my prompt", 

153 system="sys", 

154 max_tokens=512, 

155 timeout=30, 

156 ) 

157 list(stream) # consume stream to trigger complete() 

158 

159 assert_that(calls).is_length(1) 

160 assert_that(calls[0]["prompt"]).is_equal_to("my prompt") 

161 assert_that(calls[0]["system"]).is_equal_to("sys") 

162 assert_that(calls[0]["max_tokens"]).is_equal_to(512) 

163 assert_that(calls[0]["timeout"]).is_equal_to(30) 

164 

165 

166def test_base_provider_stream_complete_single_chunk_iteration() -> None: 

167 """Default stream_complete yields exactly one chunk with the full content.""" 

168 resp = _make_response("one shot") 

169 provider = _StubProvider(response=resp) 

170 

171 stream = provider.stream_complete("p") 

172 chunks = list(stream) 

173 

174 assert_that(chunks).is_equal_to(["one shot"]) 

175 

176 

177def test_collect_raises_on_double_call() -> None: 

178 """collect() raises RuntimeError when called a second time.""" 

179 resp = _make_response("") 

180 result = AIStreamResult( 

181 _chunks=iter(["alpha", " ", "beta"]), 

182 _on_done=lambda: resp, 

183 ) 

184 

185 # First call succeeds 

186 collected = result.collect() 

187 assert_that(collected.content).is_equal_to("alpha beta") 

188 

189 # Second call raises 

190 with pytest.raises(RuntimeError, match="already consumed"): 

191 result.collect()