Coverage for tests / unit / ai / providers / test_stream.py: 99%
84 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for AIStreamResult and BaseAIProvider.stream_complete()."""
3from __future__ import annotations
5import pytest
6from assertpy import assert_that
8from lintro.ai.providers.base import AIResponse, AIStreamResult, BaseAIProvider
9from lintro.ai.providers.constants import DEFAULT_PER_CALL_MAX_TOKENS, DEFAULT_TIMEOUT
12class _StubProvider(BaseAIProvider):
13 """Minimal concrete provider for testing the base default behaviour."""
15 def __init__(self, response: AIResponse) -> None:
16 self._response = response
17 self._provider_name = "stub"
18 self._has_sdk = True
19 self._model = "stub-model"
20 self._api_key_env = "STUB_KEY"
21 self._max_tokens = DEFAULT_PER_CALL_MAX_TOKENS
22 self._base_url = None
23 self._client = "fake"
25 def _create_client(self, *, api_key: str) -> object:
26 return "fake"
28 def complete(
29 self,
30 prompt: str,
31 *,
32 system: str | None = None,
33 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
34 timeout: float = DEFAULT_TIMEOUT,
35 ) -> AIResponse:
36 return self._response
39def _make_response(content: str = "hello world") -> AIResponse:
40 return AIResponse(
41 content=content,
42 model="test-model",
43 input_tokens=10,
44 output_tokens=5,
45 cost_estimate=0.001,
46 provider="test",
47 )
50def test_stream_result_iter_yields_chunks() -> None:
51 """Iterating an AIStreamResult yields all provided chunks."""
52 chunks = ["foo", "bar", "baz"]
53 resp = _make_response("foobarbaz")
54 result = AIStreamResult(_chunks=iter(chunks), _on_done=lambda: resp)
56 assert_that(list(result)).is_equal_to(["foo", "bar", "baz"])
59def test_stream_result_response_returns_metadata() -> None:
60 """response() returns the AIResponse supplied by _on_done."""
61 resp = _make_response()
62 result = AIStreamResult(_chunks=iter([]), _on_done=lambda: resp)
63 list(result)
65 assert_that(result.response()).is_equal_to(resp)
68def test_stream_result_collect_concatenates_and_returns_response() -> None:
69 """collect() joins chunks and populates content in the returned AIResponse."""
70 resp = _make_response("")
71 result = AIStreamResult(
72 _chunks=iter(["alpha", " ", "beta"]),
73 _on_done=lambda: resp,
74 )
76 collected = result.collect()
78 assert_that(collected.content).is_equal_to("alpha beta")
79 assert_that(collected.model).is_equal_to("test-model")
80 assert_that(collected.input_tokens).is_equal_to(10)
81 assert_that(collected.output_tokens).is_equal_to(5)
82 assert_that(collected.provider).is_equal_to("test")
85def test_stream_result_collect_empty_stream() -> None:
86 """collect() with no chunks returns empty content."""
87 resp = _make_response("")
88 result = AIStreamResult(_chunks=iter([]), _on_done=lambda: resp)
90 assert_that(result.collect().content).is_equal_to("")
93@pytest.mark.parametrize(
94 ("chunks", "expected"),
95 [
96 (["a"], "a"),
97 (["a", "b", "c"], "abc"),
98 ([""], ""),
99 (["hello ", "world"], "hello world"),
100 ],
101 ids=["single", "multi", "empty-chunk", "with-space"],
102)
103def test_stream_result_collect_various_chunk_patterns(
104 chunks: list[str],
105 expected: str,
106) -> None:
107 """collect() works correctly with various chunk patterns."""
108 resp = _make_response("")
109 result = AIStreamResult(_chunks=iter(chunks), _on_done=lambda: resp)
111 assert_that(result.collect().content).is_equal_to(expected)
114def test_base_provider_stream_complete_delegates_to_complete() -> None:
115 """Default stream_complete wraps complete() in a single-chunk stream."""
116 resp = _make_response("delegated content")
117 provider = _StubProvider(response=resp)
119 stream = provider.stream_complete("test prompt")
120 collected = stream.collect()
122 assert_that(collected.content).is_equal_to("delegated content")
123 assert_that(collected.model).is_equal_to("test-model")
124 assert_that(collected.provider).is_equal_to("test")
127def test_base_provider_stream_complete_passes_kwargs() -> None:
128 """Default stream_complete forwards system/max_tokens/timeout to complete."""
129 calls: list[dict[str, object]] = []
131 class _CapturingProvider(_StubProvider):
132 def complete(
133 self,
134 prompt: str,
135 *,
136 system: str | None = None,
137 max_tokens: int = DEFAULT_PER_CALL_MAX_TOKENS,
138 timeout: float = DEFAULT_TIMEOUT,
139 ) -> AIResponse:
140 calls.append(
141 {
142 "prompt": prompt,
143 "system": system,
144 "max_tokens": max_tokens,
145 "timeout": timeout,
146 },
147 )
148 return _make_response()
150 provider = _CapturingProvider(response=_make_response())
151 stream = provider.stream_complete(
152 "my prompt",
153 system="sys",
154 max_tokens=512,
155 timeout=30,
156 )
157 list(stream) # consume stream to trigger complete()
159 assert_that(calls).is_length(1)
160 assert_that(calls[0]["prompt"]).is_equal_to("my prompt")
161 assert_that(calls[0]["system"]).is_equal_to("sys")
162 assert_that(calls[0]["max_tokens"]).is_equal_to(512)
163 assert_that(calls[0]["timeout"]).is_equal_to(30)
166def test_base_provider_stream_complete_single_chunk_iteration() -> None:
167 """Default stream_complete yields exactly one chunk with the full content."""
168 resp = _make_response("one shot")
169 provider = _StubProvider(response=resp)
171 stream = provider.stream_complete("p")
172 chunks = list(stream)
174 assert_that(chunks).is_equal_to(["one shot"])
177def test_collect_raises_on_double_call() -> None:
178 """collect() raises RuntimeError when called a second time."""
179 resp = _make_response("")
180 result = AIStreamResult(
181 _chunks=iter(["alpha", " ", "beta"]),
182 _on_done=lambda: resp,
183 )
185 # First call succeeds
186 collected = result.collect()
187 assert_that(collected.content).is_equal_to("alpha beta")
189 # Second call raises
190 with pytest.raises(RuntimeError, match="already consumed"):
191 result.collect()