Coverage for tests / unit / ai / test_registry.py: 100%
82 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-04-03 18:53 +0000
1"""Tests for the AI provider registry."""
3from __future__ import annotations
5from dataclasses import FrozenInstanceError, asdict
7import pytest
8from assertpy import assert_that
10from lintro.ai.registry import (
11 DEFAULT_PRICING,
12 PROVIDERS,
13 AIProvider,
14 ModelPricing,
15 ProviderInfo,
16)
18# -- AIProvider StrEnum ----------------------------------------------------
21def test_aiprovider_members():
22 """All expected members exist with lowercase string values."""
23 assert_that(AIProvider.ANTHROPIC).is_equal_to("anthropic")
24 assert_that(AIProvider.OPENAI).is_equal_to("openai")
27def test_aiprovider_is_str():
28 """StrEnum members are str instances."""
29 for member in AIProvider:
30 assert_that(member).is_instance_of(str)
33def test_aiprovider_iteration():
34 """Iterating AIProvider yields all members."""
35 members = list(AIProvider)
36 assert_that(len(members)).is_greater_than_or_equal_to(2)
37 assert_that(members).contains(AIProvider.ANTHROPIC, AIProvider.OPENAI)
40def test_aiprovider_from_string():
41 """AIProvider can be constructed from a plain string."""
42 assert_that(AIProvider("anthropic")).is_equal_to(AIProvider.ANTHROPIC)
43 assert_that(AIProvider("openai")).is_equal_to(AIProvider.OPENAI)
46def test_aiprovider_invalid_value_raises():
47 """Constructing AIProvider with an unknown value raises ValueError."""
48 with pytest.raises(ValueError, match="not a valid"):
49 AIProvider("gemini")
52# -- ModelPricing ----------------------------------------------------------
55def test_model_pricing_fields():
56 """ModelPricing stores input and output rates."""
57 p = ModelPricing(input_per_million=3.00, output_per_million=15.00)
58 assert_that(p.input_per_million).is_equal_to(3.00)
59 assert_that(p.output_per_million).is_equal_to(15.00)
62def test_model_pricing_frozen():
63 """ModelPricing is immutable."""
64 p = ModelPricing(1.0, 2.0)
65 with pytest.raises(FrozenInstanceError):
66 p.input_per_million = 999.0 # type: ignore[misc]
69# -- ProviderInfo ----------------------------------------------------------
72def test_provider_info_fields():
73 """ProviderInfo stores all expected attributes."""
74 info = ProviderInfo(
75 default_model="test-model",
76 default_api_key_env="TEST_KEY",
77 models={"test-model": ModelPricing(1.0, 2.0)},
78 )
79 assert_that(info.default_model).is_equal_to("test-model")
80 assert_that(info.default_api_key_env).is_equal_to("TEST_KEY")
81 assert_that(info.models).contains_key("test-model")
84def test_provider_info_default_models_empty():
85 """ProviderInfo.models defaults to an empty dict."""
86 info = ProviderInfo(default_model="m", default_api_key_env="K")
87 assert_that(info.models).is_empty()
90# -- AIProviderRegistry ----------------------------------------------------
93def test_registry_items():
94 """items() yields all providers."""
95 items = list(PROVIDERS.items())
96 assert_that(items).is_length(len(list(AIProvider)))
97 providers = [p for p, _ in items]
98 assert_that(providers).contains(AIProvider.ANTHROPIC, AIProvider.OPENAI)
101def test_registry_get():
102 """get() returns the correct ProviderInfo."""
103 info = PROVIDERS.get(AIProvider.ANTHROPIC)
104 assert_that(info).is_same_as(PROVIDERS.anthropic)
105 info = PROVIDERS.get(AIProvider.OPENAI)
106 assert_that(info).is_same_as(PROVIDERS.openai)
109def test_registry_model_pricing_contains_all_models():
110 """model_pricing merges every model from all providers."""
111 pricing = PROVIDERS.model_pricing
112 for _provider, info in PROVIDERS.items():
113 for model_name in info.models:
114 assert_that(pricing).contains_key(model_name)
117def test_registry_model_pricing_values_are_model_pricing():
118 """Every value in model_pricing is a ModelPricing instance."""
119 for p in PROVIDERS.model_pricing.values():
120 assert_that(p).is_instance_of(ModelPricing)
123def test_registry_default_models():
124 """default_models maps each AIProvider to a string."""
125 defaults = PROVIDERS.default_models
126 assert_that(defaults).contains_key(AIProvider.ANTHROPIC)
127 assert_that(defaults).contains_key(AIProvider.OPENAI)
128 for model in defaults.values():
129 assert_that(model).is_instance_of(str)
132def test_registry_default_api_key_envs():
133 """default_api_key_envs maps each AIProvider to a string."""
134 envs = PROVIDERS.default_api_key_envs
135 assert_that(envs).contains_key(AIProvider.ANTHROPIC)
136 assert_that(envs).contains_key(AIProvider.OPENAI)
137 assert_that(envs[AIProvider.ANTHROPIC]).is_equal_to("ANTHROPIC_API_KEY")
138 assert_that(envs[AIProvider.OPENAI]).is_equal_to("OPENAI_API_KEY")
141def test_registry_default_model_in_provider_models():
142 """Every default model exists in its provider's models dict."""
143 for _provider, info in PROVIDERS.items():
144 assert_that(info.models).contains_key(info.default_model)
147# -- asdict ----------------------------------------------------------------
150def test_asdict_produces_nested_dict():
151 """asdict(PROVIDERS) produces a correct nested dictionary."""
152 d = asdict(PROVIDERS)
153 assert_that(d).contains_key("anthropic", "openai")
154 anthropic_info = d["anthropic"]
155 assert_that(anthropic_info).contains_key(
156 "default_model",
157 "default_api_key_env",
158 "models",
159 )
160 # Models are nested dicts with pricing fields.
161 for pricing in anthropic_info["models"].values():
162 assert_that(pricing).contains_key(
163 "input_per_million",
164 "output_per_million",
165 )
168# -- DEFAULT_PRICING -------------------------------------------------------
171def test_default_pricing_is_model_pricing():
172 """DEFAULT_PRICING is a ModelPricing instance."""
173 assert_that(DEFAULT_PRICING).is_instance_of(ModelPricing)
176def test_default_pricing_values():
177 """DEFAULT_PRICING has expected fallback values."""
178 assert_that(DEFAULT_PRICING.input_per_million).is_equal_to(3.00)
179 assert_that(DEFAULT_PRICING.output_per_million).is_equal_to(15.00)