Coverage for lintro / ai / registry.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-04-03 18:53 +0000

1"""AI provider registry — single source of truth for provider metadata. 

2 

3Consolidates model pricing, default models, and API key environment 

4variables into a frozen dataclass hierarchy keyed by an ``AIProvider`` 

5StrEnum. Every piece of provider metadata lives here; downstream 

6modules import what they need rather than maintaining parallel dicts. 

7 

8The ``AIProvider`` enum, ``ModelPricing``, and ``ProviderInfo`` dataclasses 

9are defined in :mod:`lintro.ai.provider_enum` and 

10:mod:`lintro.ai.provider_info` respectively, and re-exported here for 

11convenience. 

12""" 

13 

14from __future__ import annotations 

15 

16from collections.abc import Iterator 

17from dataclasses import dataclass, field 

18 

19from lintro.ai.provider_enum import AIProvider 

20from lintro.ai.provider_info import ModelPricing, ProviderInfo 

21 

22__all__ = [ 

23 "AIProvider", 

24 "AIProviderRegistry", 

25 "DEFAULT_PRICING", 

26 "ModelPricing", 

27 "PROVIDERS", 

28 "ProviderInfo", 

29] 

30 

31# -- Registry class -------------------------------------------------------- 

32 

33 

34@dataclass(frozen=True) 

35class AIProviderRegistry: 

36 """Frozen registry of all supported AI providers. 

37 

38 Access individual providers via attribute (``registry.anthropic``) 

39 or iterate with :meth:`items`. 

40 """ 

41 

42 anthropic: ProviderInfo 

43 openai: ProviderInfo 

44 _cached_model_pricing: dict[str, ModelPricing] = field( 

45 default_factory=dict, 

46 init=False, 

47 repr=False, 

48 ) 

49 

50 def __post_init__(self) -> None: 

51 """Pre-compute cached derived mappings.""" 

52 pricing: dict[str, ModelPricing] = {} 

53 for _provider, info in self.items(): 

54 pricing.update(info.models) 

55 object.__setattr__(self, "_cached_model_pricing", pricing) 

56 

57 def items(self) -> Iterator[tuple[AIProvider, ProviderInfo]]: 

58 """Yield ``(AIProvider, ProviderInfo)`` pairs.""" 

59 for provider in AIProvider: 

60 yield provider, getattr(self, provider.value) 

61 

62 def get(self, provider: AIProvider) -> ProviderInfo: 

63 """Look up a provider by enum member. 

64 

65 Args: 

66 provider: The provider to look up. 

67 

68 Returns: 

69 ProviderInfo for the requested provider. 

70 """ 

71 info: ProviderInfo = getattr(self, provider.value) 

72 return info 

73 

74 @property 

75 def model_pricing(self) -> dict[str, ModelPricing]: 

76 """Flat mapping of every known model to its pricing.""" 

77 return dict(self._cached_model_pricing) 

78 

79 @property 

80 def default_models(self) -> dict[AIProvider, str]: 

81 """Map each provider to its default model identifier.""" 

82 return {p: info.default_model for p, info in self.items()} 

83 

84 @property 

85 def default_api_key_envs(self) -> dict[AIProvider, str]: 

86 """Map each provider to its default API-key env var.""" 

87 return {p: info.default_api_key_env for p, info in self.items()} 

88 

89 

90# -- Singleton instance ---------------------------------------------------- 

91 

92PROVIDERS = AIProviderRegistry( 

93 anthropic=ProviderInfo( 

94 default_model="claude-sonnet-4-6", 

95 default_api_key_env="ANTHROPIC_API_KEY", 

96 models={ 

97 "claude-sonnet-4-6": ModelPricing(3.00, 15.00), 

98 "claude-sonnet-4-20250514": ModelPricing(3.00, 15.00), 

99 "claude-haiku-4-5-20251001": ModelPricing(0.80, 4.00), 

100 "claude-opus-4-20250514": ModelPricing(15.00, 75.00), 

101 }, 

102 ), 

103 openai=ProviderInfo( 

104 default_model="gpt-4o", 

105 default_api_key_env="OPENAI_API_KEY", 

106 models={ 

107 "gpt-4o": ModelPricing(2.50, 10.00), 

108 "gpt-4o-mini": ModelPricing(0.15, 0.60), 

109 "gpt-4-turbo": ModelPricing(10.00, 30.00), 

110 "o1": ModelPricing(15.00, 60.00), 

111 "o1-mini": ModelPricing(1.10, 4.40), 

112 }, 

113 ), 

114) 

115 

116# Fallback pricing when a model is not in the registry. 

117DEFAULT_PRICING = ModelPricing(input_per_million=3.00, output_per_million=15.00)