Coverage for tests / unit / ai / providers / test_openai.py: 100%
116 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for OpenAI AI provider."""
3from __future__ import annotations
5from unittest.mock import MagicMock, patch
7import pytest
8from assertpy import assert_that
10from lintro.ai.exceptions import (
11 AIAuthenticationError,
12 AINotAvailableError,
13)
14from lintro.ai.providers import openai as mod
15from lintro.ai.providers.openai import OpenAIProvider
18def test_openai_provider_raises_when_sdk_missing():
19 """Verify that OpenAIProvider raises AINotAvailableError when the SDK is missing."""
20 with (
21 patch.object(mod, "_has_openai", False),
22 pytest.raises(AINotAvailableError),
23 ):
24 OpenAIProvider()
27def test_openai_provider_default_model():
28 """Verify that OpenAIProvider uses the expected default model and provider name."""
29 with patch.object(mod, "_has_openai", True):
30 provider = OpenAIProvider()
32 assert_that(provider.model_name).is_equal_to("gpt-4o")
33 assert_that(provider.name).is_equal_to("openai")
36def test_openai_provider_is_available_with_no_key():
37 """Verify that is_available returns False when no API key is set."""
38 with patch.object(mod, "_has_openai", True):
39 provider = OpenAIProvider()
40 provider._api_key_env = "NONEXISTENT_KEY_VAR"
42 with patch.dict("os.environ", {}, clear=True):
43 assert_that(provider.is_available()).is_false()
46def test_openai_provider_is_available_with_key():
47 """Verify that is_available returns True when a valid API key is present."""
48 with patch.object(mod, "_has_openai", True):
49 provider = OpenAIProvider()
50 provider._api_key_env = "TEST_API_KEY"
52 with patch.dict(
53 "os.environ",
54 {"TEST_API_KEY": "sk-test"},
55 ):
56 assert_that(provider.is_available()).is_true()
59def test_openai_provider_get_client_no_key_raises():
60 """_get_client raises AIAuthenticationError when key missing."""
61 with patch.object(mod, "_has_openai", True):
62 provider = OpenAIProvider()
63 provider._api_key_env = "NONEXISTENT_KEY"
65 with (
66 patch.dict("os.environ", {}, clear=True),
67 pytest.raises(AIAuthenticationError),
68 ):
69 provider._get_client()
72def test_openai_complete_parses_response():
73 """complete() extracts content, tokens, and cost from SDK response."""
74 with patch.object(mod, "_has_openai", True):
75 provider = OpenAIProvider()
76 provider._api_key_env = "TEST_KEY"
78 mock_message = MagicMock()
79 mock_message.content = "Hello from GPT!"
81 mock_choice = MagicMock()
82 mock_choice.message = mock_message
84 mock_usage = MagicMock()
85 mock_usage.prompt_tokens = 200
86 mock_usage.completion_tokens = 80
88 mock_response = MagicMock()
89 mock_response.choices = [mock_choice]
90 mock_response.usage = mock_usage
92 mock_client = MagicMock()
93 mock_client.chat.completions.create.return_value = mock_response
94 provider._client = mock_client
96 with patch.dict("os.environ", {"TEST_KEY": "sk-test"}):
97 result = provider.complete(
98 "test prompt",
99 system="be helpful",
100 )
102 assert_that(result.content).is_equal_to("Hello from GPT!")
103 assert_that(result.input_tokens).is_equal_to(200)
104 assert_that(result.output_tokens).is_equal_to(80)
105 assert_that(result.provider).is_equal_to("openai")
106 assert_that(result.cost_estimate).is_greater_than_or_equal_to(0.0)
108 call_kwargs = mock_client.chat.completions.create.call_args[1]
109 assert_that(call_kwargs["messages"]).is_equal_to(
110 [
111 {"role": "system", "content": "be helpful"},
112 {"role": "user", "content": "test prompt"},
113 ],
114 )
117def test_openai_complete_without_system_prompt():
118 """complete() omits system message when system is None."""
119 with patch.object(mod, "_has_openai", True):
120 provider = OpenAIProvider()
122 mock_message = MagicMock()
123 mock_message.content = "response"
125 mock_choice = MagicMock()
126 mock_choice.message = mock_message
128 mock_usage = MagicMock()
129 mock_usage.prompt_tokens = 10
130 mock_usage.completion_tokens = 5
132 mock_response = MagicMock()
133 mock_response.choices = [mock_choice]
134 mock_response.usage = mock_usage
136 mock_client = MagicMock()
137 mock_client.chat.completions.create.return_value = mock_response
138 provider._client = mock_client
140 with patch.dict("os.environ", {"OPENAI_API_KEY": "sk-test"}):
141 provider.complete("prompt")
143 call_kwargs = mock_client.chat.completions.create.call_args[1]
144 assert_that(call_kwargs["messages"]).is_equal_to(
145 [{"role": "user", "content": "prompt"}],
146 )
149def test_openai_complete_handles_none_usage():
150 """complete() handles None usage gracefully (tokens default to 0)."""
151 with patch.object(mod, "_has_openai", True):
152 provider = OpenAIProvider()
154 mock_message = MagicMock()
155 mock_message.content = "response"
157 mock_choice = MagicMock()
158 mock_choice.message = mock_message
160 mock_response = MagicMock()
161 mock_response.choices = [mock_choice]
162 mock_response.usage = None
164 mock_client = MagicMock()
165 mock_client.chat.completions.create.return_value = mock_response
166 provider._client = mock_client
168 with patch.dict("os.environ", {"OPENAI_API_KEY": "sk-test"}):
169 result = provider.complete("prompt")
171 assert_that(result.input_tokens).is_equal_to(0)
172 assert_that(result.output_tokens).is_equal_to(0)
175def test_openai_complete_respects_max_tokens_cap():
176 """complete() uses the lower of per-call and provider-level max_tokens."""
177 with patch.object(mod, "_has_openai", True):
178 provider = OpenAIProvider(max_tokens=2048)
180 mock_message = MagicMock()
181 mock_message.content = "ok"
182 mock_choice = MagicMock()
183 mock_choice.message = mock_message
184 mock_usage = MagicMock()
185 mock_usage.prompt_tokens = 10
186 mock_usage.completion_tokens = 5
188 mock_response = MagicMock()
189 mock_response.choices = [mock_choice]
190 mock_response.usage = mock_usage
192 mock_client = MagicMock()
193 mock_client.chat.completions.create.return_value = mock_response
194 provider._client = mock_client
196 with patch.dict("os.environ", {"OPENAI_API_KEY": "sk-test"}):
197 provider.complete("prompt", max_tokens=4096)
199 call_kwargs = mock_client.chat.completions.create.call_args[1]
200 assert_that(call_kwargs["max_tokens"]).is_equal_to(2048)