Coverage for tests / unit / ai / test_registry.py: 100%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""Tests for the AI provider registry.""" 

2 

3from __future__ import annotations 

4 

5from dataclasses import FrozenInstanceError, asdict 

6 

7import pytest 

8from assertpy import assert_that 

9 

10from lintro.ai.registry import ( 

11 DEFAULT_PRICING, 

12 PROVIDERS, 

13 AIProvider, 

14 ModelPricing, 

15 ProviderInfo, 

16) 

17 

18# -- AIProvider StrEnum ---------------------------------------------------- 

19 

20 

21def test_aiprovider_members(): 

22 """All expected members exist with lowercase string values.""" 

23 assert_that(AIProvider.ANTHROPIC).is_equal_to("anthropic") 

24 assert_that(AIProvider.OPENAI).is_equal_to("openai") 

25 

26 

27def test_aiprovider_is_str(): 

28 """StrEnum members are str instances.""" 

29 for member in AIProvider: 

30 assert_that(member).is_instance_of(str) 

31 

32 

33def test_aiprovider_iteration(): 

34 """Iterating AIProvider yields all members.""" 

35 members = list(AIProvider) 

36 assert_that(len(members)).is_greater_than_or_equal_to(2) 

37 assert_that(members).contains(AIProvider.ANTHROPIC, AIProvider.OPENAI) 

38 

39 

40def test_aiprovider_from_string(): 

41 """AIProvider can be constructed from a plain string.""" 

42 assert_that(AIProvider("anthropic")).is_equal_to(AIProvider.ANTHROPIC) 

43 assert_that(AIProvider("openai")).is_equal_to(AIProvider.OPENAI) 

44 

45 

46def test_aiprovider_invalid_value_raises(): 

47 """Constructing AIProvider with an unknown value raises ValueError.""" 

48 with pytest.raises(ValueError, match="not a valid"): 

49 AIProvider("gemini") 

50 

51 

52# -- ModelPricing ---------------------------------------------------------- 

53 

54 

55def test_model_pricing_fields(): 

56 """ModelPricing stores input and output rates.""" 

57 p = ModelPricing(input_per_million=3.00, output_per_million=15.00) 

58 assert_that(p.input_per_million).is_equal_to(3.00) 

59 assert_that(p.output_per_million).is_equal_to(15.00) 

60 

61 

62def test_model_pricing_frozen(): 

63 """ModelPricing is immutable.""" 

64 p = ModelPricing(1.0, 2.0) 

65 with pytest.raises(FrozenInstanceError): 

66 p.input_per_million = 999.0 # type: ignore[misc] 

67 

68 

69# -- ProviderInfo ---------------------------------------------------------- 

70 

71 

72def test_provider_info_fields(): 

73 """ProviderInfo stores all expected attributes.""" 

74 info = ProviderInfo( 

75 default_model="test-model", 

76 default_api_key_env="TEST_KEY", 

77 models={"test-model": ModelPricing(1.0, 2.0)}, 

78 ) 

79 assert_that(info.default_model).is_equal_to("test-model") 

80 assert_that(info.default_api_key_env).is_equal_to("TEST_KEY") 

81 assert_that(info.models).contains_key("test-model") 

82 

83 

84def test_provider_info_default_models_empty(): 

85 """ProviderInfo.models defaults to an empty dict.""" 

86 info = ProviderInfo(default_model="m", default_api_key_env="K") 

87 assert_that(info.models).is_empty() 

88 

89 

90# -- AIProviderRegistry ---------------------------------------------------- 

91 

92 

93def test_registry_items(): 

94 """items() yields all providers.""" 

95 items = list(PROVIDERS.items()) 

96 assert_that(items).is_length(len(list(AIProvider))) 

97 providers = [p for p, _ in items] 

98 assert_that(providers).contains(AIProvider.ANTHROPIC, AIProvider.OPENAI) 

99 

100 

101def test_registry_get(): 

102 """get() returns the correct ProviderInfo.""" 

103 info = PROVIDERS.get(AIProvider.ANTHROPIC) 

104 assert_that(info).is_same_as(PROVIDERS.anthropic) 

105 info = PROVIDERS.get(AIProvider.OPENAI) 

106 assert_that(info).is_same_as(PROVIDERS.openai) 

107 

108 

109def test_registry_model_pricing_contains_all_models(): 

110 """model_pricing merges every model from all providers.""" 

111 pricing = PROVIDERS.model_pricing 

112 for _provider, info in PROVIDERS.items(): 

113 for model_name in info.models: 

114 assert_that(pricing).contains_key(model_name) 

115 

116 

117def test_registry_model_pricing_values_are_model_pricing(): 

118 """Every value in model_pricing is a ModelPricing instance.""" 

119 for p in PROVIDERS.model_pricing.values(): 

120 assert_that(p).is_instance_of(ModelPricing) 

121 

122 

123def test_registry_default_models(): 

124 """default_models maps each AIProvider to a string.""" 

125 defaults = PROVIDERS.default_models 

126 assert_that(defaults).contains_key(AIProvider.ANTHROPIC) 

127 assert_that(defaults).contains_key(AIProvider.OPENAI) 

128 for model in defaults.values(): 

129 assert_that(model).is_instance_of(str) 

130 

131 

132def test_registry_default_api_key_envs(): 

133 """default_api_key_envs maps each AIProvider to a string.""" 

134 envs = PROVIDERS.default_api_key_envs 

135 assert_that(envs).contains_key(AIProvider.ANTHROPIC) 

136 assert_that(envs).contains_key(AIProvider.OPENAI) 

137 assert_that(envs[AIProvider.ANTHROPIC]).is_equal_to("ANTHROPIC_API_KEY") 

138 assert_that(envs[AIProvider.OPENAI]).is_equal_to("OPENAI_API_KEY") 

139 

140 

141def test_registry_default_model_in_provider_models(): 

142 """Every default model exists in its provider's models dict.""" 

143 for _provider, info in PROVIDERS.items(): 

144 assert_that(info.models).contains_key(info.default_model) 

145 

146 

147# -- asdict ---------------------------------------------------------------- 

148 

149 

150def test_asdict_produces_nested_dict(): 

151 """asdict(PROVIDERS) produces a correct nested dictionary.""" 

152 d = asdict(PROVIDERS) 

153 assert_that(d).contains_key("anthropic", "openai") 

154 anthropic_info = d["anthropic"] 

155 assert_that(anthropic_info).contains_key( 

156 "default_model", 

157 "default_api_key_env", 

158 "models", 

159 ) 

160 # Models are nested dicts with pricing fields. 

161 for pricing in anthropic_info["models"].values(): 

162 assert_that(pricing).contains_key( 

163 "input_per_million", 

164 "output_per_million", 

165 ) 

166 

167 

168# -- DEFAULT_PRICING ------------------------------------------------------- 

169 

170 

171def test_default_pricing_is_model_pricing(): 

172 """DEFAULT_PRICING is a ModelPricing instance.""" 

173 assert_that(DEFAULT_PRICING).is_instance_of(ModelPricing) 

174 

175 

176def test_default_pricing_values(): 

177 """DEFAULT_PRICING has expected fallback values.""" 

178 assert_that(DEFAULT_PRICING.input_per_million).is_equal_to(3.00) 

179 assert_that(DEFAULT_PRICING.output_per_million).is_equal_to(15.00)